Skip to content

Commit

Permalink
updated rgcn conv with memory-efficient aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 30, 2020
1 parent 271146a commit 776f891
Show file tree
Hide file tree
Showing 5 changed files with 397 additions and 83 deletions.
66 changes: 47 additions & 19 deletions examples/rgcn.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,81 @@
import argparse
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Entities
from torch_geometric.nn import RGCNConv
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.nn import RGCNConv, FastRGCNConv

name = 'MUTAG'
path = osp.join(
osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities', name)
dataset = Entities(path, name)
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str,
choices=['AIFB', 'MUTAG', 'BGS', 'AM'])
args = parser.parse_args()

# Trade memory consumption for faster computation.
if args.dataset in ['AIFB', 'MUTAG']:
RGCNConv = FastRGCNConv

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')
dataset = Entities(path, args.dataset)
data = dataset[0]

# BGS and AM graphs are too big to process them in a full-batch fashion.
# Since our model does only make use of a rather small receptive field, we
# filter the graph to only contain the nodes that are at most 2-hop neighbors
# away from any training/test node.
node_idx = torch.cat([data.train_idx, data.test_idx], dim=0)
node_idx, edge_index, mapping, edge_mask = k_hop_subgraph(
node_idx, 2, data.edge_index, relabel_nodes=True)

data.num_nodes = node_idx.size(0)
data.edge_index = edge_index
data.edge_type = data.edge_type[edge_mask]
data.train_idx = mapping[:data.train_idx.size(0)]
data.test_idx = mapping[data.train_idx.size(0):]


class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = RGCNConv(
data.num_nodes, 16, dataset.num_relations, num_bases=30)
self.conv2 = RGCNConv(
16, dataset.num_classes, dataset.num_relations, num_bases=30)
self.conv1 = RGCNConv(data.num_nodes, 16, dataset.num_relations,
num_bases=30)
self.conv2 = RGCNConv(16, dataset.num_classes, dataset.num_relations,
num_bases=30)

def forward(self, edge_index, edge_type, edge_norm):
def forward(self, edge_index, edge_type):
x = F.relu(self.conv1(None, edge_index, edge_type))
x = self.conv2(x, edge_index, edge_type)
return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu') if args.dataset == 'AM' else device
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)


def train():
model.train()
optimizer.zero_grad()
out = model(data.edge_index, data.edge_type, data.edge_norm)
F.nll_loss(out[data.train_idx], data.train_y).backward()
out = model(data.edge_index, data.edge_type)
loss = F.nll_loss(out[data.train_idx], data.train_y)
loss.backward()
optimizer.step()
return loss.item()


@torch.no_grad()
def test():
model.eval()
out = model(data.edge_index, data.edge_type, data.edge_norm)
pred = out[data.test_idx].max(1)[1]
acc = pred.eq(data.test_y).sum().item() / data.test_y.size(0)
return acc
pred = model(data.edge_index, data.edge_type).argmax(dim=-1)
train_acc = pred[data.train_idx].eq(data.train_y).to(torch.float).mean()
test_acc = pred[data.test_idx].eq(data.test_y).to(torch.float).mean()
return train_acc.item(), test_acc.item()


for epoch in range(1, 51):
train()
test_acc = test()
print('Epoch: {:02d}, Accuracy: {:.4f}'.format(epoch, test_acc))
loss = train()
train_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
f'Test: {test_acc:.4f}')
111 changes: 97 additions & 14 deletions test/nn/conv/test_rgcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,102 @@
from itertools import product

import pytest
import torch
from torch_geometric.nn import RGCNConv
from torch_sparse import SparseTensor
from torch_geometric.nn import RGCNConv, FastRGCNConv

classes = [RGCNConv, FastRGCNConv]
confs = [(None, None), (2, None), (None, 2)]


@pytest.mark.parametrize('conf', confs)
def test_rgcn_conv_equality(conf):
num_bases, num_blocks = conf

x1 = torch.randn(4, 4)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]])
edge_type = torch.tensor([0, 1, 1, 0, 0, 1])

# edge_index
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3],
[0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1],
])
edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3])

torch.manual_seed(12345)
conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks)

torch.manual_seed(12345)
conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks)

out1 = conv1(x1, edge_index, edge_type)
out2 = conv2(x1, edge_index, edge_type)
assert torch.allclose(out1, out2)


@pytest.mark.parametrize('cls,conf', product(classes, confs))
def test_rgcn_conv(cls, conf):
num_bases, num_blocks = conf

x1 = torch.randn(4, 4)
x2 = torch.randn(2, 16)
idx1 = torch.arange(4)
idx2 = torch.arange(2)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]])
edge_type = torch.tensor([0, 1, 1, 0, 0, 1])
row, col = edge_index
adj = SparseTensor(row=row, col=col, value=edge_type, sparse_sizes=(4, 4))

