Skip to content
Permalink
Browse files

rgcn bugfix

  • Loading branch information...
rusty1s committed Apr 15, 2019
1 parent 44f1fa5 commit 5ebeac473f854b08ca9a04d6dbfd5bf8ea2a7d78
Showing with 3 additions and 2 deletions.
  1. +3 −2 torch_geometric/nn/conv/rgcn_conv.py
@@ -72,7 +72,8 @@ def message(self, x_j, edge_index, edge_type, edge_norm):
# loopkup based on the target node index and its edge type.
if x_j is None:
w = w.view(-1, self.out_channels)
index = edge_type * self.in_channels + edge_index[1]
j = 1 if self.flow == 'target_to_source' else 0
index = edge_type * self.in_channels + edge_index[j]
out = torch.index_select(w, 0, index)
else:
w = w.view(self.num_relations, self.in_channels, self.out_channels)
@@ -82,7 +83,7 @@ def message(self, x_j, edge_index, edge_type, edge_norm):
return out if edge_norm is None else out * edge_norm.view(-1, 1)

def update(self, aggr_out, x):
if x.dtype == torch.long:
if x is None:
out = aggr_out + self.root
else:
out = aggr_out + torch.matmul(x, self.root)

0 comments on commit 5ebeac4

Please sign in to comment.
You can’t perform that action at this time.