Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent 3fbceb7 commit 29e36da
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
16 changes: 7 additions & 9 deletions torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ class SAGPooling(torch.nn.Module):
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :arg:`ratio` argument is
When this value is not :obj:`None`, the :obj:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :arg:`min_score` is used. (default: :obj:`1`)
when :obj:`min_score` is used. (default: :obj:`1`)
**kwargs (optional): Additional parameters for initializing the graph
neural network layer.
"""
Expand Down Expand Up @@ -89,10 +89,9 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
if batch is None:
batch = edge_index.new_zeros(x.size(0))

if attn_input is None:
attn_input = x
attn_input = attn_input.unsqueeze(-1) if attn_input.dim() == 1 \
else attn_input
attn_input = x if attn_input is None else attn_input
if attn_input.dim() == 1:
attn_input = attn_input.unsqueeze(-1)

score = self.gnn(attn_input, edge_index).view(-1)

Expand All @@ -101,10 +100,9 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
else:
score = softmax(score, batch)

perm = topk(score, self.ratio, batch, min_score=self.min_score)
perm = topk(score, self.ratio, batch, self.min_score)
x = x[perm] * score[perm].view(-1, 1)
if self.multiplier != 1:
x = x * self.multiplier
x = self.multiplier * x if self.multiplier != 1 else x

batch = batch[perm]
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
Expand Down
17 changes: 7 additions & 10 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ class TopKPooling(torch.nn.Module):
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :arg:`ratio` argument is
When this value is not :obj:`None`, the :obj:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :arg:`min_score` is used. (default: :obj:`1`)
when :obj:`min_score` is used. (default: :obj:`1`)
"""

def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=1):
Expand All @@ -137,10 +137,9 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
if batch is None:
batch = edge_index.new_zeros(x.size(0))

if attn_input is None:
attn_input = x
attn_input = attn_input.unsqueeze(-1) if attn_input.dim() == 1 \
else attn_input
attn_input = x if attn_input is None else attn_input
if attn_input.dim() == 1:
attn_input = attn_input.unsqueeze(-1)

score = (attn_input * self.weight).sum(dim=-1)

Expand All @@ -149,11 +148,9 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
else:
score = softmax(score, batch)

perm = topk(score, self.ratio, batch, min_score=self.min_score)

perm = topk(score, self.ratio, batch, self.min_score)
x = x[perm] * score[perm].view(-1, 1)
if self.multiplier != 1:
x = x * self.multiplier
x = self.multiplier * x if self.multiplier != 1 else x

batch = batch[perm]
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
Expand Down

0 comments on commit 29e36da

Please sign in to comment.