Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent 3935ccf commit 349d874
Showing 1 changed file with 52 additions and 81 deletions.
133 changes: 52 additions & 81 deletions examples/colors_topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,120 +2,96 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.nn import GINConv, TopKPooling
from torch_geometric.nn import global_add_pool as gsum
from torch_geometric.nn import global_add_pool
from torch_scatter import scatter_mean


class HandleNodeAttention(object):
def __call__(self, data):
if data.x.dim() == 1:
data.x = data.x.unsqueeze(-1)
data.node_attention = torch.softmax(data.x[:, 0], dim=0)
if data.x.shape[1] > 1:
data.x = data.x[:, 1:]
else:
# not supposed to use node attention as node features,
# because it is typically not available in the val/test set
data.x = None

data.attn = torch.softmax(data.x[:, 0], dim=0)
data.x = data.x[:, 1:]
return data


train_path = osp.join(
osp.dirname(osp.realpath(__file__)), '..', 'data', 'COLORS-3')
dataset = TUDataset(train_path, name='COLORS-3', use_node_attr=True,
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'COLORS-3')
dataset = TUDataset(path, 'COLORS-3', use_node_attr=True,
transform=HandleNodeAttention())

n_train, n_val, n_test_each = 500, 2500, 2500

train_dataset = dataset[:n_train]
train_loader = DataLoader(train_dataset, batch_size=60, shuffle=True)
val_loader = DataLoader(dataset[n_train:n_train + n_val], batch_size=60)
test_loader = DataLoader(dataset[n_train + n_val:], batch_size=60)
train_loader = DataLoader(dataset[:500], batch_size=60, shuffle=True)
val_loader = DataLoader(dataset[500:3000], batch_size=60)
test_loader = DataLoader(dataset[3000:], batch_size=60)


class Net(torch.nn.Module):
def __init__(self):
def __init__(self, in_channels):
super(Net, self).__init__()

self.conv1 = GINConv(
nn.Sequential(
nn.Linear(train_dataset.num_features, 256), nn.ReLU(),
nn.Linear(256, 64)))
self.pool1 = TopKPooling(train_dataset.num_features, min_score=0.05)
self.conv2 = GINConv(
nn.Sequential(nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 64)))
self.conv1 = GINConv(Seq(Lin(in_channels, 256), ReLU(), Lin(256, 64)))
self.pool1 = TopKPooling(in_channels, min_score=0.05)
self.conv2 = GINConv(Seq(Lin(64, 256), ReLU(), Lin(256, 64)))

self.lin = torch.nn.Linear(64, 1) # regression
self.lin = torch.nn.Linear(64, 1)

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch

x_input = x
x = F.relu(self.conv1(x_input, edge_index))
out = F.relu(self.conv1(x, edge_index))

x, edge_index, _, batch, perm, score = self.pool1(
x, edge_index, None, batch, attn_input=x_input)
ratio = x.shape[0] / float(x_input.shape[0])
out, edge_index, _, batch, perm, score = self.pool1(
out, edge_index, None, batch, attn_input=x)
ratio = out.size(0) / x.size(0)

x = F.relu(self.conv2(x, edge_index))
x = gsum(x, batch)
x = self.lin(x)
out = F.relu(self.conv2(out, edge_index))
out = global_add_pool(out, batch)
out = self.lin(out).view(-1)

# supervised node attention
attn_loss_batch = scatter_mean(
F.kl_div(
torch.log(score + 1e-14), data.node_attention[perm],
reduction='none'), batch)
attn_loss = F.kl_div(
torch.log(score + 1e-14), data.attn[perm], reduction='none')
attn_loss = scatter_mean(attn_loss, batch)

return x, attn_loss_batch, ratio
return out, attn_loss, ratio


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
model = Net(dataset.num_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Initialize to optimal attention weights:
# model.pool1.weight.data = torch.tensor([0., 1., 0., 0.]).view(1,4).to(device)

print(model)
print('model size: %d trainable parameters' % np.sum(
[np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()]))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
model.train()

loss_all = 0
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output, attn_loss, _ = model(data)
loss = ((data.y - output.view_as(data.y))**2 + 100 * attn_loss).mean()

out, attn_loss, _ = model(data)
loss = ((out - data.y).pow(2) + 100 * attn_loss).mean()
loss.backward()
loss_all += data.num_graphs * loss.item()
total_loss += loss.item() * data.num_graphs
optimizer.step()

return loss_all / len(train_dataset)
return total_loss / len(train_loader.dataset)


def test(loader):
model.eval()

correct, ratio_all = [], 0
corrects, total_ratio = [], 0
for data in loader:
data = data.to(device)
output, _, ratio = model(data)
pred = output.round().long().view_as(data.y)
correct += list(pred.eq(data.y.long()).data.cpu().numpy())
ratio_all += ratio
return np.array(correct), ratio_all / len(loader)
out, _, ratio = model(data)
pred = out.round().to(torch.long)
corrects.append(pred.eq(data.y.to(torch.long)))
total_ratio += ratio
return torch.cat(corrects, dim=0), total_ratio / len(loader)


for epoch in range(1, 301):
Expand All @@ -124,21 +100,16 @@ def test(loader):
val_correct, val_ratio = test(val_loader)
test_correct, test_ratio = test(test_loader)

train_acc = train_correct.sum() / len(train_correct)
val_acc = val_correct.sum() / len(val_correct)

# Test on three different subsets
test_correct1 = test_correct[:n_test_each].sum()
test_correct2 = test_correct[n_test_each:2 * n_test_each].sum()
test_correct3 = test_correct[n_test_each * 2:].sum()
assert len(test_correct) == n_test_each * 3, len(test_correct)

print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.3f}, Val Acc: {:.3f}, '
'Test Acc Orig: {:.3f} ({}/{}), '
'Test Acc Large: {:.3f} ({}/{}), '
'Test Acc LargeC: {:.3f} ({}/{}), '
'Train/Val/Test Pool Ratio={:.3f}/{:.3f}/{:.3f}'.format(
epoch, loss, train_acc, val_acc, test_correct1 / n_test_each,
test_correct1, n_test_each, test_correct2 / n_test_each,
test_correct2, n_test_each, test_correct3 / n_test_each,
test_correct3, n_test_each, train_ratio, val_ratio, test_ratio))
train_acc = train_correct.sum().item() / train_correct.size(0)
val_acc = val_correct.sum().item() / val_correct.size(0)
test_acc = test_correct.sum().item() / test_correct.size(0)

# Test on three different subsets.
test_acc1 = test_correct[:2500].sum().item() / 2500
test_acc2 = test_correct[2500:5000].sum().item() / 2500
test_acc3 = test_correct[5000:].sum().item() / 2500

print(('Epoch: {:03d}, Loss: {:.4f}, Train: {:.3f}, Val: {:.3f}, '
'Test: {:.3f}, Train/Val/Test Ratio={:.3f}/{:.3f}/{:.3f}').format(
epoch, loss, train_acc, val_acc, test_acc1, test_acc2,
test_acc3, train_ratio, val_ratio, test_ratio))

0 comments on commit 349d874

Please sign in to comment.