Skip to content

Commit

Permalink
test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 4, 2019
1 parent 2bd6dff commit a076bc2
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 23 deletions.
2 changes: 2 additions & 0 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def test_gat_conv():
conv = GATConv(in_channels, out_channels, heads=2, dropout=0.5)
assert conv.__repr__() == 'GATConv(16, 32, heads=2)'
assert conv(x, edge_index).size() == (num_nodes, 2 * out_channels)
assert conv((x, None), edge_index).size() == (num_nodes, 2 * out_channels)

conv = GATConv(in_channels, out_channels, heads=2, concat=False)
assert conv(x, edge_index).size() == (num_nodes, out_channels)
assert conv((x, None), edge_index).size() == (num_nodes, out_channels)
1 change: 1 addition & 0 deletions test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ def test_sage_conv():
conv = SAGEConv(in_channels, out_channels)
assert conv.__repr__() == 'SAGEConv(16, 32)'
assert conv(x, edge_index).size() == (num_nodes, out_channels)
assert conv((x, None), edge_index).size() == (num_nodes, out_channels)
22 changes: 6 additions & 16 deletions test/nn/pool/test_sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,9 @@ def test_sag_pooling():
num_nodes = edge_index.max().item() + 1
x = torch.randn((num_nodes, in_channels))

pool = SAGPooling(in_channels, ratio=0.5, gnn='GCN')
assert pool.__repr__() == 'SAGPooling(GCN, 16, ratio=0.5)'

out = pool(x, edge_index)
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)

pool = SAGPooling(in_channels, ratio=0.5, gnn='GAT')
out = pool(x, edge_index)
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)

pool = SAGPooling(in_channels, ratio=0.5, gnn='SAGE')
out = pool(x, edge_index)
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)
for gnn in ['GraphConv', 'GCN', 'GAT', 'SAGE']:
pool = SAGPooling(in_channels, ratio=0.5, gnn=gnn)
assert pool.__repr__() == 'SAGPooling({}, 16, ratio=0.5)'.format(gnn)
out = pool(x, edge_index)
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)
2 changes: 1 addition & 1 deletion torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def reset_parameters(self):

def forward(self, x, edge_index, size=None):
""""""
if size is None:
if size is None and torch.is_tensor(x):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def reset_parameters(self):

def forward(self, x, edge_index, size=None):
""""""
if size is None:
if size is None and torch.is_tensor(x):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,20 @@ class SAGPooling(torch.nn.Module):
(default: :obj:`0.5`)
gnn (string, optional): Specifies which graph neural network layer to
use for calculating projection scores (one of
:obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`)
:obj:`"GraphConv", `:obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`).
(default: :obj:`GraphConv`)
**kwargs (optional): Additional parameters for initializing the graph
neural network layer.
"""

def __init__(self, in_channels, ratio=0.5, gnn='gConv', **kwargs):
def __init__(self, in_channels, ratio=0.5, gnn='GraphConv', **kwargs):
super(SAGPooling, self).__init__()

self.in_channels = in_channels
self.ratio = ratio
self.gnn_name = gnn

assert gnn in ['GCN', 'GAT', 'SAGE', 'gConv']
assert gnn in ['GraphConv', 'GCN', 'GAT', 'SAGE']
if gnn == 'GCN':
self.gnn = GCNConv(self.in_channels, 1, **kwargs)
elif gnn == 'GAT':
Expand All @@ -66,8 +67,8 @@ def forward(self, x, edge_index, edge_attr=None, batch=None):
perm = topk(score, self.ratio, batch)
x = x[perm] * score[perm].view(-1, 1)
batch = batch[perm]
edge_index, edge_attr = filter_adj(
edge_index, edge_attr, perm, num_nodes=score.size(0))
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
num_nodes=score.size(0))

return x, edge_index, edge_attr, batch, perm

Expand Down

0 comments on commit a076bc2

Please sign in to comment.