# TRAIN

In [1]:
from dataprocessing.utils.normalization import get_stats
import torch
import numpy as np
import random
from train import train
from utils.visualization import save_plots
from torch_geometric.loader import DataLoader

## SETUP

In [24]:
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

for args in [
        {'model_type': 'autoencoder',  
         'num_layers': 4,
         'batch_size': 16, 
         'hidden_dim': 64, 
         'epochs': 500,
         'opt': 'adam', 
         'opt_scheduler': 'none', 
         'opt_restart': 0, 
         'weight_decay': 5e-4, 
         'lr': 0.001,
         'train_size': 100, 
         'test_size': 40, 
         'device':'cuda',
         'shuffle': True, 
         'save_velo_val': True,
         'save_best_model': False, 
         'checkpoint_dir': './best_models/',
         'postprocess_dir': './2d_loss_plots/'},
    ]:
        args = objectview(args)

#To ensure reproducibility the best we can, here we control the sources of
#randomness by seeding the various random number generators used in this Colab
#For more information, see: https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(5)  #Torch
random.seed(5)        #Python
np.random.seed(5)     #NumPy

In [25]:
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
args.device = device
print(device)

True
True
mps


In [26]:
dataset = torch.load('data/trajectories/trajectory_0.pt')[:(args.train_size+args.test_size)]
stats_list = get_stats(dataset)
loader = DataLoader(dataset)

## Multiscale Graph Neural Network

In [27]:
import torch
import torch_scatter
import torch.nn as nn
from torch.nn import Linear, Sequential, LayerNorm, ReLU
from torch_geometric.nn.conv import MessagePassing

In [28]:

class ProcessorLayer(MessagePassing):
    def __init__(self, in_channels, out_channels,  **kwargs):
        super(ProcessorLayer, self).__init__(  **kwargs )
        """
        in_channels: dim of node embeddings [128], out_channels: dim of edge embeddings [128]

        """

        # Note that the node and edge encoders both have the same hidden dimension
        # size. This means that the input of the edge processor will always be
        # three times the specified hidden dimension
        # (input: adjacent node embeddings and self embeddings)
        self.edge_mlp = Sequential(Linear( 3* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))

        self.node_mlp = Sequential(Linear( 2* in_channels , out_channels),
                                   ReLU(),
                                   Linear( out_channels, out_channels),
                                   LayerNorm(out_channels))


        self.reset_parameters()

    def reset_parameters(self):
        """
        reset parameters for stacked MLP layers
        """
        self.edge_mlp[0].reset_parameters()
        self.edge_mlp[2].reset_parameters()

        self.node_mlp[0].reset_parameters()
        self.node_mlp[2].reset_parameters()

    def forward(self, x, edge_index, edge_attr, size = None):
        """
        Handle the pre and post-processing of node features/embeddings,
        as well as initiates message passing by calling the propagate function.

        Note that message passing and aggregation are handled by the propagate
        function, and the update

        x has shpae [node_num , in_channels] (node embeddings)
        edge_index: [2, edge_num]
        edge_attr: [E, in_channels]

        """
        

        out, updated_edges = self.propagate(edge_index, x = x, edge_attr = edge_attr, size = size) # out has the shape of [E, out_channels]

        updated_nodes = torch.cat([x,out],dim=1)        # Complete the aggregation through self-aggregation

        updated_nodes = x + self.node_mlp(updated_nodes) # residual connection

        return updated_nodes, updated_edges

    def message(self, x_i, x_j, edge_attr):
        """
        source_node: x_i has the shape of [E, in_channels]
        target_node: x_j has the shape of [E, in_channels]
        target_edge: edge_attr has the shape of [E, out_channels]

        The messages that are passed are the raw embeddings. These are not processed.
        """

        updated_edges=torch.cat([x_i, x_j, edge_attr], dim = 1) # tmp_emb has the shape of [E, 3 * in_channels]
        updated_edges=self.edge_mlp(updated_edges)+edge_attr

        return updated_edges

    def aggregate(self, updated_edges, edge_index, dim_size = None):
        """
        First we aggregate from neighbors (i.e., adjacent nodes) through concatenation,
        then we aggregate self message (from the edge itself). This is streamlined
        into one operation here.
        """

        # The axis along which to index number of nodes.
        node_dim = 0

        out = torch_scatter.scatter(updated_edges, edge_index[0, :], dim=node_dim, reduce = 'sum')

        return out, updated_edges

