Skip to content

Commit

Permalink
refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
Weihua Hu authored and Weihua Hu committed Dec 28, 2018
1 parent 0636984 commit 6bce28c
Show file tree
Hide file tree
Showing 7 changed files with 630 additions and 0 deletions.
Binary file added dataset.zip
Binary file not shown.
159 changes: 159 additions & 0 deletions main.py
@@ -0,0 +1,159 @@
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from pathlib import Path

from tqdm import tqdm

from util import load_data, separate_data
from models.graphcnn import GraphCNN

criterion = nn.CrossEntropyLoss()

def train(args, model, device, train_graphs, optimizer, epoch):
model.train()

total_iters = args.iters_per_epoch
pbar = tqdm(range(total_iters), unit='batch')

loss_accum = 0
for pos in pbar:
selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]

batch_graph = [train_graphs[idx] for idx in selected_idx]
output = model(batch_graph)

labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)

#compute loss
loss = criterion(output, labels)

#backprop
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()


loss = loss.detach().cpu().numpy()
loss_accum += loss

#report
pbar.set_description('epoch: %d' % (epoch))

average_loss = loss_accum/total_iters
print("loss training: %f" % (average_loss))

return average_loss

###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation)
def pass_data_iteratively(model, graphs, minibatch_size = 64):
model.eval()
output = []
idx = np.arange(len(graphs))
for i in range(0, len(graphs), minibatch_size):
sampled_idx = idx[i:i+minibatch_size]
if len(sampled_idx) == 0:
continue
output.append(model([graphs[j] for j in sampled_idx]).detach())
return torch.cat(output, 0)

def test(args, model, device, train_graphs, test_graphs, epoch):
model.eval()

output = pass_data_iteratively(model, train_graphs)
pred = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
acc_train = correct / float(len(train_graphs))

output = pass_data_iteratively(model, test_graphs)
pred = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
acc_test = correct / float(len(test_graphs))

print("accuracy train: %f test: %f" % (acc_train, acc_test))

return acc_train, acc_test

def main():
# Training settings
# Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification')
parser.add_argument('--dataset', type=str, default="MUTAG",
help='name of dataset (default: MUTAG)')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
help='input batch size for training (default: 32)')
parser.add_argument('--iters_per_epoch', type=int, default=50,
help='number of iterations per each epoch (default: 50)')
parser.add_argument('--epochs', type=int, default=350,
help='number of epochs to train (default: 351)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--decay', type=float, default=0,
help='weight decay (default: 0.0)')
parser.add_argument('--seed', type=int, default=0,
help='random seed for splitting the dataset into 10 (default: 0)')
parser.add_argument('--fold_idx', type=int, default=0,
help='the index of fold in 10-fold validation. Should be less then 10.')
parser.add_argument('--num_layers', type=int, default=5,
help='number of layers INCLUDING the input one (default: 5)')
parser.add_argument('--num_mlp_layers', type=int, default=2,
help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.')
parser.add_argument('--hidden_dim', type=int, default=64,
help='number of hidden units (default: 64)')
parser.add_argument('--final_dropout', type=float, default=0.5,
help='final layer dropout (default: 0.5)')
parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"],
help='Pooling for over nodes in a graph: sum or average')
parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],
help='Pooling for over neighboring nodes: sum, average or max')
parser.add_argument('--learn_eps', action="store_true",
help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
parser.add_argument('--degree_as_tag', action="store_true",
help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
parser.add_argument('--filename', type = str, default = "",
help='output file')
args = parser.parse_args()

#set up seeds and gpu device
torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

graphs, num_classes = load_data(args.dataset, args.degree_as_tag)

##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)

model = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)


for epoch in range(1, args.epochs + 1):
scheduler.step()

avg_loss = train(args, model, device, train_graphs, optimizer, epoch)
acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch)

if not args.filename == "":
with open(args.filename, 'w') as f:
f.write("%f %f %f" % (avg_loss, acc_train, acc_test))
f.write("\n")
print("")

print(model.eps)


if __name__ == '__main__':
main()
Binary file added models/__pycache__/graphcnn.cpython-36.pyc
Binary file not shown.
Binary file added models/__pycache__/mlp.cpython-36.pyc
Binary file not shown.
227 changes: 227 additions & 0 deletions models/graphcnn.py
@@ -0,0 +1,227 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append("models/")
from mlp import MLP

class GraphCNN(nn.Module):
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device):
'''
num_layers: number of layers in the neural networks (INCLUDING the input layer)
num_mlp_layers: number of layers in mlps (EXCLUDING the input layer)
input_dim: dimensionality of input features
hidden_dim: dimensionality of hidden units at ALL layers
output_dim: number of classes for prediction
final_dropout: dropout ratio on the final linear layer
learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
device: which device to use
'''

super(GraphCNN, self).__init__()

self.final_dropout = final_dropout
self.device = device
self.num_layers = num_layers
self.graph_pooling_type = graph_pooling_type
self.neighbor_pooling_type = neighbor_pooling_type
self.learn_eps = learn_eps
self.eps = nn.Parameter(torch.zeros(self.num_layers-1))

