Skip to content

Commit

Permalink
update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent 9a07e84 commit b4e6f3b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class SAGPooling(torch.nn.Module):
if min_score :math:`\tilde{\alpha}` is a value in [0, 1]:
.. math::
\mathbf{y} &= softmax(\textrm{GNN}(\mathbf{X}, \mathbf{A}))
\mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))
\mathbf{i} &: \mathbf{y}_i > \tilde{\alpha}
\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot (\mathbf{y}))_{\mathbf{i}}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},
Expand Down Expand Up @@ -58,8 +58,8 @@ class SAGPooling(torch.nn.Module):
neural network layer.
"""

def __init__(self, in_channels, ratio=0.5, min_score=None,
multiplier=None, gnn='GraphConv', **kwargs):
def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=None,
gnn='GraphConv', **kwargs):
super(SAGPooling, self).__init__()

self.in_channels = in_channels
Expand Down
14 changes: 7 additions & 7 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def topk(x, ratio, batch, min_score=None, tol=1e-7):
if min_score is not None:
# make sure that we do not drop all nodes in a graph
scores_max = scatter_max(x, batch)[0][batch] - tol
scores_min = torch.min(scores_max, x.new_full((1,), min_score))
scores_min = torch.min(scores_max, x.new_full((1, ), min_score))

perm = torch.nonzero(x > scores_min).view(-1)

Expand All @@ -26,7 +26,7 @@ def topk(x, ratio, batch, min_score=None, tol=1e-7):
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

dense_x = x.new_full((batch_size * max_num_nodes,), -2)
dense_x = x.new_full((batch_size * max_num_nodes, ), -2)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)

Expand Down Expand Up @@ -87,11 +87,11 @@ class TopKPooling(torch.nn.Module):
if min_score :math:`\tilde{\alpha}` is a value in [0, 1]:
.. math::
\mathbf{y} &= softmax(\mathbf{X}\mathbf{p})
\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})
\mathbf{i} &: \mathbf{y}_i > \tilde{\alpha}
\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot (\mathbf{y}))_{\mathbf{i}}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},
Expand Down Expand Up @@ -157,8 +157,8 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
x = x * self.multiplier

batch = batch[perm]
edge_index, edge_attr = filter_adj(
edge_index, edge_attr, perm, num_nodes=score.size(0))
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
num_nodes=score.size(0))

return x, edge_index, edge_attr, batch, perm, score[perm]

Expand Down

0 comments on commit b4e6f3b

Please sign in to comment.