In [1]:
import time
import logging

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns

import torch
from torch.functional import F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, Linear, to_hetero
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.utils.convert import to_networkx
import networkx as nx

import joblib

In [2]:
class GNN_model(torch.nn.Module):
    def __init__(self):
        super(GNN_model, self).__init__()
        self.conv1 = GCNConv(8, 4)
        self.conv2 = GCNConv(4, 2)
        self.conv3 = GCNConv(2, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.lin = torch.nn.Linear(in_features=118, out_features=118)
        # self.clamp = ClampLayer(-594, 594)
        
    def forward(self, x, edge_index):
        x = F.leaky_relu(self.conv1(x, edge_index))
        x = F.leaky_relu(self.conv2(x, edge_index))
        x = F.leaky_relu(self.conv3(x, edge_index))
        x = self.flatten(x)
        x = self.lin(x)
        # Physical constraints
        # x = self.clamp(x)

        return x

#### Load trained model

In [3]:
# Load trained model and set it to evaluation mode
model = torch.load('node_prediction_model_trained.pt')
model.eval()

GraphModule(
  (conv1): Module(
    (node__branch__node): GCNConv(8, 4)
    (node__trafo__node): GCNConv(8, 4)
  )
  (conv2): Module(
    (node__branch__node): GCNConv(4, 2)
    (node__trafo__node): GCNConv(4, 2)
  )
  (conv3): Module(
    (node__branch__node): GCNConv(2, 1)
    (node__trafo__node): GCNConv(2, 1)
  )
  (flatten): Module(
    (node): Flatten(start_dim=0, end_dim=-1)
  )
  (lin): Module(
    (node): Linear(in_features=118, out_features=118, bias=True)
  )
)

#### Load test dataset

In [4]:
# Load test dataset
test_dataset = torch.load('train_test_dataset/node_prediction_test_dataset.pt')

# Get batch size
# batch_size = test_dataset.batch_size

In [5]:
test_dataset.dataset[0]

HeteroData(
  [1mnode[0m={
    x=[118, 8],
    y=[118]
  },
  [1m(node, branch, node)[0m={
    edge_index=[2, 173],
    edge_attr=[173, 3]
  },
  [1m(node, trafo, node)[0m={
    edge_index=[2, 13],
    edge_attr=[13, 5]
  }
)

#### Model evaluation

In [6]:
pred_list = []

for batch_data in test_dataset.dataset:
    # batch_data = T.ToUndirected()(batch_data)
    # batch_data = T.AddSelfLoops()(batch_data)
    pred = model(batch_data.x_dict, batch_data.edge_index_dict)
    pred = pred['node']
    # pred = pred['node'].unsqueeze(dim=-1)
    pred = pred.detach().numpy()
    pred_list.append(pred)

In [7]:
## Save prediction
node_prediction = np.array(pred_list).T
# Save it as the same length as in synthetic data
dim = int(node_prediction.shape[0]/1)
node_prediction = node_prediction.reshape((dim, -1))

In [8]:
# Also save true values
node_true = []
for data in test_dataset.dataset:
    # node_true.append(data['node'].y.detach().numpy())
    node_true.append(data['node'].y.squeeze(dim=-1).detach().numpy())

node_true = np.array(node_true).T

In [9]:
# Save data
pd.DataFrame(node_prediction).to_csv('node_pred.csv')
pd.DataFrame(node_true).to_csv('node_true.csv')