# Install packages
Please install all any needed packages here.

In [None]:
%pip install numpy matplotlib pandas
%pip install torch
%pip install torch-geometric
%pip install torch-scatter 
%pip install scikit-learn

# Import packages

Please import all the needed packages here

In [None]:
import torch
import torch.nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import torch_geometric.nn as geom_nn
from torch.nn import Sequential as Seq, Linear, ReLU
import numpy as np
import pandas as pd
import torch_scatter
import os.path as osp
from random import sample
from sklearn.metrics.pairwise import euclidean_distances
from os import listdir
from os.path import isfile, join
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import torch_geometric
import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(10)

import warnings
warnings.filterwarnings("ignore")

# Utility functions

All the utility functions required for data processing, model training and post processing are defined below

In [None]:
def min_max_scaler(x, min_val, max_val, device):
    
    """Scales the data using the max and min values.
    
    Parameters
    -----------
    x : an N dimensional tensor, can be node attributes or target. 
    min_val : column wise min of x, array
    max_val : column wise max of x, array
    device : GPU or CPU device
    
    Ouput
    ------
    a tensor of the same size as x with scaled values
    """ 
    return (x.to(device)-min_val.to(device))/(max_val.to(device)-min_val.to(device))
    
def min_max_descaler(x, min_val, max_val, device):
    
    """Descale the data using the max and min values.
    
    Parameters
    -----------
    x : an N dimensional tensor, can be node attributes or target. 
    min_val : column wise min of x, array
    max_val : column wise max of x, array
    device : GPU or CPU device
    
    Ouput
    ------
    a tensor of the same size as x with descaled values
       
    """ 
    return (x.to(device)*(max_val.to(device)-min_val.to(device)))+ min_val.to(device)

def get_max_min(data):
    
    """Generating the maximum and minimum values of the training data.
    
    This is needed if you need to scale the data
    
    Parameters
    -----------
    data : contains a list of graph data
        Each of them are grapha data created using Data function 
        and contains three graphs of different resolutions - coarse, medium and fine
    
    Ouput
    ------
    X_max_c, Y_max_c, E_max_c : maximum of X, Y, and  E (node attributes, target and edge attributes) for the coarse graph, float values
    X_min_c, Y_min_c, E_min_c : minimum of X, Y, and  E (node attributes, target and edge attributes) for the coarse graph, float values
    X_max_m, Y_max_m, E_max_m : maximum of X, Y, and  E (node attributes, target and edge attributes) for the medium graph, float values
    X_min_m, Y_min_m, E_min_m : minimum of X, Y, and  E (node attributes, target and edge attributes) for the medium graph, float values
    X_max_f, Y_max_f, E_max_f : maximum of X, Y, and  E (node attributes, target and edge attributes) for the fine graph, float values
    X_min_f, Y_min_f, E_min_f : minimum of X, Y, and  E (node attributes, target and edge attributes) for the fine graph, float values
    
    """ 
       
    train_loader = DataLoader(data, batch_size=20)
    
    X_c=[]
    Y_c=[]
    E_c=[]
    
    X_m=[]
    Y_m=[]
    E_m=[]
    
    X_f=[]
    Y_f=[]
    E_f=[]
    
    for data in train_loader:
        X_c.append(data[0].x.detach().cpu().numpy())
        Y_c.append(data[0].y.detach().cpu().numpy())
        E_c.append(data[0].edge_attr.detach().cpu().numpy())
        
        X_m.append(data[1].x.detach().cpu().numpy())
        Y_m.append(data[1].y.detach().cpu().numpy())
        E_m.append(data[1].edge_attr.detach().cpu().numpy())
        
        X_f.append(data[2].x.detach().cpu().numpy())
        Y_f.append(data[2].y.detach().cpu().numpy())
        E_f.append(data[2].edge_attr.detach().cpu().numpy())
        
    # coarse graph max and min values
    X_c = np.vstack(X_c)
    Y_c = np.vstack(Y_c)
    E_c = np.vstack(E_c)
    X_max_c = torch.Tensor(np.max(X_c,0))
    Y_max_c = torch.Tensor(np.max(Y_c,0))
    E_max_c = torch.Tensor(np.max(E_c,0))
    X_min_c = torch.Tensor(np.min(X_c,0))
    Y_min_c = torch.Tensor(np.min(Y_c,0))
    E_min_c = torch.Tensor(np.min(E_c,0))
    
    # medium graph max and min values
    X_m = np.vstack(X_m)
    Y_m = np.vstack(Y_m)
    E_m = np.vstack(E_m)
    X_max_m = torch.Tensor(np.max(X_m,0))
    Y_max_m = torch.Tensor(np.max(Y_m,0))
    E_max_m = torch.Tensor(np.max(E_m,0))
    X_min_m = torch.Tensor(np.min(X_m,0))
    Y_min_m = torch.Tensor(np.min(Y_m,0))
    E_min_m = torch.Tensor(np.min(E_m,0))
    
    # fine graph max and min values
    X_f = np.vstack(X_f)
    Y_f = np.vstack(Y_f)
    E_f = np.vstack(E_f)
    X_max_f = torch.Tensor(np.max(X_f,0))
    Y_max_f = torch.Tensor(np.max(Y_f,0))
    E_max_f = torch.Tensor(np.max(E_f,0))
    X_min_f = torch.Tensor(np.min(X_f,0))
    Y_min_f = torch.Tensor(np.min(Y_f,0))
    E_min_f = torch.Tensor(np.min(E_f,0))
    
    return X_max_c, Y_max_c, E_max_c, X_min_c, Y_min_c, E_min_c, X_max_m, Y_max_m, E_max_m, X_min_m, Y_min_m, E_min_m, X_max_f, Y_max_f, E_max_f, X_min_f, Y_min_f, E_min_f



# Graph data creation

This section creates the graph data using the Data function of pytorch geometric.

Each graph data object contains 3 graphs, of the same simulation but with different resolutions.

We identify them as - coarse, medium and fine (differentiated with _c, _m and _f in the variable names).

