Skip to content

Commit

Permalink
major update: neighborloader (lacks insertion though)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Dec 7, 2020
1 parent b378031 commit 7be5640
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 146 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ In detail, the following methods are currently implemented:
* **[DropEdge](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_adj)** from Rong *et al.*: [DropEdge: Towards Deep Graph Convolutional Networks on Node Classification](https://openreview.net/forum?id=Hkx1qkrKPr) (ICLR 2020)
* **[PairNorm](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.norm.PairNorm)** from Zhao and Akoglu: [PairNorm: Tackling Oversmoothing in GNNs](https://arxiv.org/abs/1909.12223) (ICLR 2020)
* **[Tree Decomposition](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.tree_decomposition)** from Jin *et al.*: [Junction Tree Variational Autoencoder for Molecular Graph Generation](https://arxiv.org/abs/1802.04364) (ICML 2018)
* **[TGN](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.TGN)** from Rossi *et al.*: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) (GRL+ 2020) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/tgn.py)]
* **[TGN](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.TGNMemory)** from Rossi *et al.*: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) (GRL+ 2020) [[**Example**](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/tgn.py)]

--------------------------------------------------------------------------------

Expand Down
163 changes: 72 additions & 91 deletions examples/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,178 +2,157 @@

import torch
from torch.nn import Linear
from torch_sparse import SparseTensor
from sklearn.metrics import average_precision_score, roc_auc_score

from torch_geometric.datasets import JODIEDataset
from torch_geometric.nn import TGN, TransformerConv
from torch_geometric.nn.models.tgn import IdentityMessage, LastAggregator
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (LastNeighborLoader, IdentityMessage,
LastAggregator)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0].to(device)
# Ensure to only sample *real* destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

class NeighborSampler(object):
def __init__(self, data, size):
self.size = size
self.device = data.src.device
edge_features_one_dir = torch.cat([data.t.cpu().unsqueeze(1), data.msg.cpu()], dim=1)
edge_features = torch.cat([edge_features_one_dir, edge_features_one_dir], dim=0)
self.adj = SparseTensor(row=torch.cat([data.src.cpu(), data.dst.cpu()]),
col=torch.cat([data.dst.cpu(), data.src.cpu()]),
value=edge_features,
sparse_sizes=(data.num_nodes, data.num_nodes))

def __call__(self, n_id, t):
_, _, value = self.adj.coo()
edge_t = value[:, 0].squeeze()
mask = edge_t < t
adj = self.adj.masked_select_nnz(mask, layout='coo')
if adj.numel() == 0:
adj = adj.sparse_resize([n_id.numel(), n_id.numel()])
else:
adj, n_id = adj.sample_adj(n_id.cpu(), num_neighbors=self.size)
return adj.to(self.device), n_id.to(self.device)


sampler = NeighborSampler(data, size=10)
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=10,
msg_dim=data.msg.size(-1), device=device)

# for batch in train_data.seq_batches(batch_size=200):
# src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# neighbors.insert(src, pos_dst, t, msg)

# raise NotImplementedError


class GraphAttentionEmbedding(torch.nn.Module):
def __init__(self, in_channels, out_channels, time_enc, edge_dim):
def __init__(self, in_channels, out_channels, msg_dim, time_enc):
super(GraphAttentionEmbedding, self).__init__()
self.conv = TransformerConv(in_channels, out_channels, edge_dim=edge_dim)
self.conv = TransformerConv(in_channels, out_channels,
edge_dim=msg_dim + time_enc.out_channels)
self.time_enc = time_enc

def forward(self, x, adj_t, t):
if adj_t.nnz() > 0:
_, _, value = adj_t.coo()

edge_t = value[:, 0].squeeze()
rel_t = edge_t - t
rel_t_enc = self.time_enc(rel_t.float())

edge_feat = value[:, 1:]

edge_attr = torch.cat([rel_t_enc, edge_feat], dim=1)
adj_t.set_value_(edge_attr, layout='coo')
x = self.conv((x, x[:adj_t.size(0)]), adj_t)
else:
x = self.conv.lin_skip(x)
return x
def forward(self, x, last_update, edge_index, t, msg):
rel_t = last_update[edge_index[0]] - t
rel_t_enc = self.time_enc(rel_t.to(x.dtype))
edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
return self.conv(x, edge_index, edge_attr)


class LinkPredictor(torch.nn.Module):
def __init__(self, in_channels):
super(LinkPredictor, self).__init__()
self.lin_src = Linear(in_channels, in_channels)
self.lin_dst = Linear(in_channels, in_channels)
self.lin_end = Linear(in_channels, 1)
self.lin_final = Linear(in_channels, 1)

