Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 11, 2019
1 parent dd97be0 commit 6219117
Showing 1 changed file with 16 additions and 34 deletions.
50 changes: 16 additions & 34 deletions torch_geometric/nn/models/re_net.py
Expand Up @@ -35,20 +35,14 @@ class RENet(torch.nn.Module):
seq_len (int): The sequence length of past events.
num_layers (int, optional): The number of recurrent layers.
(default: :obj:`1`)
dropout (float): If non-zero, introduces a dropout layer before the final
prediction. (default: :obj:`0.`)
dropout (float): If non-zero, introduces a dropout layer before the
final prediction. (default: :obj:`0.`)
bias (bool, optional): If set to :obj:`False`, all layers will not
learn an additive bias. (default: :obj:`True`)
"""

def __init__(self,
num_nodes,
num_rels,
hidden_channels,
seq_len,
num_layers=1,
dropout=0.,
bias=True):
def __init__(self, num_nodes, num_rels, hidden_channels, seq_len,
num_layers=1, dropout=0., bias=True):
super(RENet, self).__init__()

self.num_nodes = num_nodes
Expand All @@ -60,18 +54,10 @@ def __init__(self,
self.ent = Parameter(torch.Tensor(num_nodes, hidden_channels))
self.rel = Parameter(torch.Tensor(num_rels, hidden_channels))

self.sub_gru = GRU(
3 * hidden_channels,
hidden_channels,
num_layers,
batch_first=True,
bias=bias)
self.obj_gru = GRU(
3 * hidden_channels,
hidden_channels,
num_layers,
batch_first=True,
bias=bias)
self.sub_gru = GRU(3 * hidden_channels, hidden_channels, num_layers,
batch_first=True, bias=bias)
self.obj_gru = GRU(3 * hidden_channels, hidden_channels, num_layers,
batch_first=True, bias=bias)

self.sub_lin = Linear(3 * hidden_channels, num_nodes, bias=bias)
self.obj_lin = Linear(3 * hidden_channels, num_nodes, bias=bias)
Expand Down Expand Up @@ -116,8 +102,8 @@ def get_history(self, hist, node, rel):
h = hist[node][s]
hists += h
ts.append(torch.full((len(h), ), s, dtype=torch.long))
node, r = torch.tensor(
hists, dtype=torch.long).view(-1, 2).t().contiguous()
node, r = torch.tensor(hists, dtype=torch.long).view(
-1, 2).t().contiguous()
node = node[r == rel]
t = torch.cat(ts, dim=0)[r == rel]
return node, t
Expand Down Expand Up @@ -180,16 +166,12 @@ def forward(self, data):
h_sub_t = data.h_sub_t + data.h_sub_batch * seq_len
h_obj_t = data.h_obj_t + data.h_obj_batch * seq_len

h_sub = scatter_mean(
self.ent[data.h_sub],
h_sub_t,
dim=0,
dim_size=batch_size * seq_len).view(batch_size, seq_len, -1)
h_obj = scatter_mean(
self.ent[data.h_obj],
h_obj_t,
dim=0,
dim_size=batch_size * seq_len).view(batch_size, seq_len, -1)
h_sub = scatter_mean(self.ent[data.h_sub], h_sub_t, dim=0,
dim_size=batch_size * seq_len).view(
batch_size, seq_len, -1)
h_obj = scatter_mean(self.ent[data.h_obj], h_obj_t, dim=0,
dim_size=batch_size * seq_len).view(
batch_size, seq_len, -1)

sub = self.ent[data.sub].unsqueeze(1).repeat(1, seq_len, 1)
rel = self.rel[data.rel].unsqueeze(1).repeat(1, seq_len, 1)
Expand Down

0 comments on commit 6219117

Please sign in to comment.