Coarse graph is the lowest resolution followed by medium and fine graphs respectively. Fine graph has the highest resolution.


In [None]:
def process_graph_data(source_dir, sim):
    
    """Loading the data from the source directory and returns the data
    
    Parameters
    -----------
    source_dir : source directory which contains the original graph data, with the name format 'run_{sim}_{resolution}'.
                resolution can be LF/MF/HF, referring to coarse (low), medium and high resolution data respectively
    sim : simulation/run number, to identify different simulations, int
    
    Ouput
    ------
    data_c - coarse resolution data, Pytorch geometric Data object
    data_m - medium resolution data, Pytorch geometric Data object
    data_f - fine resolution data, Pytorch geometric Data object
       
    """ 
    
    data_c = torch.load(osp.join('{}/run_{}_LF.pt'.format(source_dir,sim)))
    data_m = torch.load(osp.join('{}/run_{}_MF.pt'.format(source_dir,sim)))
    data_f = torch.load(osp.join('{}/run_{}_HF.pt'.format(source_dir,sim)))
    
    #DO THE REST OF THE PROCESSING BASED ON YOUR DATA!!
    
#     # The following edge index correction is done as, in the original data, edge index numbering starts from 1. 
#     # It should start from 0 for correct pytorch geometric data processing
#     data_c.edge_index[0,:] = (data_c.edge_index[0,:]-1).long()
#     data_c.edge_index[1,:] = (data_c.edge_index[1,:]-1).long()
#     data_m.edge_index[0,:] = (data_m.edge_index[0,:]-1).long()
#     data_m.edge_index[1,:] = (data_m.edge_index[1,:]-1).long()
#     data_f.edge_index[0,:] = (data_f.edge_index[0,:]-1).long()
#     data_f.edge_index[1,:] = (data_f.edge_index[1,:]-1).long()
    
    return data_c, data_m, data_f