###List of MLPs
self.mlps = torch.nn.ModuleList()

###List of batchnorms applied to the output of MLP (input of the final prediction linear layer)
self.batch_norms = torch.nn.ModuleList()

for layer in range(self.num_layers-1):
if layer == 0:
self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim))
else:
self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim))

self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

#Linear function that maps the hidden representation at dofferemt layers into a prediction score
self.linears_prediction = torch.nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linears_prediction.append(nn.Linear(input_dim, output_dim))
else:
self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))


def __preprocess_neighbors_maxpool(self, batch_graph):
###create padded_neighbor_list in concatenated graph

#compute the maximum number of neighbors within the graphs in the current minibatch
max_deg = max([graph.max_neighbor for graph in batch_graph])

padded_neighbor_list = []
start_idx = [0]


for i, graph in enumerate(batch_graph):
start_idx.append(start_idx[i] + len(graph.g))
padded_neighbors = []
for j in range(len(graph.neighbors)):
#add off-set values to the neighbor indices
pad = [n + start_idx[i] for n in graph.neighbors[j]]
#padding, dummy data is assumed to be stored in -1
pad.extend([-1]*(max_deg - len(pad)))

#Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
if not self.learn_eps:
pad.append(j + start_idx[i])

padded_neighbors.append(pad)
padded_neighbor_list.extend(padded_neighbors)

return torch.LongTensor(padded_neighbor_list)


def __preprocess_neighbors_sumavepool(self, batch_graph):
###create block diagonal sparse matrix

edge_mat_list = []
start_idx = [0]
for i, graph in enumerate(batch_graph):
start_idx.append(start_idx[i] + len(graph.g))
edge_mat_list.append(graph.edge_mat + start_idx[i])
Adj_block_idx = torch.cat(edge_mat_list, 1)
Adj_block_elem = torch.ones(Adj_block_idx.shape[1])

#Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.

if not self.learn_eps:
num_node = start_idx[-1]
self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
elem = torch.ones(num_node)
Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)

Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]]))

return Adj_block.to(self.device)


def __preprocess_graphpool(self, batch_graph):
###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)

start_idx = [0]

#compute the padded neighbor list
for i, graph in enumerate(batch_graph):
start_idx.append(start_idx[i] + len(graph.g))

idx = []
elem = []
for i, graph in enumerate(batch_graph):
###average pooling
if self.graph_pooling_type == "average":
elem.extend([1./len(graph.g)]*len(graph.g))

else:
###sum pooling
elem.extend([1]*len(graph.g))

idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)])
elem = torch.FloatTensor(elem)
idx = torch.LongTensor(idx).transpose(0,1)
graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]]))

return graph_pool.to(self.device)

def maxpool(self, h, padded_neighbor_list):
###Element-wise minimum will never affect max-pooling

dummy = torch.min(h, dim = 0)[0]
h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)])
pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0]
return pooled_rep


def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None):
###pooling neighboring nodes and center nodes separately by epsilon reweighting.

if self.neighbor_pooling_type == "max":
##If max pooling
pooled = self.maxpool(h, padded_neighbor_list)
else:
#If sum or average pooling
pooled = torch.spmm(Adj_block, h)
if self.neighbor_pooling_type == "average":
#If average pooling
degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
pooled = pooled/degree

#Reweights the center node representation when aggregating it with its neighbors
pooled = pooled + (1 + self.eps[layer])*h
pooled_rep = self.mlps[layer](pooled)
h = self.batch_norms[layer](pooled_rep)

#non-linearity
h = F.relu(h)
return h


def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None):
###pooling neighboring nodes and center nodes altogether

if self.neighbor_pooling_type == "max":
##If max pooling
pooled = self.maxpool(h, padded_neighbor_list)
else:
#If sum or average pooling
pooled = torch.spmm(Adj_block, h)
if self.neighbor_pooling_type == "average":
#If average pooling
degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
pooled = pooled/degree

#representation of neighboring and center nodes
pooled_rep = self.mlps[layer](pooled)

h = self.batch_norms[layer](pooled_rep)

#non-linearity
h = F.relu(h)
return h


def forward(self, batch_graph):
X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device)
graph_pool = self.__preprocess_graphpool(batch_graph)

if self.neighbor_pooling_type == "max":
padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
else:
Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)

#list of hidden representation at each layer (including input)
hidden_rep = [X_concat]
h = X_concat

for layer in range(self.num_layers-1):
if self.neighbor_pooling_type == "max" and self.learn_eps:
h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list)
elif not self.neighbor_pooling_type == "max" and self.learn_eps:
h = self.next_layer_eps(h, layer, Adj_block = Adj_block)
elif self.neighbor_pooling_type == "max" and not self.learn_eps:
h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list)
elif not self.neighbor_pooling_type == "max" and not self.learn_eps:
h = self.next_layer(h, layer, Adj_block = Adj_block)

hidden_rep.append(h)

score_over_layer = 0

#perform pooling over all nodes in each graph in every layer
for layer, h in enumerate(hidden_rep):
pooled_h = torch.spmm(graph_pool, h)
score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training)

return score_over_layer

0 comments on commit 6bce28c

Please sign in to comment.