In [29]:
class MMPLayer(torch.nn.Module):
    def __init__(self, input_dim_node, input_dim_edge, hidden_dim, output_dim, args, emb=False):
        super(MMPLayer, self).__init__()
        self.num_layers = args.num_layers

        self.processor = nn.ModuleList()
        assert (self.num_layers >= 1), 'Number of message passing layers is not >=1'

        self.node_encoder = Sequential(Linear(input_dim_node , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, hidden_dim),
                              LayerNorm(hidden_dim))

        self.edge_encoder = Sequential(Linear( input_dim_edge , hidden_dim),
                              ReLU(),
                              Linear( hidden_dim, hidden_dim),
                              LayerNorm(hidden_dim)
                              )

        processor_layer=self.build_processor_model()
        for _ in range(self.num_layers):
            self.processor.append(processor_layer(hidden_dim,hidden_dim))
    
    def build_processor_model(self):
        return ProcessorLayer
    
    def forward(self, data):
        x, edge_index, edge_attr, pressure = data.x, data.edge_index, data.edge_attr, data.p
         # Step 1: encode node/edge features into latent node/edge embeddings
        x = self.node_encoder(x) # output shape is the specified hidden dimension
        print(f'Encoded x: {x.shape}')
        edge_attr = self.edge_encoder(edge_attr) # output shape is the specified hidden dimension
        print(f'Encoded edge: {edge_attr.shape}')
         # step 2: perform message passing with latent node/edge embeddings
        for i in range(self.num_layers):
            x,edge_attr = self.processor[i](x,edge_index,edge_attr)
        return x, edge_attr

In [33]:
sample = next(iter(loader))
input_dim_node, input_dim_edge = sample.num_features, sample.edge_attr.shape[1]
print(f'Input dim node: {input_dim_node}\nInput dim edge: {input_dim_edge}\nHidden dim: {args.hidden_dim}')
model = MMPLayer(input_dim_node, input_dim_edge, args.hidden_dim, 1, args)
x, edge_attr = model(sample)
print(x.shape)
print(edge_attr.shape)

Input dim node: 11
Input dim edge: 3
Hidden dim: 64
Encoded x: torch.Size([1876, 64])
Encoded edge: torch.Size([10788, 64])
torch.Size([1876, 64])
torch.Size([10788, 64])


# PLOTS


In [35]:
import os
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse as sp
import torch
from torch import Tensor
import torch_geometric
from torch_geometric.utils import to_networkx
from torch_geometric.datasets import Planetoid
import networkx as nx
from networkx.algorithms import community
from torch_geometric.nn import GAE, VGAE, GCNConv
import copy
from torch import tensor
from torch_geometric.loader import DataLoader
from torchmetrics import Accuracy, AveragePrecision, Dice


## Make Predictions

In [38]:
[mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge,mean_vec_y,std_vec_y] = stats_list
(mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge,mean_vec_y,std_vec_y)=(mean_vec_x.to(device),
            std_vec_x.to(device),mean_vec_edge.to(device),std_vec_edge.to(device),mean_vec_y.to(device),std_vec_y.to(device))
sample = dataset[0].to(device)
with torch.no_grad():
  loss = torch.nn.CrossEntropyLoss()
  auto_pred = auto_best_model(sample,mean_vec_x,std_vec_x).T
  mesh_pred = mesh_best_model(sample,mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge).T


In [29]:
[mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge,mean_vec_y,std_vec_y] = stats_list
(mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge,mean_vec_y,std_vec_y)=(mean_vec_x.to(device),
            std_vec_x.to(device),mean_vec_edge.to(device),std_vec_edge.to(device),mean_vec_y.to(device),std_vec_y.to(device))

sample_1 = dataset[0].to(device)
sample_2 = dataset[10].to(device)
diff = sample_2.mesh_pos - sample_2.mesh_pos 
print(sample_1.mesh_pos)
print(sample_2.mesh_pos)
print(sample_2.mesh_pos - sample_2.mesh_pos)
print(np.sum(diff.cpu().detach().numpy(), axis = 0))
#edge_attr = sample.edge_attr.cpu().detach().numpy()
# pred gives the learnt accelaration between two timsteps
# next_vel = curr_vel + pred * delta_t  

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='mps:0')
[0. 0.]


In [None]:
G = to_networkx(cylinder_data, to_undirected=True)
pos = nx.spring_layout(G, seed=42)
cent = nx.degree_centrality(G)
node_size = list(map(lambda x: x * 500, cent.values()))
cent_array = np.array(list(cent.values()))
threshold = sorted(cent_array, reverse=True)[10]
print("threshold", threshold)
cent_bin = np.where(cent_array >= threshold, 1, 0.1)
plt.figure(figsize=(12, 12))
nodes = nx.draw_networkx_nodes(G, pos, node_size=node_size,
                               cmap=plt.cm.plasma,
                               node_color=cent_bin,
                               nodelist=list(cent.keys()),
                               alpha=cent_bin)
edges = nx.draw_networkx_edges(G, pos, width=0.25, alpha=0.3)
plt.show()