-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
259 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch_geometric.datasets import MNISTSuperpixels | ||
import torch_geometric.transforms as T | ||
from torch_geometric.data import DataLoader | ||
from torch_geometric.utils import normalized_cut | ||
from torch_geometric.nn import (NNConv, graclus, max_pool, max_pool_x, | ||
global_mean_pool) | ||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST') | ||
train_dataset = MNISTSuperpixels(path, True, transform=T.Cartesian()) | ||
test_dataset = MNISTSuperpixels(path, False, transform=T.Cartesian()) | ||
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | ||
test_loader = DataLoader(test_dataset, batch_size=64) | ||
d = train_dataset.data | ||
|
||
|
||
def normalized_cut_2d(edge_index, pos): | ||
row, col = edge_index | ||
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) | ||
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
n1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), nn.Linear(25, 32)) | ||
self.conv1 = NNConv(d.num_features, 32, n1) | ||
|
||
n2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(), nn.Linear(25, 2048)) | ||
self.conv2 = NNConv(32, 64, n2) | ||
|
||
self.fc1 = torch.nn.Linear(64, 128) | ||
self.fc2 = torch.nn.Linear(128, d.num_classes) | ||
|
||
def forward(self, data): | ||
data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) | ||
weight = normalized_cut_2d(data.edge_index, data.pos) | ||
cluster = graclus(data.edge_index, weight, data.x.size(0)) | ||
data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) | ||
|
||
data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) | ||
weight = normalized_cut_2d(data.edge_index, data.pos) | ||
cluster = graclus(data.edge_index, weight, data.x.size(0)) | ||
x, batch = max_pool_x(cluster, data.x, data.batch) | ||
|
||
x = global_mean_pool(x, batch) | ||
x = F.elu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
return F.log_softmax(self.fc2(x), dim=1) | ||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = Net().to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | ||
|
||
|
||
def train(epoch): | ||
model.train() | ||
|
||
if epoch == 16: | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = 0.001 | ||
|
||
if epoch == 26: | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = 0.0001 | ||
|
||
for data in train_loader: | ||
data = data.to(device) | ||
optimizer.zero_grad() | ||
F.nll_loss(model(data), data.y).backward() | ||
optimizer.step() | ||
|
||
|
||
def test(): | ||
model.eval() | ||
correct = 0 | ||
|
||
for data in test_loader: | ||
data = data.to(device) | ||
pred = model(data).max(1)[1] | ||
correct += pred.eq(data.y).sum().item() | ||
return correct / len(test_dataset) | ||
|
||
|
||
for epoch in range(1, 31): | ||
train(epoch) | ||
test_acc = test() | ||
print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn import Sequential, Linear, ReLU, GRU | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.datasets import QM9 | ||
from torch_geometric.nn import NNConv, Set2Set | ||
from torch_geometric.data import DataLoader | ||
from torch_geometric.utils import remove_self_loops | ||
|
||
target = 0 | ||
dim = 73 | ||
|
||
|
||
class MyTransform(object): | ||
def __call__(self, data): | ||
# Pad features. | ||
x = data.x | ||
data.x = torch.cat([x, x.new_zeros(x.size(0), dim - x.size(1))], dim=1) | ||
|
||
# Specify target. | ||
data.y = data.y[:, target] | ||
return data | ||
|
||
|
||
class Complete(object): | ||
def __call__(self, data): | ||
device = data.edge_index.device | ||
|
||
row = torch.arange(data.num_nodes, dtype=torch.long, device=device) | ||
col = torch.arange(data.num_nodes, dtype=torch.long, device=device) | ||
|
||
row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) | ||
col = col.repeat(data.num_nodes) | ||
edge_index = torch.stack([row, col], dim=0) | ||
|
||
edge_attr = None | ||
if data.edge_attr is not None: | ||
idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] | ||
size = list(data.edge_attr.size()) | ||
size[0] = data.num_nodes * data.num_nodes | ||
edge_attr = data.edge_attr.new_zeros(size) | ||
edge_attr[idx] = data.edge_attr | ||
|
||
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) | ||
data.edge_attr = edge_attr | ||
data.edge_index = edge_index | ||
|
||
return data | ||
|
||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9') | ||
transform = T.Compose([MyTransform(), Complete(), T.Distance()]) | ||
dataset = QM9(path, transform=transform).shuffle() | ||
|
||
# Normalize targets to mean = 0 and std = 1. | ||
mean = dataset.data.y[:, target].mean().item() | ||
std = dataset.data.y[:, target].std().item() | ||
dataset.data.y[:, target] = (dataset.data.y[:, target] - mean) / std | ||
|
||
# Split datasets. | ||
test_dataset = dataset[:10000] | ||
val_dataset = dataset[10000:20000] | ||
train_dataset = dataset[20000:] | ||
test_loader = DataLoader(test_dataset, batch_size=64) | ||
val_loader = DataLoader(val_dataset, batch_size=64) | ||
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | ||
|
||
|
||
class Net(torch.nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim)) | ||
self.conv = NNConv(dim, dim, nn, root_weight=False) | ||
self.gru = GRU(dim, dim, batch_first=True) | ||
self.set2set = Set2Set(dim, dim, processing_steps=3) | ||
self.fc1 = torch.nn.Linear(2 * dim, dim) | ||
self.fc2 = torch.nn.Linear(dim, 1) | ||
|
||
def forward(self, data): | ||
out = data.x | ||
h = data.x.unsqueeze(0) | ||
|
||
for i in range(3): | ||
m = F.relu(self.conv(out, data.edge_index, data.edge_attr)) | ||
out, h = self.gru(m.unsqueeze(1), h) | ||
out = out.squeeze(1) | ||
|
||
out = self.set2set(out, data.batch) | ||
out = F.relu(self.fc1(out)) | ||
out = self.fc2(out) | ||
out = out.view(-1) | ||
return out | ||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = Net().to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | ||
optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) | ||
|
||
|
||
def train(epoch): | ||
model.train() | ||
loss_all = 0 | ||
|
||
for data in train_loader: | ||
data = data.to(device) | ||
optimizer.zero_grad() | ||
loss = F.mse_loss(model(data), data.y) | ||
loss.backward() | ||
loss_all += loss.item() * data.num_graphs | ||
optimizer.step() | ||
return loss_all / len(train_loader.dataset) | ||
|
||
|
||
def test(loader): | ||
model.eval() | ||
error = 0 | ||
|
||
for data in loader: | ||
data = data.to(device) | ||
error += (model(data) * std - data.y * std).abs().sum().item() # MAE | ||
return error / len(loader.dataset) | ||
|
||
|
||
best_val_error = None | ||
for epoch in range(1, 301): | ||
lr = scheduler.optimizer.param_groups[0]['lr'] | ||
loss = train(epoch) | ||
val_error = test(val_loader) | ||
scheduler.step(val_error) | ||
|
||
if best_val_error is None or val_error <= best_val_error: | ||
test_error = test(test_loader) | ||
best_val_error = val_error | ||
|
||
print('Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' | ||
'Test MAE: {:.7f},'.format(epoch, lr, loss, val_error, test_error)) |