def forward(self, z_src, z_dst):
h = self.lin_src(z_src) + self.lin_dst(z_dst)
h = h.relu()
return self.lin_end(h)

return self.lin_final(h)

class TimeEncoder(torch.nn.Module):
def __init__(self, dimension):
super(TimeEncoder, self).__init__()
self.dimension = dimension
self.lin = Linear(1, dimension)

def forward(self, t):
return self.lin(t.view(-1, 1)).cos()

def reset_parameters(self):
self.lin.reset_parameters()
memory_dim = time_dim = embedding_dim = 100

memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)

raw_msg_dim = data.msg.size(-1)
memory_dim = 100
time_dim = 100
gnn = GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)

time_enc = TimeEncoder(time_dim)
model = TGN(data.num_nodes, raw_msg_dim, memory_dim,
message_module=IdentityMessage(raw_msg_dim, memory_dim, time_dim),
aggregator_module=LastAggregator(),
time_enc=time_enc).to(device)
gnn = GraphAttentionEmbedding(in_channels=memory_dim, out_channels=100, time_enc=time_enc, edge_dim=time_dim+raw_msg_dim).to(device)
link_pred = LinkPredictor(in_channels=100).to(device)
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(
list(model.parameters()) + list(gnn.parameters()) +
list(link_pred.parameters()), lr=0.0001)
set(memory.parameters()) | set(gnn.parameters())
| set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


def train():
model.train()
model.reset_state()
memory.train()
gnn.train()
link_pred.train()

memory.reset_state() # Start with a fresh memory.

total_loss = 0
for batch in train_data.seq_batches(batch_size=200):
optimizer.zero_grad()

src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

# Sample negative destination nodes.
neg_dst = torch.randint(min_dst_idx, max_dst_idx + 1, (src.size(0), ),
dtype=torch.long, device=device)

n_id = torch.cat([src, pos_dst, neg_dst]).unique()
query_t = t[0]
adj_t, n_id = sampler(n_id, t=query_t.cpu())
n_id, edge_index, t_edge, msg_edge = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)

z, _ = model(n_id, query_t) # Get memory.
z = gnn(z, adj_t, query_t) # Embed memory via graph convolution.
# Get updated memory of all nodes involved in the computation.
z, last_update = memory(n_id)

z = gnn(z, last_update, edge_index, t_edge, msg_edge)

pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])

loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))

model.update_state(src, pos_dst, t, msg)
# Update memory and neighbor loader with ground-truth state.
memory.update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst, t, msg)
neighbor_loader.insert(pos_dst, src, t, msg)

loss.backward()
optimizer.step()
model.detach_memory()
memory.detach()
total_loss += float(loss) * batch.num_events

model.flush_msg_store()

return total_loss / train_data.num_events


@torch.no_grad()
def test(data, current_event_id):
model.eval()
def test(data):
memory.eval()
gnn.eval()
link_pred.eval()

torch.manual_seed(12345) # Ensure deterministic sampling across epochs.

aps, aucs = [], []
for batch in data.seq_batches(batch_size=200):
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

neg_dst = torch.randint(min_dst_idx, max_dst_idx + 1, (src.size(0), ),
dtype=torch.long, device=device)

n_id = torch.cat([src, pos_dst, neg_dst]).unique()
query_t = t[0]
adj_t, n_id = sampler(n_id, t=query_t.cpu())
n_id, edge_index, t_edge, msg_edge = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)

z, _ = model(n_id, query_t)
z = gnn(z, adj_t, query_t) # Embed memory via graph convolution.
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, t_edge, msg_edge)

pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])
Expand All @@ -186,15 +165,17 @@ def test(data, current_event_id):
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))

model.update_state(src, pos_dst, t, msg)
memory.update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst, t, msg)
neighbor_loader.insert(pos_dst, src, t, msg)

return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())


for epoch in range(1, 51):
loss = train()
print(f' Epoch: {epoch:02d}, Loss: {loss:.4f}')
val_ap, val_auc = test(val_data, len(train_data))
test_ap, test_auc = test(test_data, len(train_data) + len(val_data))
val_ap, val_auc = test(val_data)
test_ap, test_auc = test(test_data)
print(f' Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .gnn_explainer import GNNExplainer
from .metapath2vec import MetaPath2Vec
from .deepgcn import DeepGCNLayer
from .tgn import TGN
from .tgn import TGNMemory

__all__ = [
'JumpingKnowledge',
Expand All @@ -29,5 +29,5 @@
'GNNExplainer',
'MetaPath2Vec',
'DeepGCNLayer',
'TGN',
'TGNMemory',
]

0 comments on commit 7be5640

Please sign in to comment.