Skip to content

Commit

Permalink
fixed arga
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 18, 2019
1 parent 75276a6 commit adf1ec0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions test/nn/models/test_autoencoder.py
Expand Up @@ -12,17 +12,18 @@ def test_gae():
z = model.encode(x)
assert z.tolist() == x.tolist()

adj = model.decode(z)
adj = model.decoder.forward_all(z)
assert adj.tolist() == torch.sigmoid(
torch.Tensor([[+2, -1, +1], [-1, +5, +4], [+1, +4, +5]])).tolist()

edge_index = torch.tensor([[0, 1], [1, 2]])
value = model.decode_indices(z, edge_index)
value = model.decode(z, edge_index)
assert value.tolist() == torch.sigmoid(torch.Tensor([-1, 4])).tolist()

edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
data = Data(edge_index=edge_index)
data.num_nodes = edge_index.max().item() + 1
data = model.split_edges(data, val_ratio=0.2, test_ratio=0.3)

assert data.val_pos_edge_index.size() == (2, 2)
Expand Down Expand Up @@ -68,5 +69,5 @@ def test_argva():

x = torch.Tensor([[1, -1], [1, 2], [2, 1]])
model.encode(x)
model.reparametrize(model.mu, model.logvar)
model.reparametrize(model.__mu__, model.__logvar__)
assert model.kl_loss().item() > 0
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/autoencoder.py
Expand Up @@ -36,7 +36,7 @@ class InnerProductDecoder(torch.nn.Module):
\sigma(\mathbf{Z}\mathbf{Z}^{\top})
where :math:`\mathbf{Z} \in \mathbb{R}^{N \times d}` denotes the latent
space produced by the encoder"""
space produced by the encoder."""

def forward(self, z, edge_index, sigmoid=True):
r"""Decodes the latent variables :obj:`z` into edge probabilties for
Expand Down Expand Up @@ -261,8 +261,8 @@ class ARGA(GAE):
"""

def __init__(self, encoder, discriminator, decoder=None):
super(ARGA, self).__init__(encoder, decoder)
self.discriminator = discriminator
super(ARGA, self).__init__(encoder, decoder)

def reset_parameters(self):
super(ARGA, self).reset_parameters()
Expand Down

0 comments on commit adf1ec0

Please sign in to comment.