Skip to content

Commit

Permalink
add size attribute to rgcn conv
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent 89afdb9 commit 5847d09
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,8 @@ class RGCNConv(MessagePassing):
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self,
in_channels,
out_channels,
num_relations,
num_bases,
bias=True,
**kwargs):
def __init__(self, in_channels, out_channels, num_relations, num_bases,
bias=True, **kwargs):
super(RGCNConv, self).__init__(aggr='add', **kwargs)

self.in_channels = in_channels
Expand All @@ -63,10 +58,10 @@ def reset_parameters(self):
uniform(size, self.root)
uniform(size, self.bias)

def forward(self, x, edge_index, edge_type, edge_norm=None):
def forward(self, x, edge_index, edge_type, edge_norm=None, size=None):
""""""
return self.propagate(
edge_index, x=x, edge_type=edge_type, edge_norm=edge_norm)
return self.propagate(edge_index, size=size, x=x, edge_type=edge_type,
edge_norm=edge_norm)

def message(self, x_j, edge_index_j, edge_type, edge_norm):
w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
Expand Down

0 comments on commit 5847d09

Please sign in to comment.