Skip to content

Commit

Permalink
update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent b4e6f3b commit ad11b3c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
22 changes: 11 additions & 11 deletions torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ class SAGPooling(torch.nn.Module):
gnn (string, optional): Specifies which graph neural network layer to
use for calculating projection scores (one of
:obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`)
min_score (float): Minimal node score, which is used to compute indices
of pooled nodes :math:`\mathbf{i} : \mathbf{y}_i > \tilde{\alpha}`.
When this value is not None, the ratio argument is ignored.
(default: None)
multiplier (float): Coefficient by which multiply features
after pooling. This can be useful for large graphs and
when min_score is used.
(default: None)
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :arg:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :arg:`min_score` is used. (default: :obj:`1`)
**kwargs (optional): Additional parameters for initializing the graph
neural network layer.
"""

def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=None,
gnn='GraphConv', **kwargs):
def __init__(self, in_channels, ratio=0.5, gnn='GraphConv', min_score=None,
multiplier=1, **kwargs):
super(SAGPooling, self).__init__()

self.in_channels = in_channels
Expand Down Expand Up @@ -103,7 +103,7 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,

perm = topk(score, self.ratio, batch, min_score=self.min_score)
x = x[perm] * score[perm].view(-1, 1)
if self.multiplier is not None:
if self.multiplier != 1:
x = x * self.multiplier

batch = batch[perm]
Expand Down
21 changes: 10 additions & 11 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,17 @@ class TopKPooling(torch.nn.Module):
:math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
This value is ignored if min_score is not None.
(default: :obj:`0.5`)
min_score (float): Minimal node score, which is used to compute indices
of pooled nodes :math:`\mathbf{i} : \mathbf{y}_i > \tilde{\alpha}`.
When this value is not None, the ratio argument is ignored.
(default: None)
multiplier (float): Coefficient by which multiply features
after pooling. This can be useful for large graphs and
when min_score is used.
(default: None)
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :arg:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :arg:`min_score` is used. (default: :obj:`1`)
"""

def __init__(self, in_channels, ratio=0.5, min_score=None,
multiplier=None):
def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=1):
super(TopKPooling, self).__init__()

self.in_channels = in_channels
Expand Down Expand Up @@ -153,7 +152,7 @@ def forward(self, x, edge_index, edge_attr=None, batch=None,
perm = topk(score, self.ratio, batch, min_score=self.min_score)

x = x[perm] * score[perm].view(-1, 1)
if self.multiplier is not None:
if self.multiplier != 1:
x = x * self.multiplier

batch = batch[perm]
Expand Down

0 comments on commit ad11b3c

Please sign in to comment.