Skip to content

Commit

Permalink
update gae models
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 26, 2019
1 parent c33af0f commit 45a03ac
Showing 1 changed file with 63 additions and 112 deletions.
175 changes: 63 additions & 112 deletions torch_geometric/nn/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from sklearn.metrics import roc_auc_score, average_precision_score

from ..inits import reset


class GAE(torch.nn.Module):
r"""The Graph Auto-Encoder model from the
Expand All @@ -21,38 +23,33 @@ def __init__(self, encoder):
super(GAE, self).__init__()
self.encoder = encoder

def reset_parameters(self):
reset(self.encoder)

def encode(self, *args, **kwargs):
r"""Runs the encoder and computes latent variables for each node."""
return self.encoder(*args, **kwargs)

def decode_all(self, z, sigmoid=True):
def decode(self, z):
r"""Decodes the latent variables :obj:`z` into a probabilistic
dense adjacency matrix.
Args:
z (Tensor): The latent space :math:`\mathbf{Z}`.
sigmoid (bool, optional): If set to :obj:`False`, does not apply
the logistic sigmoid function to the output.
(default :obj:`False`)
"""
adj = torch.matmul(z, z.t())
adj = torch.sigmoid(adj) if sigmoid else adj
return adj
return torch.sigmoid(adj)

def decode_indices(self, z, edge_index, sigmoid=True):
def decode_indices(self, z, edge_index):
r"""Decodes the latent variables :obj:`z` into edge-probabilties for
the given node-pairs :obj:`edge_index`.
Args:
z (Tensor): The latent space :math:`\mathbf{Z}`.
edge_index (LongTensor): The edge indices to predict.
sigmoid (bool, optional): If set to :obj:`False`, does not apply
the logistic sigmoid function to the output.
(default :obj:`False`)
"""
value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
value = torch.sigmoid(value) if sigmoid else value
return value
return torch.sigmoid(value)

def split_edges(self, data, val_ratio=0.05, test_ratio=0.1):
r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
Expand Down Expand Up @@ -110,7 +107,7 @@ def split_edges(self, data, val_ratio=0.05, test_ratio=0.1):

return data

def loss(self, z, pos_edge_index, neg_adj_mask):
def reconstruction_loss(self, z, pos_edge_index, neg_adj_mask):
r"""Given latent variables :obj:`z`, computes the binary cross
entropy loss for positive edges :obj:`pos_edge_index` and a negative
adjacency matrix mask :obj:`neg_adj_mask`.
Expand All @@ -122,15 +119,16 @@ def loss(self, z, pos_edge_index, neg_adj_mask):
:obj:`[N, N]` denoting the negative edges to train against.
"""

pos_loss = -torch.log(
self.decode_indices(z, pos_edge_index, sigmoid=True)).mean()
pos_loss = -torch.log(self.decode_indices(z, pos_edge_index)).mean()

neg_loss = -torch.log(
(1 - self.decode_all(z, sigmoid=True)[neg_adj_mask]).clamp(
min=1e-8)).mean()
(1 - self.decode(z)[neg_adj_mask]).clamp(min=1e-8)).mean()

return pos_loss + neg_loss

def loss(self, z, pos_edge_index, neg_adj_mask):
return self.reconstruction_loss(z, pos_edge_index, neg_adj_mask)

def test(self, z, pos_edge_index, neg_edge_index):
r"""Given latent variables :obj:`z`, positive edges
:obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`,
Expand All @@ -148,8 +146,8 @@ def test(self, z, pos_edge_index, neg_edge_index):
neg_y = z.new_zeros(neg_edge_index.size(1))
y = torch.cat([pos_y, neg_y], dim=0)

pos_pred = self.decode_indices(z, pos_edge_index, sigmoid=True)
neg_pred = self.decode_indices(z, neg_edge_index, sigmoid=True)
pos_pred = self.decode_indices(z, pos_edge_index)
neg_pred = self.decode_indices(z, neg_edge_index)
pred = torch.cat([pos_pred, neg_pred], dim=0)

y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
Expand All @@ -158,108 +156,61 @@ def test(self, z, pos_edge_index, neg_edge_index):


class VGAE(GAE):
def __init__(self, encoder, out_channels, n_latent):
def __init__(self, encoder):
super(VGAE, self).__init__(encoder)
self.z_mean = torch.nn.Linear(out_channels, n_latent)
self.z_var = torch.nn.Linear(out_channels, n_latent)
torch.nn.init.xavier_uniform(self.z_mean.weight)
torch.nn.init.xavier_uniform(self.z_var.weight)

def kl_loss(self, mean, logvar):
loss = torch.mean(0.5 * torch.sum(
torch.exp(logvar) + mean**2 - 1. - logvar, 1))
print(loss)
return loss

def reconstruction_loss(self, adj, edge_index, neg_adj_mask):
row, col = edge_index
loss = -torch.log(torch.sigmoid(adj[row, col])).mean()
print(loss)
loss = loss - torch.log(1 - torch.sigmoid(adj[neg_adj_mask])).mean()
return loss

def sample_z(self, mean, logvar):
stddev = torch.exp(0.5 * logvar)
noise = torch.randn(stddev.size())
if torch.cuda.is_available():
noise = noise.cuda()
return (noise * stddev) + mean

def encode(self, x, edge_index):
z = torch.nn.functional.relu(self.encoder(x, edge_index))
mean, logvar = self.z_mean(z), self.z_var(z)
z = self.sample_z(mean, logvar)
return z, mean, logvar

def loss(self, z, mean, logvar, *args):
args = list(args)
args[0] = self.decoder(args[0])
recon_loss = self.reconstruction_loss(*args)
kl_loss = self.kl_loss(mean, logvar)
total_loss = recon_loss + kl_loss
return total_loss

def sample(self, mu, logvar):
return mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar)

def kl_loss(self, mu, logvar):
return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

def loss(self, mu, logvar, pos_edge_index, neg_adj_mask):
z = self.sample(mu, logvar)
recon_loss = self.reconstruction_loss(z, pos_edge_index, neg_adj_mask)
kl_loss = self.kl_loss(z, mu, logvar)
return recon_loss + kl_loss


class ARGA(GAE):
def __init__(self, encoder, discriminator, n_latent):
def __init__(self, encoder, discriminator):
super(ARGA, self).__init__(encoder)
self.discriminator = discriminator

def reconstruction_loss(self, adj, edge_index, neg_adj_mask):
row, col = edge_index
loss = -torch.log(torch.sigmoid(adj[row, col])).mean()
loss = loss - torch.log(1 - torch.sigmoid(adj[neg_adj_mask])).mean()
return loss
def reset_parameters(self):
super(ARGA, self).reset_parameters(self)
reset(self.discriminator)

def discriminate(self, z):
z_real = torch.randn(z.size())
d_real = self.discriminator(z_real)
d_fake = self.discriminator(z)
return d_real, d_fake

def discriminator_loss(self, d_real, d_fake):
dc_real_loss = torch.nn.BCELoss(reduction='mean')(
d_real, torch.ones(d_real.size()))
dc_fake_loss = torch.nn.BCELoss(reduction='mean')(
d_fake, torch.zeros(d_fake.size()))
dc_gen_loss = torch.nn.BCELoss(reduction='mean')(
d_fake, torch.ones(d_fake.size()))
return dc_real_loss + dc_fake_loss + dc_gen_loss

def loss(self, d_real, d_fake, *args):
args = list(args)
args[0] = self.decoder(args[0])
recon_loss = self.reconstruction_loss(*args)
d_loss = self.discriminator_loss(d_real, d_fake)
total_loss = recon_loss + d_loss
return total_loss
real = torch.sigmoid(self.discriminator(torch.randn_like(z)))
fake = torch.sigmoid(self.discriminator(z))
return real, fake

def discriminator_loss(self, real, fake):
real_loss = -torch.log(real).mean()
fake_loss = -torch.log((1 - fake).clamp(min=1e-8)).mean()
return real_loss + fake_loss

def loss(self, z, pos_edge_index, neg_adj_mask):
recon_loss = self.reconstruction_loss(z, pos_edge_index, neg_adj_mask)
d_loss = self.discriminator_loss(*self.discriminate(z))
return recon_loss + d_loss


class ARGVA(ARGA):
def __init__(self, encoder, discriminator, out_channels):
n_latent = out_channels
super(ARGVA, self).__init__(encoder, discriminator, n_latent)
self.discriminator = discriminator
self.z_mean = torch.nn.Linear(out_channels, n_latent)
self.z_var = torch.nn.Linear(out_channels, n_latent)
torch.nn.init.xavier_uniform(self.z_mean.weight)
torch.nn.init.xavier_uniform(self.z_var.weight)

def kl_loss(self, mean, logvar):
loss = torch.mean(0.5 * torch.sum(
torch.exp(logvar) + mean**2 - 1. - logvar, 1))
print(loss)
return loss

def sample_z(self, mean, logvar):
stddev = torch.exp(0.5 * logvar)
noise = torch.randn(stddev.size())
if torch.cuda.is_available():
noise = noise.cuda()
return (noise * stddev) + mean

def encode(self, x, edge_index):
z = torch.nn.functional.relu(self.encoder(x, edge_index))
mean, logvar = self.z_mean(z), self.z_var(z)
z = self.sample_z(mean, logvar)
return z, mean, logvar
def __init__(self, encoder, discriminator):
super(ARGVA, self).__init__(encoder, discriminator)
self.VGAE = VGAE(encoder)

def sample(self, mu, logvar):
return self.VGAE.sample(mu, logvar)

def kl_loss(self, mu, logvar):
return self.VGAE.kl_loss(mu, logvar)

def loss(self, mu, logvar, pos_edge_index, neg_adj_mask):
z = self.sample(mu, logvar)
recon_loss = self.reconstruction_loss(z, pos_edge_index, neg_adj_mask)
kl_loss = self.kl_loss(z, mu, logvar)
d_loss = self.discriminator_loss(*self.discriminate(z))
return recon_loss + kl_loss + d_loss

0 comments on commit 45a03ac

Please sign in to comment.