Skip to content

Commit

Permalink
update __repr__
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent ad11b3c commit 64c6de3
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 30 deletions.
5 changes: 2 additions & 3 deletions test/nn/pool/test_sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ def test_sag_pooling():

for gnn in ['GraphConv', 'GCN', 'GAT', 'SAGE']:
pool = SAGPooling(in_channels, ratio=0.5, gnn=gnn)
assert pool.__repr__() == 'SAGPooling({}, 16, ratio=0.5, ' \
'min_score=None, ' \
'multiplier=None)'.format(gnn)
assert pool.__repr__() == ('SAGPooling({}, 16, ratio=0.5, '
'multiplier=1)').format(gnn)
out = pool(x, edge_index)
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)
3 changes: 1 addition & 2 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def test_topk_pooling():
x = torch.randn((num_nodes, in_channels))

pool = TopKPooling(in_channels, ratio=0.5)
assert pool.__repr__() == 'TopKPooling(16, ratio=0.5, ' \
'min_score=None, multiplier=None)'
assert pool.__repr__() == 'TopKPooling(16, ratio=0.5, multiplier=1)'

x, edge_index, _, _, _, _ = pool(x, edge_index)
assert x.size() == (num_nodes // 2, in_channels)
Expand Down
20 changes: 7 additions & 13 deletions torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GraphConv
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from ...utils import softmax
from torch_geometric.utils import softmax


class SAGPooling(torch.nn.Module):
Expand Down Expand Up @@ -84,7 +84,7 @@ def reset_parameters(self):
self.gnn.reset_parameters()

def forward(self, x, edge_index, edge_attr=None, batch=None,
attn_input=None): # attn_input can differ from x in general
attn_input=None): # `attn_input` can differ from x in general.
""""""
if batch is None:
batch = edge_index.new_zeros(x.size(0))
Expand Down Expand Up @@ -113,14 +113,8 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
return x, edge_index, edge_attr, batch, perm, score[perm]

def __repr__(self):
return '{}({}, {}, ' \
'ratio={}{}, ' \
'min_score={}, ' \
'multiplier={})'.format(self.__class__.__name__,
self.gnn_name,
self.in_channels,
self.ratio,
'(ignored)' if self.
min_score is not None else '',
self.min_score,
self.multiplier)
return '{}({}, {}, {}={}, multiplier={})'.format(
self.__class__.__name__, self.gnn_name, self.in_channels,
'ratio' if self.min_score is None else 'min_score',
self.ratio if self.min_score is None else self.min_score,
self.multiplier)
19 changes: 7 additions & 12 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import softmax

from ..inits import uniform
from ...utils.num_nodes import maybe_num_nodes
from ...utils import softmax


def topk(x, ratio, batch, min_score=None, tol=1e-7):
Expand Down Expand Up @@ -131,7 +131,7 @@ def reset_parameters(self):
uniform(size, self.weight)

def forward(self, x, edge_index, edge_attr=None, batch=None,
attn_input=None): # attn_input can differ from x in general
attn_input=None): # `attn_input` can differ from x in general.
""""""

if batch is None:
Expand Down Expand Up @@ -162,13 +162,8 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
return x, edge_index, edge_attr, batch, perm, score[perm]

def __repr__(self):
return '{}({}, ' \
'ratio={}{}, ' \
'min_score={}, ' \
'multiplier={})'.format(self.__class__.__name__,
self.in_channels,
self.ratio,
'(ignored)' if self.
min_score is not None else '',
self.min_score,
self.multiplier)
return '{}({}, {}={}, multiplier={})'.format(
self.__class__.__name__, self.in_channels,
'ratio' if self.min_score is None else 'min_score',
self.ratio if self.min_score is None else self.min_score,
self.multiplier)

0 comments on commit 64c6de3

Please sign in to comment.