Skip to content

Commit

Permalink
Merge pull request #484 from bknyaz/master
Browse files Browse the repository at this point in the history
Synthetic COLORS and TRIANGLES datasets and Threshold-based Top-K and SAG pooling
  • Loading branch information
rusty1s committed Jul 6, 2019
2 parents 5847d09 + 8400c3f commit 55c2e7d
Show file tree
Hide file tree
Showing 13 changed files with 507 additions and 66 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ In detail, the following methods are currently implemented:
* **[Dense Differentiable Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.diff_pool.dense_diff_pool)** from Ying *et al.*: [Hierarchical Graph Representation Learning with Differentiable Pooling](https://arxiv.org/abs/1806.08804) (NeurIPS 2018)
* **[Graclus Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.graclus)** from Dhillon *et al.*: [Weighted Graph Cuts without Eigenvectors: A Multilevel Approach](http://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf) (PAMI 2007)
* **[Voxel Grid Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.voxel_grid)** from, *e.g.*, Simonovsky and Komodakis: [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017)
* **[Top-K Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.TopKPooling)** from Gao and Ji: [Graph U-Net](https://openreview.net/forum?id=HJePRoAct7) (ICLR 2019 submission) and Cangea *et al.*: [Towards Sparse Hierarchical Graph Classifiers](https://arxiv.org/abs/1811.01287) (NeurIPS-W 2018)
* **[SAG Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.SAGPooling)** from Lee *et al.*: [Self-Attention Graph Pooling](https://arxiv.org/abs/1904.08082) (ICML 2019)
* **[Top-K Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.TopKPooling)** from Gao and Ji: [Graph U-Net](https://openreview.net/forum?id=HJePRoAct7) (ICLR 2019 submission), Cangea *et al.*: [Towards Sparse Hierarchical Graph Classifiers](https://arxiv.org/abs/1811.01287) (NeurIPS-W 2018) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019)
* **[SAG Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.SAGPooling)** from Lee *et al.*: [Self-Attention Graph Pooling](https://arxiv.org/abs/1904.08082) (ICML 2019) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019)
* **[Local Degree Profile](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.LocalDegreeProfile)** from Cai and Wang: [A Simple yet Effective Baseline for Non-attribute Graph Classification](https://arxiv.org/abs/1811.03508) (CoRR 2018)
* **[Jumping Knowledge](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.JumpingKnowledge)** from Xu *et al.*: [Representation Learning on Graphs with Jumping Knowledge Networks](https://arxiv.org/abs/1806.03536) (ICML 2018)
* **[Deep Graph Infomax](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.DeepGraphInfomax)** from Veličković *et al.*: [Deep Graph Infomax](https://arxiv.org/abs/1809.10341) (ICLR 2019)
Expand Down
134 changes: 134 additions & 0 deletions examples/colors_topk_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os.path as osp
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import SyntheticDataset
from torch_geometric.transforms import HandleNodeAttention
from torch_geometric.data import DataLoader
from torch_geometric.nn import GraphConv, GINConv, TopKPooling
from torch_geometric.nn import global_add_pool as gsum
from torch_scatter import scatter_mean


train_path = osp.join(osp.dirname(osp.realpath(__file__)), '..',
'data', 'COLORS-3')
dataset = SyntheticDataset(train_path, name='COLORS-3', use_node_attr=True,
transform=HandleNodeAttention())

n_train, n_val, n_test_each = 500, 2500, 2500

train_dataset = dataset[:n_train]
train_loader = DataLoader(train_dataset, batch_size=60, shuffle=True)
val_loader = DataLoader(dataset[n_train:n_train + n_val], batch_size=60)
test_loader = DataLoader(dataset[n_train + n_val:], batch_size=60)


class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()

self.conv1 = GINConv(nn.Sequential(nn.Linear(train_dataset.
num_features, 256),
nn.ReLU(),
nn.Linear(256, 64)))
self.pool1 = TopKPooling(train_dataset.num_features, min_score=0.05)
self.conv2 = GINConv(nn.Sequential(nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 64)))

self.lin = torch.nn.Linear(64, 1) # regression

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch

x_input = x
x = F.relu(self.conv1(x_input, edge_index))

x, edge_index, _, batch, perm, score = self.pool1(x, edge_index,
None, batch,
attn_input=x_input)
ratio = x.shape[0] / float(x_input.shape[0])

x = F.relu(self.conv2(x, edge_index))
x = gsum(x, batch)
x = self.lin(x)

# supervised node attention
attn_loss_batch = scatter_mean(F.kl_div(torch.log(score + 1e-14),
data.node_attention[perm],
reduction='none'), batch)

return x, attn_loss_batch, ratio


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
# Initialize to optimal attention weights:
# model.pool1.weight.data = torch.tensor([0., 1., 0., 0.]).view(1,4).to(device)

print(model)
print('model size: %d trainable parameters' %
np.sum([np.prod(p.size()) if p.requires_grad else 0
for p in model.parameters()]))


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
model.train()

loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output, attn_loss, _ = model(data)
loss = ((data.y - output.view_as(data.y)) ** 2 + 100*attn_loss).mean()

loss.backward()
loss_all += data.num_graphs * loss.item()
optimizer.step()

return loss_all / len(train_dataset)


def test(loader):
model.eval()

correct, ratio_all = [], 0
for data in loader:
data = data.to(device)
output, _, ratio = model(data)
pred = output.round().long().view_as(data.y)
correct += list(pred.eq(data.y.long()).data.cpu().numpy())
ratio_all += ratio
return np.array(correct), ratio_all / len(loader)


for epoch in range(1, 301):
loss = train(epoch)
train_correct, train_ratio = test(train_loader)
val_correct, val_ratio = test(val_loader)
test_correct, test_ratio = test(test_loader)

train_acc = train_correct.sum() / len(train_correct)
val_acc = val_correct.sum() / len(val_correct)

# Test on three different subsets
test_correct1 = test_correct[:n_test_each].sum()
test_correct2 = test_correct[n_test_each: 2*n_test_each].sum()
test_correct3 = test_correct[n_test_each*2:].sum()
assert len(test_correct) == n_test_each*3, len(test_correct)

print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.3f}, Val Acc: {:.3f}, '
'Test Acc Orig: {:.3f} ({}/{}), '
'Test Acc Large: {:.3f} ({}/{}), '
'Test Acc LargeC: {:.3f} ({}/{}), '
'Train/Val/Test Pool Ratio={:.3f}/{:.3f}/{:.3f}'.
format(epoch, loss, train_acc, val_acc,
test_correct1 / n_test_each, test_correct1, n_test_each,
test_correct2 / n_test_each, test_correct2, n_test_each,
test_correct3 / n_test_each, test_correct3, n_test_each,
train_ratio, val_ratio, test_ratio))
6 changes: 3 additions & 3 deletions examples/enzymes_topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch

x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

x = F.relu(self.conv3(x, edge_index))
x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

x = x1 + x2 + x3
Expand Down
139 changes: 139 additions & 0 deletions examples/triangles_sag_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os.path as osp
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import SyntheticDataset
from torch_geometric.transforms import OneHotDegree, HandleNodeAttention
from torch_geometric.transforms import Compose
from torch_geometric.data import DataLoader
from torch_geometric.nn import GraphConv, GINConv, SAGPooling
from torch_geometric.nn import global_max_pool as gmp
from torch_scatter import scatter_mean


transform = Compose([HandleNodeAttention(), OneHotDegree(max_degree=14)])

train_path = osp.join(osp.dirname(osp.realpath(__file__)),
'..', 'data', 'TRIANGLES')
dataset = SyntheticDataset(train_path, name='TRIANGLES', use_node_attr=True,
transform=transform)

n_train, n_val, n_test_each = 30000, 5000, 5000

train_dataset = dataset[:n_train]
train_loader = DataLoader(train_dataset, batch_size=60, shuffle=True)
val_loader = DataLoader(dataset[n_train:n_train + n_val], batch_size=60)
test_loader = DataLoader(dataset[n_train + n_val:], batch_size=60)


class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()

self.conv1 = GINConv(nn.Sequential(nn.Linear(train_dataset.
num_features, 64),
nn.ReLU(),
nn.Linear(64, 64)))
self.pool1 = SAGPooling(64, min_score=0.001, gnn='GCN')
self.conv2 = GINConv(nn.Sequential(nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 64)))

self.pool2 = SAGPooling(64, min_score=0.001, gnn='GCN')

self.conv3 = GINConv(nn.Sequential(nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 64)))

self.lin = torch.nn.Linear(64, 1) # regression

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch

x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, perm, score = self.pool1(x, edge_index,
None, batch)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, perm, score = self.pool2(x, edge_index,
None, batch)
ratio = x.shape[0] / float(data.x.shape[0])

x = F.relu(self.conv3(x, edge_index))
x = gmp(x, batch)

x = self.lin(x)

# supervised node attention
attn_loss_batch = scatter_mean(F.kl_div(torch.log(score + 1e-14),
data.node_attention[perm],
reduction='none'), batch)

return x, attn_loss_batch, ratio


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
print(model)
print('model size: %d trainable parameters' %
np.sum([np.prod(p.size()) if p.requires_grad else 0
for p in model.parameters()]))


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
model.train()

loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output, attn_loss, _ = model(data)

loss = ((data.y - output.view_as(data.y)) ** 2 + 100*attn_loss).mean()

loss.backward()
loss_all += data.num_graphs * loss.item()
optimizer.step()

return loss_all / len(train_dataset)


def test(loader):
model.eval()

correct, ratio_all = [], 0
for data in loader:
data = data.to(device)
output, _, ratio = model(data)
pred = output.round().long().view_as(data.y)
correct += list(pred.eq(data.y.long()).data.cpu().numpy())
ratio_all += ratio
return np.array(correct), ratio_all / len(loader)


for epoch in range(1, 101):
loss = train(epoch)
train_correct, train_ratio = test(train_loader)
val_correct, val_ratio = test(val_loader)
test_correct, test_ratio = test(test_loader)

train_acc = train_correct.sum() / len(train_correct)
val_acc = val_correct.sum() / len(val_correct)

# Test on two different subsets
test_correct1 = test_correct[:n_test_each].sum()
test_correct2 = test_correct[n_test_each:].sum()
assert len(test_correct) == n_test_each*2, len(test_correct)

print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.3f}, Val Acc: {:.3f}, '
'Test Acc Orig: {:.3f} ({}/{}), '
'Test Acc Large: {:.3f} ({}/{}), '
'Train/Val/Test Pool Ratio={:.3f}/{:.3f}/{:.3f}'.
format(epoch, loss, train_acc, val_acc,
test_correct1 / n_test_each, test_correct1, n_test_each,
test_correct2 / n_test_each, test_correct2, n_test_each,
train_ratio, val_ratio, test_ratio))
4 changes: 3 additions & 1 deletion test/nn/pool/test_sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def test_sag_pooling():

for gnn in ['GraphConv', 'GCN', 'GAT', 'SAGE']:
pool = SAGPooling(in_channels, ratio=0.5, gnn=gnn)
assert pool.__repr__() == 'SAGPooling({}, 16, ratio=0.5)'.format(gnn)
assert pool.__repr__() == 'SAGPooling({}, 16, ratio=0.5, ' \
'min_score=None, ' \
'multiplier=None)'.format(gnn)
out = pool(x, edge_index)
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)
5 changes: 3 additions & 2 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def test_topk_pooling():
x = torch.randn((num_nodes, in_channels))

pool = TopKPooling(in_channels, ratio=0.5)
assert pool.__repr__() == 'TopKPooling(16, ratio=0.5)'
assert pool.__repr__() == 'TopKPooling(16, ratio=0.5, ' \
'min_score=None, multiplier=None)'

x, edge_index, _, _, _ = pool(x, edge_index)
x, edge_index, _, _, _, _ = pool(x, edge_index)
assert x.size() == (num_nodes // 2, in_channels)
assert edge_index.size() == (2, 2)
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .bitcoin_otc import BitcoinOTC
from .icews import ICEWS18
from .gdelt import GDELT
from .synthetic_dataset import SyntheticDataset

__all__ = [
'KarateClub',
Expand Down Expand Up @@ -52,4 +53,5 @@
'BitcoinOTC',
'ICEWS18',
'GDELT',
'SyntheticDataset',
]
28 changes: 28 additions & 0 deletions torch_geometric/datasets/synthetic_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ..datasets import TUDataset


class SyntheticDataset(TUDataset):
r""" Synthetic COLORS and TRIANGLES datasets from the
`"Understanding Attention and Generalization in Graph Neural
Networks" <https://arxiv.org/abs/1905.02850>`_ paper
The datasets have the same format as :class:`TUDataset`,
but have additional node attention data.
This class has the same arguments as :class:`TUDataset`.
"""

url = 'https://github.com/bknyaz/graph_attention_pool/raw/master/data'

def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None,
use_node_attr=False):
self.name = name
super(SyntheticDataset, self).__init__(root, name, transform,
pre_transform,
pre_filter,
use_node_attr=use_node_attr)

0 comments on commit 55c2e7d

Please sign in to comment.