conv = cls(4, 32, 2, num_bases, num_blocks)
assert conv.__repr__() == f'{cls.__name__}(4, 32, num_relations=2)'
out1 = conv(x1, edge_index, edge_type)
assert out1.size() == (4, 32)
assert conv(x1, adj.t()).tolist() == out1.tolist()

if num_blocks is None:
out2 = conv(None, edge_index, edge_type)
assert out2.size() == (4, 32)
assert conv(None, adj.t()).tolist() == out2.tolist()

t = '(OptTensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index, edge_type).tolist() == out1.tolist()
if num_blocks is None:
assert jit(idx1, edge_index, edge_type).tolist() == out2.tolist()
assert jit(None, edge_index, edge_type).tolist() == out2.tolist()

t = '(OptTensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out1.tolist()
if num_blocks is None:
assert jit(idx1, adj.t()).tolist() == out2.tolist()
assert jit(None, adj.t()).tolist() == out2.tolist()

adj = adj.sparse_resize((4, 2))
conv = cls((4, 16), 32, 2, num_bases, num_blocks)
assert conv.__repr__() == f'{cls.__name__}((4, 16), 32, num_relations=2)'
out1 = conv((x1, x2), edge_index, edge_type)
assert out1.size() == (2, 32)
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()

def test_rgcn_conv():
in_channels, out_channels = (16, 32)
num_relations = 8
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
num_nodes = edge_index.max().item() + 1
x = torch.randn((num_nodes, in_channels))
edge_type = torch.randint(0, num_relations, (edge_index.size(1), ))
if num_blocks is None:
out2 = conv((None, idx2), edge_index, edge_type)
assert out2.size() == (2, 32)
assert torch.allclose(conv((idx1, idx2), edge_index, edge_type), out2)
assert conv((None, idx2), adj.t()).tolist() == out2.tolist()
assert conv((idx1, idx2), adj.t()).tolist() == out2.tolist()

conv = RGCNConv(in_channels, out_channels, num_relations, num_bases=4)
assert conv.__repr__() == 'RGCNConv(16, 32, num_relations=8)'
assert conv(x, edge_index, edge_type).size() == (num_nodes, out_channels)
t = '(Tuple[OptTensor, Tensor], Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index, edge_type).tolist() == out1.tolist()
if num_blocks is None:
assert torch.allclose(jit((None, idx2), edge_index, edge_type), out2)
assert torch.allclose(jit((idx1, idx2), edge_index, edge_type), out2)

x = None
conv = RGCNConv(num_nodes, out_channels, num_relations, num_bases=4)
assert conv(x, edge_index, edge_type).size() == (num_nodes, out_channels)
t = '(Tuple[OptTensor, Tensor], SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
if num_blocks is None:
assert jit((None, idx2), adj.t()).tolist() == out2.tolist()
assert jit((idx1, idx2), adj.t()).tolist() == out2.tolist()
18 changes: 9 additions & 9 deletions torch_geometric/datasets/entities.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import os.path as osp
from collections import Counter

import gzip
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add

from torch_geometric.data import (InMemoryDataset, Data, download_url,
extract_tar)
Expand Down Expand Up @@ -40,6 +39,14 @@ def __init__(self, root, name, transform=None, pre_transform=None):
super(Entities, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_dir(self):
return osp.join(self.root, self.name, 'raw')

@property
def processed_dir(self):
return osp.join(self.root, self.name, 'processed')

@property
def num_relations(self):
return self.data.edge_type.max().item() + 1
Expand Down Expand Up @@ -102,12 +109,6 @@ def freq(rel):
edge = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
edge_index, edge_type = edge[:2], edge[2]

oh = F.one_hot(edge_type,
num_classes=2 * len(relations)).to(torch.float)
deg = scatter_add(oh, edge_index[0], dim=0, dim_size=len(nodes))
index = edge_type + torch.arange(len(edge_list)) * 2 * len(relations)
edge_norm = 1 / deg[edge_index[0]].view(-1)[index]

if self.name == 'am':
label_header = 'label_cateogory'
nodes_header = 'proxy'
Expand Down Expand Up @@ -148,7 +149,6 @@ def freq(rel):

data = Data(edge_index=edge_index)
data.edge_type = edge_type
data.edge_norm = edge_norm
data.train_idx = train_idx
data.train_y = train_y
data.test_idx = test_idx
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .sg_conv import SGConv
from .appnp import APPNP
from .mf_conv import MFConv
from .rgcn_conv import RGCNConv
from .rgcn_conv import RGCNConv, FastRGCNConv
from .signed_conv import SignedConv
from .dna_conv import DNAConv
from .point_conv import PointConv
Expand Down Expand Up @@ -46,6 +46,7 @@
'APPNP',
'MFConv',
'RGCNConv',
'FastRGCNConv',
'SignedConv',
'DNAConv',
'PointConv',
Expand Down

0 comments on commit 776f891

Please sign in to comment.