In [None]:
class MFGraphDataset(Dataset):
    
    """
    Class to create custom multi-fidelity graph dataset compatible for Multi-fidelity Graph U-Net architecture
    
    Every graph data generated using this class contains three graphs of Data object, with three different resolutions.
    
    The output of the class is a list of length 3 containing these three graphs.
    """ 
    
    def __init__(self, root, source_dir, sim_list, test=False, transform=None, pre_transform=None):
        
        self.root = root # root directory where procssed data is stored
        self.sims = sim_list # list of simulation numbers (different for train and test data)
        self.test = test # flag to identify test data
        self.source_dir = source_dir # source directory for raw data
        
        super(MFGraphDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return []

    def download(self):
        pass
    
    def len(self):
        return len(self.sims)
    
    def process(self):
        i = 0
        
        for i in range(len(self.sims)):
            print("processing simulation {}!".format(i))
            data_c, data_m, data_f = process_graph_data(self.source_dir, self.sims[i])
            
            if(self.test):
                torch.save(data_c, osp.join(self.root, 'processed/test_run_{}_LF.pt'.format(i)))
                torch.save(data_m, osp.join(self.root, 'processed/test_run_{}_MF.pt'.format(i)))
                torch.save(data_f, osp.join(self.root, 'processed/test_run_{}_HF.pt'.format(i)))
                
            else:
                torch.save(data_c, osp.join(self.root, 'processed/run_{}_LF.pt'.format(i)))
                torch.save(data_m, osp.join(self.root, 'processed/run_{}_MF.pt'.format(i)))
                torch.save(data_f, osp.join(self.root, 'processed/run_{}_HF.pt'.format(i)))

    def get(self,idx):
        
        if(self.test):
            i = idx 
            data_c = torch.load(osp.join(self.root, 'processed/test_run_{}_LF.pt'.format(i)))
            data_m = torch.load(osp.join(self.root, 'processed/test_run_{}_MF.pt'.format(i)))
            data_f = torch.load(osp.join(self.root, 'processed/test_run_{}_HF.pt'.format(i)))
            
        else:
            i = idx 
            data_c = torch.load(osp.join(self.root, 'processed/run_{}_LF.pt'.format(i)))
            data_m = torch.load(osp.join(self.root, 'processed/run_{}_MF.pt'.format(i)))
            data_f = torch.load(osp.join(self.root, 'processed/run_{}_HF.pt'.format(i)))
            
        return [data_c, data_m, data_f] 

# Data



In [None]:
# getting the list of simulation numbers from the dataset
### change this based on your data!!!
mypath = 'dataset/'
onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
file_nums=[]
for file in onlyfiles:
    file_nums.append(int(file.split('_')[1]))
file_nums = np.unique(np.array(file_nums)) 

# defining the root and source directory
# root directory/processed is the folder where the processed graph data is stored
# source directory contains the raw graph data
root_dir = "dataset/"
source_dir = "dataset"

# 500 simulations are used for training and 300 for testing
# you can changes these numbers
train_list = file_nums[:500]
val_list = file_nums[500:]

#change batch size, if needed
batch_size = 2

In [None]:
# Processing training and validation datasets
train_dataset = MFGraphDataset(root_dir, source_dir, train_list, test=False)
val_dataset = MFGraphDataset(root_dir, source_dir, val_list, test=True)

## Defining the data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

## Calculating the min, max values of node, edge attributes and target for scaling purposes
xmax_c,ymax_c,emax_c,xmin_c,ymin_c,emin_c,xmax_m,ymax_m,emax_m,xmin_m,ymin_m,emin_m,xmax_f,ymax_f,emax_f,xmin_f,ymin_f,emin_f = get_max_min(train_dataset)

# GNN Model Architecture

Next section contains the functions and classes which define the Multi-fidelity Graph U-Net architecture

Classes EdgeModel and NodeModel contain the edge and node update MLP functions for message passing and propagation

## Multi-fidelity Graph U-Net architecture (Net_MFGUNN_v1)

This architecture is explained using three levels of graph resolution. This can be easily extrapolated to any number of resolutions.

We consider three resolutions of graph, namely coarse, medium and fine with the graph data identified as data_c, data_m and data_f respectively. 

The node attributes of medium graph data contains the indices of the N nearest nodes (calculated using Euclidean distance during the data creation stage) of the corresponding fine graph, in addition to the node features used for training. Similarly, node attributes of coarse graph data contains the indices of the N nearest nodes of the corresponding medium graph as additional node attributes. Fine graph contains only the node features used for the training. These indices, which are the additional node attributes are extracted and stored in indices_c and indices_m from the coarse and medum graph respectively, during the training.

We start with the fine graph, whose node and edge features are encoded to a latent space using an encoder function. 
The encoded node and edge attributes are passed through a number of GN blocks (composed of node and edge update modules) for message passing and propagation to update the node and edge attributes of the fine graph.

Similar to fine graph, we encode the node and edge features of the medium graph. Then, the updated node attributes of the fine graph is added to the updated node attributes of the medium graph, through a nearest neighbor downsampling method. For this, we use the indices_m list which contains the N nearest nodes from the fine graph, for every node in the medium graph. For every node in the medium graph, we aggregate the node attributes of the N nearest nodes of the fine graph using a MEAN or MAX function and added to the encoded node attribute of the node with a learnable weight (weight_m). This updated node attributes and edge attributes are then passed through a number of GN blocks for message passing and aggregation. 

Similar process is done to transfer node attributes from medium graph to coarse graph, followed by a number of GN blocks for message passing and aggregation. 

This completes the downward flow of information of the U-Net architecture. This is followed by the upward flow direction, where nodal attribute information is upsampled and passed from coarse to medium to fine graph. For this, the same set of N nearest nodes are used for upsampling the information from one level to the other as can be seen in the function 'get_upsample_attr', which is added to the current node attributes of all the nodes in that level. 

Finally, the updated node attributes of coarse, medium and fine graphs are passed through the decoder function for target prediction at that level. For loss function, all the three target prediction outputs are used. 

The function 'get_upsample_attr' is used for getting the indices for upsampling values from the coarse graph to medium graph as well as from medium graph to fine graph

## Multi-fidelity Graph U-Net architecture (Net_MFGUNN_v2)

This is the version used for the results in the paper. Here, the GN blocks are shared across different resolution levels.

First, encoded node attributes from high resolution graph is passed through k (defined by variable couple_point) GN blocks and then downsampled and the node attributes are added to the encoded node attributes of the medium resolution graph. The same process is done between medium and low resolution graphs. The node attributes of the low resolution graph is passed through all GN blocks and then they are upsampled to add to the medium resolution graph. this added node attributes of medium graph is passed through rest of the GN blocks. same process is done from medium to high resolution graph. 

Output from all levels are used for the loss function. 

## Multi-fidelity Graph U-Net Lite architecture (Net_MFGUNN_uni)

This architecture is very similar to Net_MFUGNN_v2. The only difference is in the flow of information which is uni-directional here - from low to high resolution. So all the downsampling steps are avoided.

In [None]:
class EdgeModel (torch.nn.Module):
    """
    Class for updating the edge attributes of the graph. 
    
    It is a MLP network with a single hidden layer, with ReLU activation function
    Input consists of edge attributes node attributes of the edges and the node attributes of the two nodes connected by the edge
    
    Output is the updated edge attribites of all the edges of the graph
    
    If residuals is True, the updated values are added to the previous edge attributes before they are returned.
    This is similar to the residual network.
    
    """
    
    def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
        super().__init__()
        self.residuals = residuals
        self. edge_mlp = Seq(
            Linear(2*n_features + n_edge_features, hiddens),
            ReLU(),
            Linear (hiddens, n_targets),
        )
        
    def forward(self, src, dest, edge_attr, u=None, batch=None):
        #Concats the nodes connecting the edges and edge attributes and passed through MLP
        #src and dest are the node attributes of the two nodes
        #edge_attr is the edge attributes of the edge connecting the two nodes
        
        out = torch.cat([src, dest, edge_attr], 1)
        out = self.edge_mlp(out)
        if self.residuals:
            out = out + edge_attr
        return out

class NodeModel (torch.nn.Module):
    """
    Class for updating the node attributes of the graph. 
    
    It consists of two MLP networks. Both the networks have single hidden layer with with ReLU activation function.
    
    First MLP network, node_mlp_1, takes the node attributes of the neighboring nodes and the updated edge attributes of the
    edges connecting these nodes as the input. Output is the message from the neighboring nodes, with the same size as the
    node attributes. 
    
    Second MLP network, node_mlp_2, takes the aggregated message acorss all the neighboring nodes from node_mlp_1 
    and the node attributes of the current node as the input. Output is the updated node attributes. 
    
    If residuals is True, the updated values are added to the previous node attributes before they are returned.
    This is similar to the residual network.
    """ 
    
    def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
        super(NodeModel, self).__init__()
        
        self.residuals = residuals
        
        #message calculation MLP
        self. node_mlp_1 = Seq(
            Linear(n_features + n_edge_features, hiddens),
            ReLU(),
            Linear(hiddens, n_targets),
        )
        
        #node attribute update MLP
        self.node_mlp_2 = Seq(
            Linear (hiddens + n_features, hiddens),
            ReLU(),
            Linear(hiddens, n_targets),
        )

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch. cat([x[col], edge_attr], dim=1)
        out = self.node_mlp_1(out) #message calculation
        out = torch_scatter.scatter_add(out, row, dim=0, dim_size=x.size(0)) #message aggregation, aggregation function is SUM
        out = torch.cat([x, out], dim=1)
        out = self.node_mlp_2(out) #node update
        if self.residuals:
            out = out + x
        return out


def build_layer(n_features, n_edge_features, hiddens, n_targets, batchnorm=False, residuals=True):
    """Calling the edge and node update modules
    """ 
    return geom_nn.MetaLayer(
        edge_model=EdgeModel(n_features, n_edge_features, hiddens, n_targets, residuals=residuals),
        node_model=NodeModel(n_features, n_edge_features, hiddens, n_targets, residuals=residuals),
    )




In [None]:
def get_upsample_attr(indices, node_attr, N, channel, device):
    
    """Loading the data from the source directory and returns the data
    
    Parameters
    -----------
    indices : a tensor containing the indices of all the nearest nodes of the next high level resolution of data, Tensor of shape k x N, Long
    node_attr : a tensor with the updated node attributes for all the nodes for the current level resolution, Tensor of size k x channel, Float
    N : number of nearest neighbors considered, int
    channel : latent space size of the node attributes, int
    device : CPU/GPU
    
    Ouput
    ------
    upsampled_attr : upsampled node attribute tensor to be added to the node attributes of all the nodes of the
                    next high level resolution of data, Tensor of shape k x 1
       
    """ 
    #get the list of all the indices available in the tensor indices
    #k x N tensor converted to kN x 1 tensor
    indices = indices.reshape(-1,1).to(device)
    
    # uniq_index_list is the list of unique indices in the tensor, indices
    # uniq_index_pos is the original position of the values in uniq_index_list in the tensor, indices
    uniq_index_list, uniq_index_pos = torch.unique(indices, return_inverse=True)
    
    # creating a zero value tensor of with number of rows same as the number of unique indices
    # and number of columns matching the channel value
    zero_sum = torch.zeros(uniq_index_list.shape[0], channel, dtype=torch.float32).to(device)
    
    # for every unique value in indices, counts gives the number of times its repeated in indices 
    _, counts = torch.unique(indices.to(device), return_counts=True)
       
    # repeats the node attributes N times    
    node_attr_repeat = torch.repeat_interleave(node_attr, N, dim=0).to(device)
    
    # to the zero tensor, node attributes are added corresponding to the indices in uniq_index_pos
    upsampled_attr = zero_sum.index_add_(0, uniq_index_pos[:,0], node_attr_repeat.to(device))
    # when multiple node attributes are added a single index, the following step results in taking the mean of 
    # all the added node attributes
    upsampled_attr/=counts[:, None]
    
    # all the upsampled node attributes are concatenated with the list of indices to which they belong
    upsampled_attr = torch.hstack((uniq_index_list.reshape(-1,1).to(device), upsampled_attr.to(device)))
    
    return upsampled_attr

In [None]:
class Net_MFGUNN_v1(torch.nn.Module):
    """
    Class defining the Multi-Fidelity Graph U-Net architecture
    
    """ 
    def __init__(self, node_att_num, edge_att_num, output_dim, 
                 gn_depth_list, encoder_neurons, N, node_index_list,
                 index_start, dropout=0.1, batch_norm = False):
        super(Net_MFGUNN_v1, self).__init__()

        self.node_att_num = node_att_num #number of node attributes
        self.edge_att_num = edge_att_num #number of edge attributes
        self.encoder_neurons = encoder_neurons #number of neurons in each layer of encoder network
        self.output_dim = output_dim #number of outputs/targets
        self.gn_depth_list = gn_depth_list #list containing the number of GN blocks for each level
        self.node_index_list = node_index_list #indices of the node features to be used for training
        self.dropout = dropout #droput ratio
        self.N = N #number of nearest nodes considered
        self.channel = self.encoder_neurons[1] #latent space size after encoding
        self.index_start = index_start
        # defining the encoder network for node attributes
        self.encoder_node = Seq(Linear(self.node_att_num, self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.encoder_neurons[1]))
        # defining the encoder network for edge attributes
        self.encoder_edge = Seq(Linear(self.edge_att_num, self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.encoder_neurons[1]))
        # defining the decoder network 
        self.decoder = Seq(Linear(self.encoder_neurons[1], self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.output_dim))
        # defining batch norm if used
        self.bn = torch.nn.BatchNorm1d(self.channel)
        
        # learnable weights with which the node attributes are added from fine to medium and medium to coarse graphs
        self.weight_c = Linear(self.channel, self.channel)
        self.weight_m = Linear(self.channel, self.channel)
        
        # defining the GN blocks for message passing and aggregation
        self.GN_Blocks= torch.nn.ModuleList()
        for k in range(len(self.gn_depth_list)):
            self.GN_Blocks_k= torch.nn.ModuleList()
            for i in range(self.gn_depth_list[k]):
                self.GN_Blocks_k.append(build_layer(self.channel,self.channel,self.channel,self.channel))
            self.GN_Blocks.append(self.GN_Blocks_k)
        
    def forward(self, data, device, scale=True):
        
        assert len(self.gn_depth_list) == 5, "Invalid GN list length" # GN list length should be 5 for 3 levels of resolution 
        
        data_c = data[0].to(device) #coarse graph
        data_m = data[1].to(device) #medium graph
        data_f = data[2].to(device) #fine graph

        x_c, edge_index_c, edge_attr_c = data_c.x, data_c.edge_index, data_c.edge_attr
        x_m, edge_index_m, edge_attr_m = data_m.x, data_m.edge_index, data_m.edge_attr
        x_f, edge_index_f, edge_attr_f = data_f.x, data_f.edge_index, data_f.edge_attr
        
        # extracting the indices stored as node attributes in medium and fine graph
        indices_c = x_c[:,self.index_start:].to(device) 
        indices_m = x_m[:,self.index_start:].to(device)
        
        if scale:
            # scaling node attributes
            x_c = min_max_scaler(x_c, xmin_c, xmax_c, device)
            x_m = min_max_scaler(x_m, xmin_m, xmax_m, device)
            x_f = min_max_scaler(x_f, xmin_f, xmax_f, device)
            # scaling edge attributes
            edge_attr_c = min_max_scaler(edge_attr_c, emin_c, emax_c, device)
            edge_attr_m = min_max_scaler(edge_attr_m, emin_m, emax_m, device)
            edge_attr_f = min_max_scaler(edge_attr_f, emin_f, emax_f, device)

        # encoding the node attributes for all three graphs
        x_c = self.encoder_node(x_c[:,self.node_index_list])
        x_m = self.encoder_node(x_m[:,self.node_index_list])
        x_f = self.encoder_node(x_f[:,self.node_index_list])
        
        # encoding the edge attributes for all three graphs
        edge_attr_c = self.encoder_edge(torch.tensor(edge_attr_c, dtype=torch.float))
        edge_attr_m = self.encoder_edge(torch.tensor(edge_attr_m, dtype=torch.float))
        edge_attr_f = self.encoder_edge(torch.tensor(edge_attr_f, dtype=torch.float))

        # encoded attributes of fine graph are passed through the first set of GN blocks
        for i in range(self.gn_depth_list[0]):
            x_f, edge_attr_f, _ = self.GN_Blocks[0][i](x_f, edge_index_f, edge_attr=edge_attr_f)
        
        # for the nearest nodes from the fine graph, the node attributes are extracted and stored in x_m_downsample
        x_m_downsample = x_f[indices_m.reshape(-1,1)[:,0].long(),:].reshape(-1, self.N, self.channel)
        # the extracted node attributes are aggregated (i used mean, you can use max as well)
        # they are then weighted using a learnable weight (weight_m) 
        # and added to the encoded node attributes of medium graph
        x_m = x_m + self.weight_m(torch.mean(x_m_downsample,1)[0].to(device))
          
        # these node attributes of medium graph are then passed through the second set of GN blocks
        for i in range(self.gn_depth_list[1]):
            x_m, edge_attr2, _ = self.GN_Blocks[1][i](x_m, edge_index_m, edge_attr=edge_attr_m)

        # for the nearest nodes from the medium graph, the node attributes are extracted and stored in x_c_downsample
        x_c_downsample = x_m[indices_c.reshape(-1,1)[:,0].long(),:].reshape(-1, self.N, self.channel)
        # which is then aggregated and added to the encoded node attributes of coarse graph
        x_c = x_c + self.weight_c(torch.mean(x_c_downsample,1)[0].to(device))
        
        # these node attributes of coarse graph are then passed through the third set of GN blocks
        for i in range(self.gn_depth_list[2]):
            x_c, edge_attr_c, _ = self.GN_Blocks[2][i](x_c, edge_index_c, edge_attr=edge_attr_c)
            
            
        # upsampling starts from here
        # get the upsampling attributes from the coarse graph to be added to the medium graph using get_upsample_attr function
        x_c_upsample = get_upsample_attr(indices_c.to(device), x_c.to(device), self.N, self.channel, device).to(device)
        x_m[x_c_upsample[:,0].long(),:] = x_m[x_c_upsample[:,0].long(),:] + x_c_upsample[:,1:]
        
        # updated node attributes from medium graph passed through fourth set of GN blocks
        for i in range(self.gn_depth_list[3]):
            x_m, edge_attr_m, _ = self.GN_Blocks[3][i](x_m, edge_index_m, edge_attr=edge_attr_m)

        # get the upsampling attributes from the medium graph to be added to the fine graph using get_upsample_attr function
        x_m_upsample = get_upsample_attr(indices_m.to(device), x_m.to(device), self.N, self.channel, device).to(device)
        x_f[x_m_upsample[:,0].long(),:] = x_f[x_m_upsample[:,0].long(),:] + x_m_upsample[:,1:]

        # updated node attributes from fine graph passed through final set of GN blocks
        for i in range(self.gn_depth_list[4]):
            x_f, edge_attr_f, _ = self.GN_Blocks[4][i](x_f, edge_index_f, edge_attr=edge_attr_f)
        
        # all the updated node attributes from fine, coarse and medium graphs passed through decoder
        # to get the predicted responses for each level of resolution
        x_f = self.decoder(x_f)
        x_m = self.decoder(x_m)
        x_c = self.decoder(x_c)
        
        return x_c, x_m, x_f

In [None]:
class Net_MFGUNN_v2(torch.nn.Module):
    """
    Class defining the Multi-Fidelity Graph U-Net architecture
    
    """ 
    def __init__(self, node_att_num, edge_att_num, output_dim, 
                 gn_depth, encoder_neurons, N, node_index_list,
                 index_start,couple_point, dropout=0.1, batch_norm = False):
        super(Net_MFGUNN_v2, self).__init__()

        self.node_att_num = node_att_num #number of node attributes
        self.edge_att_num = edge_att_num #number of edge attributes
        self.encoder_neurons = encoder_neurons #number of neurons in each layer of encoder network
        self.output_dim = output_dim #number of outputs/targets
        self.gn_depth = gn_depth #list containing the number of GN blocks for each level
        self.node_index_list = node_index_list #indices of the node features to be used for training
        self.dropout = dropout #droput ratio
        self.N = N #number of nearest nodes considered
        self.channel = self.encoder_neurons[1] #latent space size after encoding
        self.index_start = index_start #index of the node attributes from where the list of indices of the k nearest neighbors from the next higher resolution is stored. This is calcualted during the data creation step offline.
        # defining the encoder network for node attributes
        self.encoder_node = Seq(Linear(self.node_att_num, self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.encoder_neurons[1]))
        # defining the encoder network for edge attributes
        self.encoder_edge = Seq(Linear(self.edge_att_num, self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.encoder_neurons[1]))
        # defining the decoder network 
        self.decoder = Seq(Linear(self.encoder_neurons[1], self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.output_dim))
        # defining batch norm if used
        self.bn = torch.nn.BatchNorm1d(self.channel)
        self.couple_point = couple_point
        
        # learnable weights with which the node attributes are added from fine to medium and medium to coarse graphs
        self.weight_c = Linear(self.channel, self.channel)
        self.weight_m = Linear(self.channel, self.channel)
        
        # defining the GN blocks for message passing and aggregation
        self.GN_Blocks= torch.nn.ModuleList()
        for i in range(self.gn_depth):
            self.GN_Blocks.append(build_layer(self.channel,self.channel,self.channel,self.channel))
        
    def forward(self, data, device, scale=True):
                
        data_c = data[0].to(device) #coarse graph
        data_m = data[1].to(device) #medium graph
        data_f = data[2].to(device) #fine graph

        x_c, edge_index_c, edge_attr_c = data_c.x, data_c.edge_index, data_c.edge_attr
        x_m, edge_index_m, edge_attr_m = data_m.x, data_m.edge_index, data_m.edge_attr
        x_f, edge_index_f, edge_attr_f = data_f.x, data_f.edge_index, data_f.edge_attr
        
        # extracting the indices stored as node attributes in medium and fine graph
        indices_c = x_c[:,self.index_start:].to(device) 
        indices_m = x_m[:,self.index_start:].to(device)
        
        if scale:
            # scaling node attributes
            x_c = min_max_scaler(x_c, xmin_c, xmax_c, device)
            x_m = min_max_scaler(x_m, xmin_m, xmax_m, device)
            x_f = min_max_scaler(x_f, xmin_f, xmax_f, device)
            # scaling edge attributes
            edge_attr_c = min_max_scaler(edge_attr_c, emin_c, emax_c, device)
            edge_attr_m = min_max_scaler(edge_attr_m, emin_m, emax_m, device)
            edge_attr_f = min_max_scaler(edge_attr_f, emin_f, emax_f, device)

        # encoding the node attributes for all three graphs
        x_c = self.encoder_node(x_c[:,self.node_index_list])
        x_m = self.encoder_node(x_m[:,self.node_index_list])
        x_f = self.encoder_node(x_f[:,self.node_index_list])
        
        # encoding the edge attributes for all three graphs
        edge_attr_c = self.encoder_edge(torch.tensor(edge_attr_c, dtype=torch.float))
        edge_attr_m = self.encoder_edge(torch.tensor(edge_attr_m, dtype=torch.float))
        edge_attr_f = self.encoder_edge(torch.tensor(edge_attr_f, dtype=torch.float))

        # encoded attributes of fine graph are passed through the couple_point set of GN blocks
        for i in range(self.gn_depth-self.couple_point):
            x_f, edge_attr_f, _ = self.GN_Blocks[i](x_f, edge_index_f, edge_attr=edge_attr_f)
        
        # for the nearest nodes from the fine graph, the node attributes are extracted and stored in x_m_downsample
        x_m_downsample = x_f[indices_m.reshape(-1,1)[:,0].long(),:].reshape(-1, self.N, self.channel)
        # the extracted node attributes are aggregated (i used mean, you can use max as well)
        # they are then weighted using a learnable weight (weight_m) 
        # and added to the encoded node attributes of medium graph
        x_m = x_m + self.weight_m(torch.mean(x_m_downsample,1)[0].to(device))
          
        # these node attributes of medium graph are then passed through the couple_point set of GN blocks
        for i in range(self.gn_depth-self.couple_point):
            x_m, edge_attr2, _ = self.GN_Blocks[i](x_m, edge_index_m, edge_attr=edge_attr_m)

        # for the nearest nodes from the medium graph, the node attributes are extracted and stored in x_c_downsample
        x_c_downsample = x_m[indices_c.reshape(-1,1)[:,0].long(),:].reshape(-1, self.N, self.channel)
        # which is then aggregated and added to the encoded node attributes of coarse graph
        x_c = x_c + self.weight_c(torch.mean(x_c_downsample,1)[0].to(device))
        
        # these node attributes of coarse graph are then passed through the entire set of GN blocks
        for i in range(self.gn_depth_list):
            x_c, edge_attr_c, _ = self.GN_Blocks[i](x_c, edge_index_c, edge_attr=edge_attr_c)
            
            
        # upsampling starts from here
        # get the upsampling attributes from the coarse graph to be added to the medium graph using get_upsample_attr function
        x_c_upsample = get_upsample_attr(indices_c.to(device), x_c.to(device), self.N, self.channel, device).to(device)
        x_m[x_c_upsample[:,0].long(),:] = x_m[x_c_upsample[:,0].long(),:] + x_c_upsample[:,1:]
        
        # updated node attributes from medium graph passed through fourth set of GN blocks
        for i in range(self.gn_depth-self.couple_point,self.gn_depth_list):
            x_m, edge_attr_m, _ = self.GN_Blocks[i](x_m, edge_index_m, edge_attr=edge_attr_m)

        # get the upsampling attributes from the medium graph to be added to the fine graph using get_upsample_attr function
        x_m_upsample = get_upsample_attr(indices_m.to(device), x_m.to(device), self.N, self.channel, device).to(device)
        x_f[x_m_upsample[:,0].long(),:] = x_f[x_m_upsample[:,0].long(),:] + x_m_upsample[:,1:]

        # updated node attributes from fine graph passed through final set of GN blocks
        for i in range(self.gn_depth-self.couple_point, self.gn_depth_list):
            x_f, edge_attr_f, _ = self.GN_Blocks[i](x_f, edge_index_f, edge_attr=edge_attr_f)
        
        # all the updated node attributes from fine, coarse and medium graphs passed through decoder
        # to get the predicted responses for each level of resolution
        x_f = self.decoder(x_f)
        x_m = self.decoder(x_m)
        x_c = self.decoder(x_c)
        
        return x_c, x_m, x_f

In [None]:
class Net_MFGUNN_uni(torch.nn.Module):
    """
    Class defining the Multi-Fidelity Graph U-Net architecture
    
    """ 
    def __init__(self, node_att_num, edge_att_num, output_dim, 
                 gn_depth, encoder_neurons, N, node_index_list,
                 index_start,couple_point, dropout=0.1, batch_norm = False):
        super(Net_MFGUNN_uni, self).__init__()

        self.node_att_num = node_att_num #number of node attributes
        self.edge_att_num = edge_att_num #number of edge attributes
        self.encoder_neurons = encoder_neurons #number of neurons in each layer of encoder network
        self.output_dim = output_dim #number of outputs/targets
        self.gn_depth = gn_depth #list containing the number of GN blocks for each level
        self.node_index_list = node_index_list #indices of the node features to be used for training
        self.dropout = dropout #droput ratio
        self.N = N #number of nearest nodes considered
        self.channel = self.encoder_neurons[1] #latent space size after encoding
        self.index_start = index_start
        # defining the encoder network for node attributes
        self.encoder_node = Seq(Linear(self.node_att_num, self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.encoder_neurons[1]))
        # defining the encoder network for edge attributes
        self.encoder_edge = Seq(Linear(self.edge_att_num, self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.encoder_neurons[1]))
        # defining the decoder network 
        self.decoder = Seq(Linear(self.encoder_neurons[1], self.encoder_neurons[0]),
                                ReLU(),
                                Linear(self.encoder_neurons[0], self.output_dim))
        # defining batch norm if used
        self.bn = torch.nn.BatchNorm1d(self.channel)
        self.couple_point = couple_point
        
        # learnable weights with which the node attributes are added from fine to medium and medium to coarse graphs
        self.weight_c = Linear(self.channel, self.channel)
        self.weight_m = Linear(self.channel, self.channel)
        
        # defining the GN blocks for message passing and aggregation
        self.GN_Blocks= torch.nn.ModuleList()
        for i in range(self.gn_depth):
            self.GN_Blocks.append(build_layer(self.channel,self.channel,self.channel,self.channel))
        
    def forward(self, data, device, scale=True):
                
        data_c = data[0].to(device) #coarse graph
        data_m = data[1].to(device) #medium graph
        data_f = data[2].to(device) #fine graph

        x_c, edge_index_c, edge_attr_c = data_c.x, data_c.edge_index, data_c.edge_attr
        x_m, edge_index_m, edge_attr_m = data_m.x, data_m.edge_index, data_m.edge_attr
        x_f, edge_index_f, edge_attr_f = data_f.x, data_f.edge_index, data_f.edge_attr
        
        # extracting the indices stored as node attributes in medium and fine graph
        indices_c = x_c[:,self.index_start:].to(device) 
        indices_m = x_m[:,self.index_start:].to(device)
        
        if scale:
            # scaling node attributes
            x_c = min_max_scaler(x_c, xmin_c, xmax_c, device)
            x_m = min_max_scaler(x_m, xmin_m, xmax_m, device)
            x_f = min_max_scaler(x_f, xmin_f, xmax_f, device)
            # scaling edge attributes
            edge_attr_c = min_max_scaler(edge_attr_c, emin_c, emax_c, device)
            edge_attr_m = min_max_scaler(edge_attr_m, emin_m, emax_m, device)
            edge_attr_f = min_max_scaler(edge_attr_f, emin_f, emax_f, device)

        # encoding the node attributes for all three graphs
        x_c = self.encoder_node(x_c[:,self.node_index_list])
        x_m = self.encoder_node(x_m[:,self.node_index_list])
        x_f = self.encoder_node(x_f[:,self.node_index_list])
        
        # encoding the edge attributes for all three graphs
        edge_attr_c = self.encoder_edge(torch.tensor(edge_attr_c, dtype=torch.float))
        edge_attr_m = self.encoder_edge(torch.tensor(edge_attr_m, dtype=torch.float))
        edge_attr_f = self.encoder_edge(torch.tensor(edge_attr_f, dtype=torch.float))
        
        #  node attributes of fine graph are  passed through upto couple_point GN blocks
        for i in range(self.gn_depth-self.couple_point):
            x_f, edge_attr_f, _ = self.GN_Blocks[i](x_f, edge_index_f, edge_attr=edge_attr_f)
        #  node attributes of coarse graph are  passed through upto couple_point GN blocks
        for i in range(self.gn_depth-self.couple_point):
            x_m, edge_attr2, _ = self.GN_Blocks[i](x_m, edge_index_m, edge_attr=edge_attr_m)
        
        #  node attributes of coarse graph are  passed through the entire set of GN blocks
        for i in range(self.gn_depth_list):
            x_c, edge_attr_c, _ = self.GN_Blocks[i](x_c, edge_index_c, edge_attr=edge_attr_c)
            
        # upsampling starts from here
        # get the upsampling attributes from the coarse graph to be added to the medium graph using get_upsample_attr function
        x_c_upsample = get_upsample_attr(indices_c.to(device), x_c.to(device), self.N, self.channel, device).to(device)
        x_m[x_c_upsample[:,0].long(),:] = x_m[x_c_upsample[:,0].long(),:] + x_c_upsample[:,1:]
        
        # updated node attributes from medium graph passed through fourth set of GN blocks
        for i in range(self.gn_depth-self.couple_point,self.gn_depth_list):
            x_m, edge_attr_m, _ = self.GN_Blocks[i](x_m, edge_index_m, edge_attr=edge_attr_m)

        # get the upsampling attributes from the medium graph to be added to the fine graph using get_upsample_attr function
        x_m_upsample = get_upsample_attr(indices_m.to(device), x_m.to(device), self.N, self.channel, device).to(device)
        x_f[x_m_upsample[:,0].long(),:] = x_f[x_m_upsample[:,0].long(),:] + x_m_upsample[:,1:]

        # updated node attributes from fine graph passed through final set of GN blocks
        for i in range(self.gn_depth-self.couple_point, self.gn_depth_list):
            x_f, edge_attr_f, _ = self.GN_Blocks[i](x_f, edge_index_f, edge_attr=edge_attr_f)
        
        # all the updated node attributes from fine, coarse and medium graphs passed through decoder
        # to get the predicted responses for each level of resolution
        x_f = self.decoder(x_f)
        x_m = self.decoder(x_m)
        x_c = self.decoder(x_c)
        
        return x_c, x_m, x_f

# Model training and evaluation

In [None]:
def model_train(data_loader, loss_all, device, scale):
    """Training the GNN model
    
    Parameters
    -----------
    data_loader : Data loader object from pytorch geometric, it contains all the graphs for training
    loss_all : loss value, Tensor float
    device : GPU/CPU
    scale : if True, scaling is done on node and edge attributes as well as the target, boolean
    
    Ouput
    ------
    loss_all : loss value after a single epoch, Tensor float
       
    """ 
    model.train()
    for data in data_loader:
        # get the predicted outputs
        out_c, out_m, out_f = model(data, device, scale)
        
        optimizer.zero_grad(set_to_none=True)
        
        # scale target responses if scale is True
        if scale:
            y_c =  min_max_scaler(data[0].y, ymin_c, ymax_c, device).reshape(-1,1)
            y_m =  min_max_scaler(data[1].y, ymin_m, ymax_m, device).reshape(-1,1)
            y_f =  min_max_scaler(data[2].y, ymin_f, ymax_f, device).reshape(-1,1)
        else:
            y_c = data[0].y.reshape(-1,1)
            y_m = data[1].y.reshape(-1,1)
            y_f = data[2].y.reshape(-1,1)
            
        # loss calculation
        # loss is calculated for predictions from coarse, medium and fine graphs
        # they are weighted in tha ratio 1:5:10 for coarse:medium:fine
        loss_calc = loss(out_c.reshape(-1,1), y_c.reshape(-1,1))  + 5*loss(out_m.reshape(-1,1), y_m.reshape(-1,1)) + 10*loss(out_f.reshape(-1,1), y_f.reshape(-1,1))
        loss_all += data[2].num_graphs * loss_calc.item()
        loss_calc.backward()
        
        optimizer.step()
        my_lr_scheduler.step()
        
    return loss_all


def model_eval(data_loader, device, scale):
    """Evaluating the GNN model
    
    Parameters
    -----------
    data_loader : Data loader object from pytorch geometric, it contains all the graphs for training
    device : GPU/CPU
    scale : if True, scaling is done on node and edge attributes as well as the target, boolean
    
    Ouput
    ------
    l2_err : relative L2 error for the predictions in the fine graph, float
       
    """ 
    model.eval()
    
    predictions = []
    labels = []
    
    for data in data_loader:
                
        #getting the prediction from just the fine graph
        _, _, pred = model(data, device, scale)
        
        if scale:
            pred = min_max_descaler(pred, ymin_f, ymax_f, device).detach().cpu().numpy().reshape(-1,1)
        else:
            pred = pred.detach().cpu().numpy().reshape(-1,1)
            
        label = data[2].y.detach().cpu().numpy().reshape(-1,1)
        predictions.append(pred)
        labels.append(label)
        
    predictions = np.vstack(predictions)
    labels = np.vstack(labels)
    
    # calculation of relative L2 error
    diff_norm = np.linalg.norm(predictions - labels, ord=2)
    y_norm = np.linalg.norm(labels, ord=2)
    l2_err = np.mean(diff_norm / y_norm)

    return l2_err

### CHANGE TRAINING AND MODEL PARAMETERS HERE!

In [None]:

# TRAINING HYPERPARAMETERS
n_epochs = 1000
batch_size = 2
lr = 0.001
weight_decay=1e-6

# MODEL PARAMETERS
# change these based on your data
node_att_num = 10
edge_att_num = 3
output_dim = 1
gn_depth_list = [3,2,2,2,3]
gn_depth = 12
couple_point = 6
index_start = 12
encoder_neurons = [64, 128]
node_index_list = [0,1,4,5,6,7,8,9,10,11]
N = 5
dropout = 0.0
scale = True
batch_norm = False

# DIRECTORIES TO STORE RESULTS
result_dir = 'results'
model_dir = 'models'
loss_dir = 'losses'


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#CHOOSE THE MODEL YOU WANT!!

# model = Net_MFGUNN_v1(node_att_num=node_att_num, edge_att_num=edge_att_num, output_dim=output_dim, N=5,
#                 gn_depth_list=gn_depth_list,node_index_list=node_index_list,encoder_neurons=encoder_neurons, 
#                 index_start=index_start,dropout=dropout, batch_norm = False).to(device)

model = Net_MFGUNN_v2(node_att_num=node_att_num, edge_att_num=edge_att_num, output_dim=output_dim, N=5,
                gn_depth=gn_depth,couple_point = couple_point, node_index_list=node_index_list,encoder_neurons=encoder_neurons, 
                index_start=index_start,dropout=dropout, batch_norm = False).to(device)

# model = Net_MFGUNN_uni(node_att_num=node_att_num, edge_att_num=edge_att_num, output_dim=output_dim, N=5,
#                 gn_depth=gn_depth,couple_point = couple_point, node_index_list=node_index_list,encoder_neurons=encoder_neurons, 
#                 index_start=index_start,dropout=dropout, batch_norm = False).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.99, 0.999), weight_decay=weight_decay)
my_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs, eta_min=1e-8)

