In [1]:
import os
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

In [2]:
random.seed(30)

#### Load test dataset

In [3]:
# Root
from MyDataset import MyDataset

root = '../../train_val_test_dataset/IEEE_Case118/UC'

# Load all data with empty input
dataset = MyDataset(root=root, data_list=[])

In [4]:
# Get train dataset
train_size = 7000
test_size = 2000

# test_dataset = dataset[-test_size:]
test_dataset = dataset[:]

# Create train loader
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

#### Load trained model

In [5]:
from GNNClassifier import GNNClassifier
model = torch.load('./trained_model/UC_model.pt')
model.eval()

GNNClassifier(
  (dropout): Dropout(p=0.3, inplace=False)
  (encoder): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): ReLU()
  )
  (gnn): ModuleList(
    (0): GCNConv(64, 64)
    (1): GCNConv(64, 64)
    (2): GCNConv(64, 64)
  )
  (decoder): Sequential(
    (0): Linear(in_features=64, out_features=24, bias=True)
    (1): ReLU()
    (2): Linear(in_features=24, out_features=12, bias=True)
    (3): Sigmoid()
  )
)

#### Model testing

In [6]:
with torch.no_grad():
    for i, loader in enumerate(test_loader):
        # Get ground true
        ground_true = loader.y.detach().numpy()
        # Save ground true
        # file_path = f'./model_evaluation/UC_true/true_{i+1}.csv'
        file_path = f'./model_evaluation/UC_true_all/true_{i+1}.csv'
        pd.DataFrame(ground_true).to_csv(file_path, columns=None, index=None, header=None)

        # Get prediction
        pred = model(loader.x, loader.edge_index, loader.edge_attr)
        pred = (pred>0.5).float()
        pred = pred.detach().numpy()
        # Save prediction
        # file_path = f'./model_evaluation/UC_pred/pred_{i+1}.csv'
        file_path = f'./model_evaluation/UC_pred_all/pred_{i+1}.csv'
        pd.DataFrame(pred).to_csv(file_path, columns=None, index=None, header=None)