Skip to content
Permalink
Browse files

linting

  • Loading branch information...
rusty1s committed Sep 8, 2019
1 parent 5a7fb62 commit ef85bdd95a84fb29ff8f44df166a66042db0e12e
Showing with 13 additions and 11 deletions.
  1. +3 −3 test/nn/pool/test_edge_pool.py
  2. +10 −8 torch_geometric/nn/pool/edge_pool.py
@@ -7,7 +7,7 @@ def test_compute_edge_score_softmax():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],
[1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])
raw = torch.randn(edge_index.size(1))
e = EdgePooling.compute_edge_score_softmax(raw, edge_index)
e = EdgePooling.compute_edge_score_softmax(raw, edge_index, 6)
assert torch.all(e >= 0) and torch.all(e <= 1)

# Test whether all incoming edge scores sum up to one.
@@ -19,7 +19,7 @@ def test_compute_edge_score_tanh():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],
[1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])
raw = torch.randn(edge_index.size(1))
e = EdgePooling.compute_edge_score_tanh(raw, edge_index)
e = EdgePooling.compute_edge_score_tanh(raw, edge_index, 6)
assert torch.all(e >= -1) and torch.all(e <= 1)
assert torch.all(torch.argsort(raw) == torch.argsort(e))

@@ -28,7 +28,7 @@ def test_compute_edge_score_sigmoid():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],
[1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])
raw = torch.randn(edge_index.size(1))
e = EdgePooling.compute_edge_score_sigmoid(raw, edge_index)
e = EdgePooling.compute_edge_score_sigmoid(raw, edge_index, 6)
assert torch.all(e >= 0) and torch.all(e <= 1)
assert torch.all(torch.argsort(raw) == torch.argsort(e))

@@ -30,10 +30,11 @@ class EdgePooling(torch.nn.Module):
edge_score_method (function, optional): The function to apply
to compute the edge score from raw edge scores. By default,
this is the softmax over all incoming edges for each node.
Functions that can be used take in a :obj:`raw_edge_score` tensor
of shape :obj:`[num_nodes]` and :obj:`edge_index`, and produces
a new tensor of the same size as :obj:`raw_edge_score` describing
the normalized edge scores. Included functions are
This function takes in a :obj:`raw_edge_score` tensor of shape
:obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of
nodes :obj:`num_nodes`, and produces a new tensor of the same size
as :obj:`raw_edge_score` describing normalized edge scores.
Included functions are
:func:`EdgePooling.compute_edge_score_softmax`,
:func:`EdgePooling.compute_edge_score_tanh`, and
:func:`EdgePooling.compute_edge_score_sigmoid`.
@@ -67,15 +68,15 @@ def reset_parameters(self):
self.lin.reset_parameters()

@staticmethod
def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes=None):
def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes):
return softmax(raw_edge_score, edge_index[1], num_nodes)

@staticmethod
def compute_edge_score_tanh(raw_edge_score, edge_index, num_nodes=None):
def compute_edge_score_tanh(raw_edge_score, edge_index, num_nodes):
return torch.tanh(raw_edge_score)

@staticmethod
def compute_edge_score_sigmoid(raw_edge_score, edge_index, num_nodes=None):
def compute_edge_score_sigmoid(raw_edge_score, edge_index, num_nodes):
return torch.sigmoid(raw_edge_score)

def forward(self, x, edge_index, batch):
@@ -99,7 +100,8 @@ def forward(self, x, edge_index, batch):
e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
e = self.lin(e).view(-1)
e = F.dropout(e, p=self.dropout, training=self.training)
e = self.compute_edge_score(e, edge_index, x.size(0)) + self.add_to_edge_score
e = self.compute_edge_score(e, edge_index, x.size(0))
e = e + self.add_to_edge_score

x, edge_index, batch, unpool_info = self.__merge_edges__(
x, edge_index, batch, e)

0 comments on commit ef85bdd

Please sign in to comment.
You can’t perform that action at this time.