In [4]:
import time
import logging

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns

import torch
from torch.functional import F
from torch_geometric.nn import GCNConv, to_hetero, GATConv
from torch_geometric.data import Data
from torch_geometric import transforms as T
from torch_geometric.utils import is_undirected
from torch_geometric.utils.convert import to_networkx
import networkx as nx


from prettytable import PrettyTable

In [5]:
train_dataset = torch.load('train_test_dataset/node_prediction_train_dataset.pt')
data = train_dataset.dataset[0]
data

HeteroData(
  node={
    x=[1355, 8],
    y=[1354],
  },
  (node, branch, node)={
    edge_index=[2, 1751],
    edge_attr=[1751, 3],
  },
  (node, trafo, node)={
    edge_index=[2, 240],
    edge_attr=[240, 5],
  }
)

In [6]:
## Custom loss function
class customLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, y_pred, y_max_gen, y_min_gen, y_true):
        loss1 = F.mse_loss(y_pred, y_true)  # Equality constraint
        loss2 = torch.sum(torch.min(y_pred - y_min_gen, torch.zeros((y_pred.shape[0])))) # Inequality constraint, lower constraint
        loss3 = torch.sum(torch.max(y_pred - y_max_gen, torch.zeros((y_pred.shape[0])))) # Inequality constraint, upper constraint

        loss = loss1 - 1.5*loss2 + 1.5*loss3    # Overall loss

        return loss

In [7]:
## Clamp layer
class ClampLayer(torch.nn.Module):
    def __init__(self, min_value=-594, max_value=594):
        super(ClampLayer, self).__init__()
        self.min_value = min_value
        self.max_value = max_value

    def forward(self, x):
        return torch.clamp(x, self.min_value, self.max_value)

In [8]:
class GNN_model(torch.nn.Module):
    def __init__(self):
        super(GNN_model, self).__init__()
        self.conv1 = GCNConv(8, 4)
        self.conv2 = GCNConv(4, 2)
        self.conv3 = GCNConv(2, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.lin = torch.nn.Linear(in_features=118, out_features=118)
        # self.clamp = ClampLayer(-594, 594)
        
    def forward(self, x, edge_index):
        x = F.leaky_relu(self.conv1(x, edge_index))
        x = F.leaky_relu(self.conv2(x, edge_index))
        x = F.leaky_relu(self.conv3(x, edge_index))
        x = self.flatten(x)
        x = self.lin(x)
        # Physical constraints
        # x = self.clamp(x)

        return x

AssertionError: 