Skip to content

Commit

Permalink
normalize arg
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 25, 2018
1 parent 50011f1 commit 80a4551
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ out = SplineConv.apply(src,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None)
```
Expand All @@ -66,6 +67,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
* **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension.
* **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension.
* **degree** *(int, optional)* - B-spline basis degree. (default: `1`)
* **norm** *(bool, optional)*: Whether to normalize output by node degree. (default: `True`)
* **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`)
* **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`)

Expand All @@ -86,11 +88,12 @@ weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_chann
kernel_size = torch.tensor([5, 5]) # 5 parameters in each edge dimension
is_open_spline = torch.tensor([1, 1], dtype=torch.uint8) # only use open B-splines
degree = 1 # B-spline degree of 1
norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias

out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
is_open_spline, degree, norm root_weight, bias)

print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from setuptools import setup, find_packages

__version__ = '1.0.2'
__version__ = '1.0.3'
url = 'https://github.com/rusty1s/pytorch_spline_conv'

install_requires = ['cffi']
Expand Down
17 changes: 8 additions & 9 deletions test/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
'root_weight': [[12.5], [13]],
'bias': [1],
'expected': [
1 + (12.5 * 9 + 13 * 10 + 8.5 + 40.5 + 107.5 + 101.5) / 5,
1 + 12.5 * 1 + 13 * 2,
1 + 12.5 * 3 + 13 * 4,
1 + 12.5 * 5 + 13 * 6,
1 + 12.5 * 7 + 13 * 8,
[1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
[1 + 12.5 * 1 + 13 * 2],
[1 + 12.5 * 3 + 13 * 4],
[1 + 12.5 * 5 + 13 * 6],
[1 + 12.5 * 7 + 13 * 8],
]
}]

Expand All @@ -52,9 +52,8 @@ def test_spline_conv_forward(test, dtype, device):
bias = tensor(test['bias'], dtype, device)

out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias)
assert list(out.size()) == [5, 1]
assert pytest.approx(out.view(-1).tolist()) == test['expected']
is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected']


@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
Expand All @@ -74,5 +73,5 @@ def test_spline_basis_backward(degree, device):
bias.requires_grad_()

data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree, root_weight, bias)
degree, True, root_weight, bias)
assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True
2 changes: 1 addition & 1 deletion torch_spline_conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .weighting import SplineWeighting
from .conv import SplineConv

__version__ = '1.0.2'
__version__ = '1.0.3'

__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
12 changes: 7 additions & 5 deletions torch_spline_conv/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class SplineConv(object):
is_open_spline (:class:`ByteTensor`): Whether to use open or closed
B-spline bases for each dimension.
degree (int, optional): B-spline basis degree. (default: :obj:`1`)
norm (bool, optional): Whether to normalize output by node degree.
(default: :obj:`True`)
root_weight (:class:`Tensor`, optional): Additional shared trainable
parameters for each feature of the root node of shape
(in_channels x out_channels). (default: :obj:`None`)
Expand All @@ -45,6 +47,7 @@ def apply(src,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None):

Expand All @@ -62,15 +65,14 @@ def apply(src,
row_expand = row.unsqueeze(-1).expand_as(out)
out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)

deg = node_degree(row, n, out.dtype, out.device)
# Normalize out by node degree (if wished).
if norm:
deg = node_degree(row, n, out.dtype, out.device)
out = out / deg.unsqueeze(-1).clamp(min=1)

# Weight root node separately (if wished).
if root_weight is not None:
out += torch.mm(src, root_weight)
deg += 1

# Normalize out by node degree.
out /= deg.unsqueeze(-1).clamp(min=1)

# Add bias (if wished).
if bias is not None:
Expand Down

0 comments on commit 80a4551

Please sign in to comment.