In [1]:
import torch
import random
import pandas as pd
import torch_scatter
import torch.nn as nn
from torch.nn import Linear, Sequential, LayerNorm, ReLU
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import DataLoader
from torch.utils.data import random_split

import numpy as np
import time
import torch.optim as optim
from tqdm import trange
import pandas as pd
import copy
import matplotlib.pyplot as plt

## Import Data ##

In [2]:
dataset = torch.load('/workspace/data_gen/futurepred_graphs.pt')[:300]

## Building the Model ##

In [3]:
class MeshGraphNet(torch.nn.Module):
    def __init__(self, input_dim_node, input_dim_edge, hidden_dim, output_dim, args, emb=False):
        super(MeshGraphNet, self).__init__()
        """
        MeshGraphNet model. This model is built upon Deepmind's 2021 paper.
        This model consists of three parts: (1) Preprocessing: encoder (2) Processor
        (3) postproccessing: decoder. Encoder has an edge and node decoders respectively.
        Processor has two processors for edge and node respectively. Note that edge attributes have to be
        updated first. Decoder is only for nodes.

        Input_dim: dynamic variables + node_type + node_position
        Hidden_dim: 128 in deepmind's paper
        Output_dim: dynamic variables: velocity changes (1)

        """

        self.num_layers = args.num_layers

        # encoder convert raw inputs into latent embeddings
        self.node_encoder = Sequential(Linear(input_dim_node , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, hidden_dim),
                              LayerNorm(hidden_dim))

        self.edge_encoder = Sequential(Linear( input_dim_edge , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, hidden_dim),
                              LayerNorm(hidden_dim)
                              )


        self.processor = nn.ModuleList()
        assert (self.num_layers >= 1), 'Number of message passing layers is not <1'

        processor_layer=self.build_processor_model()
        for _ in range(self.num_layers):
            self.processor.append(processor_layer(hidden_dim,hidden_dim))


        # decoder: only for node embeddings
        self.decoder = Sequential(Linear( hidden_dim , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, output_dim)
                              )


    def build_processor_model(self):
        return ProcessorLayer


    def forward(self,data):
        """
        Encoder encodes graph (node/edge features) into latent vectors (node/edge embeddings)
        The return of processor is fed into the processor for generating new feature vectors
        """
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Step 1: encode node/edge features into latent node/edge embeddings
        x = self.node_encoder(x) # output shape is the specified hidden dimension

        edge_attr = self.edge_encoder(edge_attr) # output shape is the specified hidden dimension

        # step 2: perform message passing with latent node/edge embeddings
        for i in range(self.num_layers):
            x,edge_attr = self.processor[i](x,edge_index,edge_attr)

        # step 3: decode latent node embeddings into physical quantities of interest

        return self.decoder(x)

    # def loss(self, pred, inputs,mean_vec_y,std_vec_y):
    #     #Define the node types that we calculate loss for
    #     normal=torch.tensor(0)
    #     outflow=torch.tensor(5)

    #     #Get the loss mask for the nodes of the types we calculate loss for
    #     loss_mask=torch.logical_or((torch.argmax(inputs.x[:,2:],dim=1)==torch.tensor(0)),
    #                                (torch.argmax(inputs.x[:,2:],dim=1)==torch.tensor(5)))

    #     #Normalize labels with dataset statistics
    #     labels = normalize(inputs.y,mean_vec_y,std_vec_y)

    #     #Find sum of square errors
    #     error=torch.sum((labels-pred)**2,axis=1)

    #     #Root and mean the errors for the nodes we calculate loss for
    #     loss=torch.sqrt(torch.mean(error[loss_mask]))
        
    #     return loss

## Edge message passing, aggregation, and passing ##

In [4]:
class ProcessorLayer(MessagePassing):
    def __init__(self, in_channels, out_channels,  **kwargs):
        super(ProcessorLayer, self).__init__(  **kwargs )
        """
        in_channels: dim of node embeddings [128], out_channels: dim of edge embeddings [128]

        """

        # Note that the node and edge encoders both have the same hidden dimension
        # size. This means that the input of the edge processor will always be
        # three times the specified hidden dimension
        # (input: adjacent node embeddings and self embeddings)
        self.edge_mlp = Sequential(Linear( 3* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))

        self.node_mlp = Sequential(Linear( 2* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))


        self.reset_parameters()

    def reset_parameters(self):
        """
        reset parameters for stacked MLP layers
        """
        self.edge_mlp[0].reset_parameters()
        self.edge_mlp[2].reset_parameters()

        self.node_mlp[0].reset_parameters()
        self.node_mlp[2].reset_parameters()

    def forward(self, x, edge_index, edge_attr, size = None):
        """
        Handle the pre and post-processing of node features/embeddings,
        as well as initiates message passing by calling the propagate function.

        Note that message passing and aggregation are handled by the propagate
        function, and the update

        x has shpae [node_num , in_channels] (node embeddings)
        edge_index: [2, edge_num]
        edge_attr: [E, in_channels]

        """

        out, updated_edges = self.propagate(edge_index, x = x, edge_attr = edge_attr, size = size) # out has the shape of [E, out_channels]

        updated_nodes = torch.cat([x,out],dim=1)        # Complete the aggregation through self-aggregation

        updated_nodes = x + self.node_mlp(updated_nodes) # residual connection

        return updated_nodes, updated_edges

    def message(self, x_i, x_j, edge_attr):
        """
        source_node: x_i has the shape of [E, in_channels]
        target_node: x_j has the shape of [E, in_channels]
        target_edge: edge_attr has the shape of [E, out_channels]

        The messages that are passed are the raw embeddings. These are not processed.
        """

        updated_edges=torch.cat([x_i, x_j, edge_attr], dim = 1) # tmp_emb has the shape of [E, 3 * in_channels]
        updated_edges=self.edge_mlp(updated_edges)+edge_attr

        return updated_edges

    def aggregate(self, updated_edges, edge_index, dim_size = None):
        """
        First we aggregate from neighbors (i.e., adjacent nodes) through concatenation,
        then we aggregate self message (from the edge itself). This is streamlined
        into one operation here.
        """

        # The axis along which to index number of nodes.
        node_dim = 0

        out = torch_scatter.scatter(updated_edges, edge_index[0, :], dim=node_dim, reduce = 'sum')

        return out, updated_edges

