Skip to content

Commit

Permalink
Remove size parameter for GATv2 and HEAT (#3744)
Browse files Browse the repository at this point in the history
* refactor heat_conv and test

* refactor gatv2_conv and test
  • Loading branch information
saiden89 committed Dec 22, 2021
1 parent 95ef04f commit d47d9cd
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 25 deletions.
18 changes: 8 additions & 10 deletions test/nn/conv/test_gatv2_conv.py
Expand Up @@ -14,15 +14,14 @@ def test_gatv2_conv():
assert conv.__repr__() == 'GATv2Conv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out)
assert torch.allclose(conv(x1, edge_index), out)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)

t = '(Tensor, Tensor, OptTensor, Size, NoneType) -> Tensor'
t = '(Tensor, Tensor, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, edge_index), out)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)

t = '(Tensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor'
t = '(Tensor, SparseTensor, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)

Expand All @@ -39,7 +38,7 @@ def test_gatv2_conv():
assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7
assert conv._alpha is None

t = ('(Tensor, Tensor, OptTensor, Size, bool) -> '
t = ('(Tensor, Tensor, OptTensor, bool) -> '
'Tuple[Tensor, Tuple[Tensor, Tensor]]')
jit = torch.jit.script(conv.jittable(t))
result = jit(x1, edge_index, return_attention_weights=True)
Expand All @@ -49,7 +48,7 @@ def test_gatv2_conv():
assert result[1][1].min() >= 0 and result[1][1].max() <= 1
assert conv._alpha is None

t = ('(Tensor, SparseTensor, OptTensor, Size, bool) -> '
t = ('(Tensor, SparseTensor, OptTensor, bool) -> '
'Tuple[Tensor, SparseTensor]')
jit = torch.jit.script(conv.jittable(t))
result = jit(x1, adj.t(), return_attention_weights=True)
Expand All @@ -60,15 +59,14 @@ def test_gatv2_conv():
adj = adj.sparse_resize((4, 2))
out1 = conv((x1, x2), edge_index)
assert out1.size() == (2, 64)
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)
assert torch.allclose(conv((x1, x2), edge_index), out1)
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)

t = '(OptPairTensor, Tensor, OptTensor, Size, NoneType) -> Tensor'
t = '(OptPairTensor, Tensor, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), edge_index), out1)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)

t = '(OptPairTensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor'
t = '(OptPairTensor, SparseTensor, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)

Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_heat_conv.py
Expand Up @@ -20,7 +20,7 @@ def test_heat_conv():
assert out.size() == (4, 32)
assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)

t = '(Tensor, Tensor, Tensor, Tensor, OptTensor, Size) -> Tensor'
t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, edge_index, node_type, edge_type, edge_attr),
out)
Expand All @@ -33,6 +33,6 @@ def test_heat_conv():
assert out.size() == (4, 16)
assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)

t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor, Size) -> Tensor'
t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out)
16 changes: 7 additions & 9 deletions torch_geometric/nn/conv/gatv2_conv.py
@@ -1,5 +1,5 @@
from typing import Union, Tuple, Optional
from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor)
from torch_geometric.typing import (Adj, OptTensor, PairTensor)

import torch
from torch import Tensor
Expand Down Expand Up @@ -163,12 +163,12 @@ def reset_parameters(self):
zeros(self.bias)

def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None,
edge_attr: OptTensor = None,
return_attention_weights: bool = None):
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa
r"""
Args:
return_attention_weights (bool, optional): If set to :obj:`True`,
Expand Down Expand Up @@ -202,8 +202,6 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
num_nodes = x_l.size(0)
if x_r is not None:
num_nodes = min(num_nodes, x_r.size(0))
if size is not None:
num_nodes = min(size[0], size[1])
edge_index, edge_attr = remove_self_loops(
edge_index, edge_attr)
edge_index, edge_attr = add_self_loops(
Expand All @@ -220,7 +218,7 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,

# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr,
size=size)
size=None)

alpha = self._alpha
self._alpha = None
Expand Down
7 changes: 3 additions & 4 deletions torch_geometric/nn/conv/heat_conv.py
@@ -1,5 +1,5 @@
from typing import Optional
from torch_geometric.typing import Adj, Size, OptTensor
from torch_geometric.typing import Adj, OptTensor

import torch
from torch import Tensor
Expand Down Expand Up @@ -89,8 +89,7 @@ def reset_parameters(self):
self.lin.reset_parameters()

def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor,
edge_type: Tensor, edge_attr: OptTensor = None,
size: Size = None) -> Tensor:
edge_type: Tensor, edge_attr: OptTensor = None) -> Tensor:
""""""
x = self.hetero_lin(x, node_type)

Expand All @@ -99,7 +98,7 @@ def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor,

# propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor) # noqa
out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb,
edge_attr=edge_attr, size=size)
edge_attr=edge_attr, size=None)

if self.concat:
if self.root_weight:
Expand Down

0 comments on commit d47d9cd

Please sign in to comment.