Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Weihua Hu
authored and
Weihua Hu
committed
Dec 28, 2018
1 parent
0636984
commit 6bce28c
Showing
7 changed files
with
630 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
from tqdm import tqdm | ||
|
||
from util import load_data, separate_data | ||
from models.graphcnn import GraphCNN | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
|
||
def train(args, model, device, train_graphs, optimizer, epoch): | ||
model.train() | ||
|
||
total_iters = args.iters_per_epoch | ||
pbar = tqdm(range(total_iters), unit='batch') | ||
|
||
loss_accum = 0 | ||
for pos in pbar: | ||
selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size] | ||
|
||
batch_graph = [train_graphs[idx] for idx in selected_idx] | ||
output = model(batch_graph) | ||
|
||
labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device) | ||
|
||
#compute loss | ||
loss = criterion(output, labels) | ||
|
||
#backprop | ||
if optimizer is not None: | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
|
||
loss = loss.detach().cpu().numpy() | ||
loss_accum += loss | ||
|
||
#report | ||
pbar.set_description('epoch: %d' % (epoch)) | ||
|
||
average_loss = loss_accum/total_iters | ||
print("loss training: %f" % (average_loss)) | ||
|
||
return average_loss | ||
|
||
###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation) | ||
def pass_data_iteratively(model, graphs, minibatch_size = 64): | ||
model.eval() | ||
output = [] | ||
idx = np.arange(len(graphs)) | ||
for i in range(0, len(graphs), minibatch_size): | ||
sampled_idx = idx[i:i+minibatch_size] | ||
if len(sampled_idx) == 0: | ||
continue | ||
output.append(model([graphs[j] for j in sampled_idx]).detach()) | ||
return torch.cat(output, 0) | ||
|
||
def test(args, model, device, train_graphs, test_graphs, epoch): | ||
model.eval() | ||
|
||
output = pass_data_iteratively(model, train_graphs) | ||
pred = output.max(1, keepdim=True)[1] | ||
labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device) | ||
correct = pred.eq(labels.view_as(pred)).sum().cpu().item() | ||
acc_train = correct / float(len(train_graphs)) | ||
|
||
output = pass_data_iteratively(model, test_graphs) | ||
pred = output.max(1, keepdim=True)[1] | ||
labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device) | ||
correct = pred.eq(labels.view_as(pred)).sum().cpu().item() | ||
acc_test = correct / float(len(test_graphs)) | ||
|
||
print("accuracy train: %f test: %f" % (acc_train, acc_test)) | ||
|
||
return acc_train, acc_test | ||
|
||
def main(): | ||
# Training settings | ||
# Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper. | ||
parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification') | ||
parser.add_argument('--dataset', type=str, default="MUTAG", | ||
help='name of dataset (default: MUTAG)') | ||
parser.add_argument('--device', type=int, default=0, | ||
help='which gpu to use if any (default: 0)') | ||
parser.add_argument('--batch_size', type=int, default=32, | ||
help='input batch size for training (default: 32)') | ||
parser.add_argument('--iters_per_epoch', type=int, default=50, | ||
help='number of iterations per each epoch (default: 50)') | ||
parser.add_argument('--epochs', type=int, default=350, | ||
help='number of epochs to train (default: 351)') | ||
parser.add_argument('--lr', type=float, default=0.01, | ||
help='learning rate (default: 0.01)') | ||
parser.add_argument('--decay', type=float, default=0, | ||
help='weight decay (default: 0.0)') | ||
parser.add_argument('--seed', type=int, default=0, | ||
help='random seed for splitting the dataset into 10 (default: 0)') | ||
parser.add_argument('--fold_idx', type=int, default=0, | ||
help='the index of fold in 10-fold validation. Should be less then 10.') | ||
parser.add_argument('--num_layers', type=int, default=5, | ||
help='number of layers INCLUDING the input one (default: 5)') | ||
parser.add_argument('--num_mlp_layers', type=int, default=2, | ||
help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.') | ||
parser.add_argument('--hidden_dim', type=int, default=64, | ||
help='number of hidden units (default: 64)') | ||
parser.add_argument('--final_dropout', type=float, default=0.5, | ||
help='final layer dropout (default: 0.5)') | ||
parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"], | ||
help='Pooling for over nodes in a graph: sum or average') | ||
parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"], | ||
help='Pooling for over neighboring nodes: sum, average or max') | ||
parser.add_argument('--learn_eps', action="store_true", | ||
help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.') | ||
parser.add_argument('--degree_as_tag', action="store_true", | ||
help='let the input node features be the degree of nodes (heuristics for unlabeled graph)') | ||
parser.add_argument('--filename', type = str, default = "", | ||
help='output file') | ||
args = parser.parse_args() | ||
|
||
#set up seeds and gpu device | ||
torch.manual_seed(0) | ||
np.random.seed(0) | ||
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed_all(0) | ||
|
||
graphs, num_classes = load_data(args.dataset, args.degree_as_tag) | ||
|
||
##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx. | ||
train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx) | ||
|
||
model = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) | ||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) | ||
|
||
|
||
for epoch in range(1, args.epochs + 1): | ||
scheduler.step() | ||
|
||
avg_loss = train(args, model, device, train_graphs, optimizer, epoch) | ||
acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch) | ||
|
||
if not args.filename == "": | ||
with open(args.filename, 'w') as f: | ||
f.write("%f %f %f" % (avg_loss, acc_train, acc_test)) | ||
f.write("\n") | ||
print("") | ||
|
||
print(model.eps) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
import sys | ||
sys.path.append("models/") | ||
from mlp import MLP | ||
|
||
class GraphCNN(nn.Module): | ||
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device): | ||
''' | ||
num_layers: number of layers in the neural networks (INCLUDING the input layer) | ||
num_mlp_layers: number of layers in mlps (EXCLUDING the input layer) | ||
input_dim: dimensionality of input features | ||
hidden_dim: dimensionality of hidden units at ALL layers | ||
output_dim: number of classes for prediction | ||
final_dropout: dropout ratio on the final linear layer | ||
learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether. | ||
neighbor_pooling_type: how to aggregate neighbors (mean, average, or max) | ||
graph_pooling_type: how to aggregate entire nodes in a graph (mean, average) | ||
device: which device to use | ||
''' | ||
|
||
super(GraphCNN, self).__init__() | ||
|
||
self.final_dropout = final_dropout | ||
self.device = device | ||
self.num_layers = num_layers | ||
self.graph_pooling_type = graph_pooling_type | ||
self.neighbor_pooling_type = neighbor_pooling_type | ||
self.learn_eps = learn_eps | ||
self.eps = nn.Parameter(torch.zeros(self.num_layers-1)) | ||
|
||
###List of MLPs | ||
self.mlps = torch.nn.ModuleList() | ||
|
||
###List of batchnorms applied to the output of MLP (input of the final prediction linear layer) | ||
self.batch_norms = torch.nn.ModuleList() | ||
|
||
for layer in range(self.num_layers-1): | ||
if layer == 0: | ||
self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)) | ||
else: | ||
self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)) | ||
|
||
self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) | ||
|
||
#Linear function that maps the hidden representation at dofferemt layers into a prediction score | ||
self.linears_prediction = torch.nn.ModuleList() | ||
for layer in range(num_layers): | ||
if layer == 0: | ||
self.linears_prediction.append(nn.Linear(input_dim, output_dim)) | ||
else: | ||
self.linears_prediction.append(nn.Linear(hidden_dim, output_dim)) | ||
|
||
|
||
def __preprocess_neighbors_maxpool(self, batch_graph): | ||
###create padded_neighbor_list in concatenated graph | ||
|
||
#compute the maximum number of neighbors within the graphs in the current minibatch | ||
max_deg = max([graph.max_neighbor for graph in batch_graph]) | ||
|
||
padded_neighbor_list = [] | ||
start_idx = [0] | ||
|
||
|
||
for i, graph in enumerate(batch_graph): | ||
start_idx.append(start_idx[i] + len(graph.g)) | ||
padded_neighbors = [] | ||
for j in range(len(graph.neighbors)): | ||
#add off-set values to the neighbor indices | ||
pad = [n + start_idx[i] for n in graph.neighbors[j]] | ||
#padding, dummy data is assumed to be stored in -1 | ||
pad.extend([-1]*(max_deg - len(pad))) | ||
|
||
#Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. | ||
if not self.learn_eps: | ||
pad.append(j + start_idx[i]) | ||
|
||
padded_neighbors.append(pad) | ||
padded_neighbor_list.extend(padded_neighbors) | ||
|
||
return torch.LongTensor(padded_neighbor_list) | ||
|
||
|
||
def __preprocess_neighbors_sumavepool(self, batch_graph): | ||
###create block diagonal sparse matrix | ||
|
||
edge_mat_list = [] | ||
start_idx = [0] | ||
for i, graph in enumerate(batch_graph): | ||
start_idx.append(start_idx[i] + len(graph.g)) | ||
edge_mat_list.append(graph.edge_mat + start_idx[i]) | ||
Adj_block_idx = torch.cat(edge_mat_list, 1) | ||
Adj_block_elem = torch.ones(Adj_block_idx.shape[1]) | ||
|
||
#Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. | ||
|
||
if not self.learn_eps: | ||
num_node = start_idx[-1] | ||
self_loop_edge = torch.LongTensor([range(num_node), range(num_node)]) | ||
elem = torch.ones(num_node) | ||
Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1) | ||
Adj_block_elem = torch.cat([Adj_block_elem, elem], 0) | ||
|
||
Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]])) | ||
|
||
return Adj_block.to(self.device) | ||
|
||
|
||
def __preprocess_graphpool(self, batch_graph): | ||
###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes) | ||
|
||
start_idx = [0] | ||
|
||
#compute the padded neighbor list | ||
for i, graph in enumerate(batch_graph): | ||
start_idx.append(start_idx[i] + len(graph.g)) | ||
|
||
idx = [] | ||
elem = [] | ||
for i, graph in enumerate(batch_graph): | ||
###average pooling | ||
if self.graph_pooling_type == "average": | ||
elem.extend([1./len(graph.g)]*len(graph.g)) | ||
|
||
else: | ||
###sum pooling | ||
elem.extend([1]*len(graph.g)) | ||
|
||
idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)]) | ||
elem = torch.FloatTensor(elem) | ||
idx = torch.LongTensor(idx).transpose(0,1) | ||
graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]])) | ||
|
||
return graph_pool.to(self.device) | ||
|
||
def maxpool(self, h, padded_neighbor_list): | ||
###Element-wise minimum will never affect max-pooling | ||
|
||
dummy = torch.min(h, dim = 0)[0] | ||
h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)]) | ||
pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0] | ||
return pooled_rep | ||
|
||
|
||
def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None): | ||
###pooling neighboring nodes and center nodes separately by epsilon reweighting. | ||
|
||
if self.neighbor_pooling_type == "max": | ||
##If max pooling | ||
pooled = self.maxpool(h, padded_neighbor_list) | ||
else: | ||
#If sum or average pooling | ||
pooled = torch.spmm(Adj_block, h) | ||
if self.neighbor_pooling_type == "average": | ||
#If average pooling | ||
degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) | ||
pooled = pooled/degree | ||
|
||
#Reweights the center node representation when aggregating it with its neighbors | ||
pooled = pooled + (1 + self.eps[layer])*h | ||
pooled_rep = self.mlps[layer](pooled) | ||
h = self.batch_norms[layer](pooled_rep) | ||
|
||
#non-linearity | ||
h = F.relu(h) | ||
return h | ||
|
||
|
||
def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None): | ||
###pooling neighboring nodes and center nodes altogether | ||
|
||
if self.neighbor_pooling_type == "max": | ||
##If max pooling | ||
pooled = self.maxpool(h, padded_neighbor_list) | ||
else: | ||
#If sum or average pooling | ||
pooled = torch.spmm(Adj_block, h) | ||
if self.neighbor_pooling_type == "average": | ||
#If average pooling | ||
degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) | ||
pooled = pooled/degree | ||
|
||
#representation of neighboring and center nodes | ||
pooled_rep = self.mlps[layer](pooled) | ||
|
||
h = self.batch_norms[layer](pooled_rep) | ||
|
||
#non-linearity | ||
h = F.relu(h) | ||
return h | ||
|
||
|
||
def forward(self, batch_graph): | ||
X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device) | ||
graph_pool = self.__preprocess_graphpool(batch_graph) | ||
|
||
if self.neighbor_pooling_type == "max": | ||
padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph) | ||
else: | ||
Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph) | ||
|
||
#list of hidden representation at each layer (including input) | ||
hidden_rep = [X_concat] | ||
h = X_concat | ||
|
||
for layer in range(self.num_layers-1): | ||
if self.neighbor_pooling_type == "max" and self.learn_eps: | ||
h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list) | ||
elif not self.neighbor_pooling_type == "max" and self.learn_eps: | ||
h = self.next_layer_eps(h, layer, Adj_block = Adj_block) | ||
elif self.neighbor_pooling_type == "max" and not self.learn_eps: | ||
h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list) | ||
elif not self.neighbor_pooling_type == "max" and not self.learn_eps: | ||
h = self.next_layer(h, layer, Adj_block = Adj_block) | ||
|
||
hidden_rep.append(h) | ||
|
||
score_over_layer = 0 | ||
|
||
#perform pooling over all nodes in each graph in every layer | ||
for layer, h in enumerate(hidden_rep): | ||
pooled_h = torch.spmm(graph_pool, h) | ||
score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training) | ||
|
||
return score_over_layer |
Oops, something went wrong.