Build optimizer

In [5]:
def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p : p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    return scheduler, optimizer

## Training and Testing ##

In [6]:
class objectview(object):
        def __init__(self, d):
            self.__dict__ = d

In [7]:
for args in [
            {
            'hidden_dim': 32,
            'num_layers': 3,
            'batch_size': 1,
            'lr': 0.001,
            'opt': 'adam',
            'opt_scheduler': 'none',
            'opt_restart': 0,
            'weight_decay': 5e-4,
            'num_epochs': 5000,
            'seed': 42,
            'epochs': 5000,
            },
        ]:
            args = objectview(args)

In [8]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)


#Build the data loader
#Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)

num_node_features = dataset[0].x.shape[1]
num_edge_features = dataset[0].edge_attr.shape[1]
num_classes = 1 #only one prediction per node: the spike value at the next time step

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

#Build the model
model = MeshGraphNet(num_node_features, num_edge_features, 
                     args.hidden_dim, num_classes, args).to(device)


#Build the optimizer
scheduler, optimizer = build_optimizer(args, model.parameters())

#Build the loss function
#loss_fn = nn.NLLLoss()
loss_fn = nn.MSELoss()
#loss_fn = nn.CrossEntropyLoss()

#Define a pandas dataframe to store the training results
df = pd.DataFrame(columns=['epoch', 'loss', 'accuracy', 'test_loss', 'test_accuracy'])



In [9]:
import sys

#Train the model
for epoch in range(args.epochs):
    model.train()

    total_loss = 0
    accuracy = 0
    num_batches = 0
    for i, data in enumerate(train_loader):

        #Print the epoch and the batch number. Erase the previous line to avoid cluttering the terminal
        if(i%10==0):
            print("\rEpoch: %d, Batch: %d" % (epoch, i), end='')
            sys.stdout.flush()

        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        total_loss += loss.item()
        num_batches += 1
        #Add to the accuracy the number of correct binary predictions
        accuracy += 0#(out.round(decimals=0) == data.y).sum().item()/len(data.y)

        loss.backward()
        optimizer.step()

    if scheduler is not None:
        scheduler.step()

    #Find the performance on the test set
    model.eval()
    test_accuracy = 0
    test_loss = 0
    test_num_batches = 0
    for data in test_loader:
        data = data.to(device)
        out = model(data)
        test_loss += loss_fn(out, data.y).item()
        test_accuracy += 0#(out.round(decimals=0) == data.y).sum().item()/len(data.y)
        test_num_batches += 1

    print('\nEpoch: {:03d}, Train Loss: {:.7f}, Train Accuracy: {:.3}, Test Loss: {:.7f}, Test Accuracy: {:.3}'.format(epoch, 
            total_loss/num_batches, accuracy/num_batches, test_loss/test_num_batches, test_accuracy/test_num_batches))

    #Store the results in the dataframe
    df = pd.concat([df, pd.DataFrame({'epoch': epoch, 'loss': total_loss/num_batches, 
                                        'accuracy': accuracy/num_batches, 'test_loss': test_loss/test_num_batches,
                                        'test_accuracy': test_accuracy/test_num_batches
                                        }, index=[0])], ignore_index=True)
    #Save the dataframe to a csv file
    df.to_csv('results_forwardpred.csv', index=False)

    if(epoch==0):
        best_loss = test_loss/test_num_batches
    if(test_loss/test_num_batches < best_loss):
        best_loss = test_loss/test_num_batches
        torch.save(model.state_dict(), 'model_forwardpred.pt')

Epoch: 0, Batch: 230
Epoch: 000, Train Loss: 0.8862426, Train Accuracy: 0.0, Test Loss: 0.8748832, Test Accuracy: 0.0
Epoch: 1, Batch: 230
Epoch: 001, Train Loss: 0.8624209, Train Accuracy: 0.0, Test Loss: 0.8735591, Test Accuracy: 0.0
Epoch: 2, Batch: 100