# we can use MSE or relative L2 error loss for training

#loss = torch.nn.MSELoss()
def loss(pred, actual):
    diff_norm = torch.norm(torch.flatten(pred)-torch.flatten(actual), p=2)
    actual_norm = torch.norm(torch.flatten(actual), p=2)
    return diff_norm/actual_norm

### MODEL TRAINING

In [None]:
## model training
epoch_list = []
train_l2_err = []
val_l2_err = []

print('Training started...')

for epoch in range(n_epochs):
    loss_all = 0
    loss_all = model_train(train_loader, loss_all, device, scale)
    if(epoch%10==0):
        epoch_list.append(epoch)
        l2_err = model_eval(train_loader, device, scale)
        train_l2_err.append(l2_err)
        l2_err = model_eval(val_loader, device, scale)
        val_l2_err.append(l2_err)
        print('epoch: ', epoch, 'train error: ', train_l2_err[-1], 'val error: ', val_l2_err[-1])
        print()

        # saving the model
        torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
        }, result_dir + '/' + model_dir + '/model_mfgunet.pt')
        
        # saving the loss results
        np.savetxt(result_dir + '/' + loss_dir + '/train_l2_err.txt', train_l2_err)
        np.savetxt(result_dir + '/' + loss_dir + '/val_l2_err.txt', val_l2_err)
        

In [None]:
### FUNCTION FOR GETTING THE K NEAREST NODES DURING THE DATA GENERATION PART!!
from sklearn.metrics.pairwise import euclidean_distances

def get_topK_nodes(coarse_data, fine_data, K):
    dist12 = euclidean_distances(coarse_data, fine_data)
    top12 = np.argsort(dist12)[::-1][:,:K]
    return top12