Skip to content

Commit

Permalink
new signed gcn model with example
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 10, 2019
1 parent f42ecf8 commit 7a11a86
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 40 deletions.
11 changes: 4 additions & 7 deletions examples/ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_geometric.datasets import PPI
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv
from sklearn import metrics
from sklearn.metrics import f1_score

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')
train_dataset = PPI(path, split='train')
Expand Down Expand Up @@ -66,14 +66,11 @@ def test(loader):
out = model(data.x.to(device), data.edge_index.to(device))
preds.append((out > 0).float().cpu())

y, pred = torch.cat(ys, dim=0), torch.cat(preds, dim=0)
if pred.sum().item() == 0:
return 0
return metrics.f1_score(y.numpy(), pred.numpy(), average='micro')
y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()
return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0


for epoch in range(1, 101):
loss = train()
acc = test(val_loader)
print('Epoch: {:02d}, Loss: {:.4f}, Micro-F1: {:.4f}'.format(
epoch, loss, acc))
print('Epoch: {:02d}, Loss: {:.4f}, F1: {:.4f}'.format(epoch, loss, acc))
50 changes: 50 additions & 0 deletions examples/signed_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os.path as osp

import torch
from torch_geometric.datasets import BitcoinOTC
from torch_geometric.nn import SignedGCN

name = 'BitcoinOTC-1'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
dataset = BitcoinOTC(path, edge_window_size=1)

# Generate dataset.
pos_edge_indices, neg_edge_indices = [], []
for data in dataset:
pos_edge_indices.append(data.edge_index[:, data.edge_attr > 0])
neg_edge_indices.append(data.edge_index[:, data.edge_attr < 0])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pos_edge_index = torch.cat(pos_edge_indices, dim=1).to(device)
neg_edge_index = torch.cat(neg_edge_indices, dim=1).to(device)

# Build and train model.
model = SignedGCN(64, 64, num_layers=2, lamb=5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

train_pos_edge_index, test_pos_edge_index = model.split_edges(pos_edge_index)
train_neg_edge_index, test_neg_edge_index = model.split_edges(neg_edge_index)
x = model.create_spectral_features(train_pos_edge_index, train_neg_edge_index)


def train():
model.train()
optimizer.zero_grad()
z = model(x, train_pos_edge_index, train_neg_edge_index)
loss = model.loss(z, train_pos_edge_index, train_neg_edge_index)
loss.backward()
optimizer.step()
return loss.item()


def test():
model.eval()
with torch.no_grad():
z = model(x, train_pos_edge_index, train_neg_edge_index)
return model.test(z, test_pos_edge_index, test_neg_edge_index)


for epoch in range(201):
loss = train()
auc, f1 = test()
print('Epoch: {:03d}, Loss: {:.4f}, AUC: {:.4f}, F1: {:.4f}'.format(
epoch, loss, auc, f1))
37 changes: 19 additions & 18 deletions torch_geometric/nn/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
EPS = 1e-15


def negative_sampling(pos_edge_index, num_nodes):
idx = (pos_edge_index[0] * num_nodes + pos_edge_index[1])
idx = idx.to(torch.device('cpu'))

rng = range(num_nodes**2)
perm = torch.tensor(random.sample(rng, idx.size(0)))
mask = torch.from_numpy(np.isin(perm, idx).astype(np.uint8))
rest = mask.nonzero().view(-1)
while rest.numel() > 0: # pragma: no cover
tmp = torch.tensor(random.sample(rng, rest.size(0)))
mask = torch.from_numpy(np.isin(tmp, idx).astype(np.uint8))
perm[rest] = tmp
rest = mask.nonzero().view(-1)

row, col = perm / num_nodes, perm % num_nodes
return torch.stack([row, col], dim=0).to(pos_edge_index.device)


class GAE(torch.nn.Module):
r"""The Graph Auto-Encoder model from the
`"Variational Graph Auto-Encoders" <https://arxiv.org/abs/1611.07308>`_
Expand Down Expand Up @@ -116,23 +134,6 @@ def split_edges(self, data, val_ratio=0.05, test_ratio=0.1):

return data

def negative_sampling(self, pos_edge_index, num_nodes):
idx = (pos_edge_index[0] * num_nodes + pos_edge_index[1])
idx = idx.to(torch.device('cpu'))

rng = range(num_nodes**2)
perm = torch.tensor(random.sample(rng, idx.size(0)))
mask = torch.from_numpy(np.isin(perm, idx).astype(np.uint8))
rest = mask.nonzero().view(-1)
while rest.numel() > 0: # pragma: no cover
tmp = torch.tensor(random.sample(rng, rest.size(0)))
mask = torch.from_numpy(np.isin(tmp, idx).astype(np.uint8))
perm[rest] = tmp
rest = mask.nonzero().view(-1)

row, col = perm / num_nodes, perm % num_nodes
return torch.stack([row, col], dim=0).to(pos_edge_index.device)

def recon_loss(self, z, pos_edge_index):
r"""Given latent variables :obj:`z`, computes the binary cross
entropy loss for positive edges :obj:`pos_edge_index` and negative
Expand All @@ -146,7 +147,7 @@ def recon_loss(self, z, pos_edge_index):
pos_loss = -torch.log(self.decode_indices(z, pos_edge_index) +
EPS).mean()

neg_edge_index = self.negative_sampling(pos_edge_index, z.size(0))
neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
neg_loss = -torch.log(1 - self.decode_indices(z, neg_edge_index) +
EPS).mean()

Expand Down
111 changes: 96 additions & 15 deletions torch_geometric/nn/models/signed_gcn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import roc_auc_score, f1_score
import torch
import torch.nn.functional as F
from torch_sparse import coalesce
from torch_geometric.nn import SignedConv

from .autoencoder import negative_sampling


class SignedGCN(torch.nn.Module):
def __init__(self,
Expand All @@ -15,6 +20,7 @@ def __init__(self,
super(SignedGCN, self).__init__()

self.in_channels = in_channels
self.lamb = lamb

self.conv1 = SignedConv(
in_channels, hidden_channels // 2, first_aggr=True)
Expand All @@ -26,7 +32,7 @@ def __init__(self,
hidden_channels // 2,
first_aggr=False))

self.lin = torch.nn.Linear(2 * hidden_channels, 1)
self.lin = torch.nn.Linear(2 * hidden_channels, 3)

self.reset_parameters()

Expand All @@ -36,18 +42,34 @@ def reset_parameters(self):
conv.reset_parameters()
self.lin.reset_parameters()

def split_edges(self, edge_index, test_ratio=0.2):
mask = torch.ones(edge_index.size(1), dtype=torch.uint8)
mask[torch.randperm(mask.size(0))[:int(test_ratio * mask.size(0))]] = 0

train_edge_index = edge_index[:, mask]
test_edge_index = edge_index[:, 1 - mask]

return train_edge_index, test_edge_index

def create_spectral_features(self, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)
N = edge_index.max().item() + 1
edge_index = edge_index.to(torch.device('cpu')).detach().numpy()
edge_index = edge_index.to(torch.device('cpu'))

pos_val = torch.full((pos_edge_index.size(1), ), 1, dtype=torch.float)
neg_val = torch.full((neg_edge_index.size(1), ), -1, dtype=torch.float)
pos_val = torch.full((pos_edge_index.size(1), ), 2, dtype=torch.float)
neg_val = torch.full((neg_edge_index.size(1), ), 0, dtype=torch.float)
val = torch.cat([pos_val, neg_val], dim=0)
val = val.to(torch.device('cpu')).detach().numpy()

row, col = edge_index
edge_index = torch.cat([edge_index, torch.stack([col, row])], dim=1)
val = torch.cat([val, val], dim=0)

edge_index, val = coalesce(edge_index, val, N, N)

# Borrowed from:
# https://github.com/benedekrozemberczki/SGCN/blob/master/src/utils.py
edge_index = edge_index.detach().numpy()
val = val.detach().numpy()
A = scipy.sparse.coo_matrix((val, edge_index), shape=(N, N))
svd = TruncatedSVD(n_components=self.in_channels, n_iter=128)
svd.fit(A)
Expand All @@ -56,18 +78,77 @@ def create_spectral_features(self, pos_edge_index, neg_edge_index):

def forward(self, x, pos_edge_index, neg_edge_index):
""""""
x = F.relu(self.conv1(x, pos_edge_index, neg_edge_index))
z = F.relu(self.conv1(x, pos_edge_index, neg_edge_index))
for conv in self.convs:
x = F.relu(conv(x, pos_edge_index, neg_edge_index))
return x
z = F.relu(conv(z, pos_edge_index, neg_edge_index))
return z

def discriminate(self, z, edge_index, sigmoid=True):
def discriminate(self, z, edge_index):
value = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
value = self.lin(value).view(-1)
return torch.sigmoid(value) if sigmoid else value
value = self.lin(value)
return torch.log_softmax(value, dim=1)

def nll_loss(self, z, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)
none_edge_index = negative_sampling(edge_index, z.size(0))

nll_loss = 0
nll_loss += F.nll_loss(
self.discriminate(z, pos_edge_index),
pos_edge_index.new_full((pos_edge_index.size(1), ), 0))
nll_loss += F.nll_loss(
self.discriminate(z, neg_edge_index),
neg_edge_index.new_full((neg_edge_index.size(1), ), 1))
nll_loss += F.nll_loss(
self.discriminate(z, none_edge_index),
none_edge_index.new_full((none_edge_index.size(1), ), 2))
return nll_loss / 3.0

def negative_sampling(self, pos_edge_index, num_nodes):
device = pos_edge_index
i, j = pos_edge_index.to(torch.device('cpu'))
idx_1 = i * num_nodes + j
k = torch.randint(num_nodes, (i.size(0), ), dtype=torch.long)
idx_2 = k * num_nodes + i
mask = torch.from_numpy(np.isin(idx_2, idx_1).astype(np.uint8))
rest = mask.nonzero().view(-1)
while rest.numel() > 0: # pragma: no cover
tmp = torch.randint(num_nodes, (rest.numel(), ), dtype=torch.long)
idx_2 = tmp * num_nodes + i[rest]
mask = torch.from_numpy(np.isin(idx_2, idx_1).astype(np.uint8))
k[rest] = tmp
rest = mask.nonzero().view(-1)
return i.to(device), j.to(device), k.to(device)

def pos_embedding_loss(self, z, pos_edge_index):
i, j, k = self.negative_sampling(pos_edge_index, z.size(0))

out = (z[i] - z[j]).pow(2).sum(dim=1) - (z[i] - z[k]).pow(2).sum(dim=1)
return torch.clamp(out, min=0).mean()

def neg_embedding_loss(self, z, neg_edge_index):
i, j, k = self.negative_sampling(neg_edge_index, z.size(0))

out = (z[i] - z[k]).pow(2).sum(dim=1) - (z[i] - z[j]).pow(2).sum(dim=1)
return torch.clamp(out, min=0).mean()

def loss(self, z, pos_edge_index, neg_edge_index):
pass
nll_loss = self.nll_loss(z, pos_edge_index, neg_edge_index)
loss_1 = self.pos_embedding_loss(z, pos_edge_index)
loss_2 = self.neg_embedding_loss(z, neg_edge_index)
return nll_loss + self.lamb * (loss_1 + loss_2)

def test(self, z, pos_edge_index, neg_edge_index):
pass
with torch.no_grad():
pos_p = self.discriminate(z, pos_edge_index)[:, :2].max(dim=1)[1]
neg_p = self.discriminate(z, neg_edge_index)[:, :2].max(dim=1)[1]
pred = torch.cat([pos_p, neg_p]).cpu()
y = torch.cat(
[pred.new_zeros((pos_p.size(0))),
pred.new_ones(neg_p.size(0))])
pred, y = pred.numpy(), y.numpy()

auc = roc_auc_score(y, pred)
f1 = f1_score(y, pred, average='binary') if pred.sum() > 0 else 0

return auc, f1

0 comments on commit 7a11a86

Please sign in to comment.