# IoT GNN Demo

## Env Setup
To use this notebook, a conda environment has to be created beforehand. Commands to create the environment and install the minimal packages:
```bash
conda create -n iot_env python=3.11
conda activate iot_env
conda install pytorch
conda install pandas
conda install scikit-learn
conda install matplotlib
conda install pytorch::torchdata
conda install conda-forge::dgl
```  
Then, the conda environment has to be exported.  
```bash
conda install ipykernel
python -m ipykernel install --user --name iot_env --display-name "Python (iot_env)"
```  
Select the iot_env kernel to run the ipynb.

## Imports

In [21]:
from configuration import *

from models import e_graphsage, fnn_model, e_graphsage_hembed
from data import IoTDataset
import train
from train import ModelTrainer
import tester

# NF-BoT-IoT

## Randomized IP Addresses and Ports

### Load Data
This assumes there is a "data" folder under the root of the IoT_GNN repo.  
Datasets can be found at: https://drive.google.com/drive/folders/14t41P09gXTsCqPx3YFN1Pruwb2eZQrkT?usp=share_link

In [15]:
multiclass = True
randomized_ip_train_data = IoTDataset(version=1, multiclass=multiclass)
randomized_ip_val_data = IoTDataset(version=1, multiclass=multiclass, split='val')
randomized_ip_test_data = IoTDataset(version=1, multiclass=multiclass, split='test')

### Model Initialization

In [22]:
# Multiclass
model_egs = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=randomized_ip_train_data.num_features,
                                num_classes=len(randomized_ip_train_data.classes)
                              )
model_egs_cpu = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=randomized_ip_train_data.num_features,
                                num_classes=len(randomized_ip_train_data.classes)
                              )
model_fnn = fnn_model.TestFNN(num_hidden_layers=2,
                              hidden_layer_widths=[128, 192],  # Should be approximately comparable to EGS
                              num_edge_attr=randomized_ip_train_data.num_features,
                              num_classes=len(randomized_ip_train_data.classes),
                            )
model_egsh = e_graphsage_hembed.E_GraphSAGE_hEmbed(numLayers=2,
                                                   dim_node_embed=96,       # Approximately equal parameter count as EGS
                                                   num_edge_attr=randomized_ip_train_data.num_features,
                                                   num_classes=len(randomized_ip_train_data.classes)
                                                  )

In [23]:
# Multiclass
# 175 epochs sufficient for EGS to get plateau of validation risk.
egs_training_config = {
        'num_epochs': 300,
        'lr': 1e-3,
        'gpu': True,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egs_training_config_cpu = {
        'num_epochs': 300,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
# 205 epochs for the FNN
fnn_training_config = {
        'num_epochs': 205,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egsh_training_config = {
        'num_epochs': 205,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}

### Training

In [20]:
import importlib
importlib.reload(train)

<module 'train' from '/home/rowleyra/IoT_GNN/train.py'>

In [None]:
# Instantiate Trainer
# Multiclass
egs_trainer = ModelTrainer(egs_training_config, randomized_ip_train_data, randomized_ip_val_data)
egs_trainer_cpu = ModelTrainer(egs_training_config_cpu, randomized_ip_train_data, randomized_ip_val_data)
fnn_trainer = ModelTrainer(fnn_training_config, randomized_ip_train_data, randomized_ip_val_data)
egsh_trainer = ModelTrainer(egsh_training_config, randomized_ip_train_data, randomized_ip_val_data)

In [25]:
# Train the GNN
_ = egs_trainer.train_model(model_egs, False)

Training E_GraphSAGE_K2_H128:   0%|          | 1/300 [00:00<00:47,  6.32epoch/s, train loss=1.6024, validation loss=1.5696, learning rate=1.00e-03, F1 score=0.64103]

[Epoch 0] logit mean: -0.0157, std: 0.0765


Training E_GraphSAGE_K2_H128:   9%|▊         | 26/300 [00:03<00:37,  7.30epoch/s, train loss=1.1397, validation loss=1.1323, learning rate=1.00e-03, F1 score=0.76149]

[Epoch 25] logit mean: 0.0378, std: 0.7916


Training E_GraphSAGE_K2_H128:  17%|█▋        | 51/300 [00:06<00:32,  7.58epoch/s, train loss=0.9370, validation loss=0.9300, learning rate=1.00e-03, F1 score=0.76478]

[Epoch 50] logit mean: 0.1563, std: 1.4562


Training E_GraphSAGE_K2_H128:  26%|██▌       | 77/300 [00:10<00:26,  8.38epoch/s, train loss=0.7774, validation loss=0.8370, learning rate=1.00e-03, F1 score=0.76728]

[Epoch 75] logit mean: 0.1205, std: 1.8309


Training E_GraphSAGE_K2_H128:  34%|███▍      | 102/300 [00:13<00:25,  7.65epoch/s, train loss=0.7454, validation loss=0.7879, learning rate=1.00e-03, F1 score=0.79155]

[Epoch 100] logit mean: 0.0828, std: 2.2065


Training E_GraphSAGE_K2_H128:  42%|████▏     | 127/300 [00:16<00:20,  8.25epoch/s, train loss=0.6770, validation loss=0.7706, learning rate=1.00e-03, F1 score=0.76241]

[Epoch 125] logit mean: 0.0717, std: 2.5248


Training E_GraphSAGE_K2_H128:  51%|█████     | 152/300 [00:20<00:19,  7.43epoch/s, train loss=0.7002, validation loss=0.7338, learning rate=1.00e-03, F1 score=0.75959]

[Epoch 150] logit mean: 0.0364, std: 2.7996


Training E_GraphSAGE_K2_H128:  59%|█████▉    | 177/300 [00:23<00:15,  7.73epoch/s, train loss=0.6678, validation loss=0.7105, learning rate=1.00e-03, F1 score=0.77774]

[Epoch 175] logit mean: 0.0018, std: 2.9795


Training E_GraphSAGE_K2_H128:  67%|██████▋   | 202/300 [00:27<00:13,  7.52epoch/s, train loss=0.6296, validation loss=0.7014, learning rate=1.00e-03, F1 score=0.74999]

[Epoch 200] logit mean: -0.0163, std: 3.0391


Training E_GraphSAGE_K2_H128:  75%|███████▌  | 226/300 [00:30<00:10,  7.09epoch/s, train loss=0.6289, validation loss=0.6942, learning rate=3.16e-04, F1 score=0.75101]

[Epoch 225] logit mean: -0.0535, std: 3.2864


Training E_GraphSAGE_K2_H128:  84%|████████▍ | 252/300 [00:33<00:05,  8.12epoch/s, train loss=0.6184, validation loss=0.6931, learning rate=3.16e-05, F1 score=0.78105]

[Epoch 250] logit mean: -0.0548, std: 3.2935


Training E_GraphSAGE_K2_H128:  92%|█████████▏| 276/300 [00:36<00:02,  8.13epoch/s, train loss=0.6064, validation loss=0.6925, learning rate=1.00e-06, F1 score=0.78047]

[Epoch 275] logit mean: -0.0546, std: 3.2915


Training E_GraphSAGE_K2_H128: 100%|██████████| 300/300 [00:39<00:00,  7.54epoch/s, train loss=0.6385, validation loss=0.6925, learning rate=1.00e-07, F1 score=0.78047]


In [None]:
# Train the GNN
_ = egs_trainer_cpu.train_model(model_egs_cpu, False)

In [None]:
# Train the FCNN
_ = fnn_trainer.train_model(model_fnn, False)

In [None]:
# Train the EGSH
_ = egsh_trainer.train_model(model_egsh, False)

### Testing

TODO: Clean up the tester  
2025/11/30 I think this is mostly done. Just need to clean up for binary classification.
17:00 It is clean now.

In [None]:
# If modify tester
import importlib
importlib.reload(tester)

In [None]:
tester_inst = tester.ModelTester(randomized_ip_test_data, False)

In [None]:
tester_inst.test_model(model_egs)

In [None]:
tester_inst.test_model(model_fnn)

In [None]:
tester_inst.test_model(model_egsh)

## Non-Randomized IP Addresses and Ports

### Load Data

In [None]:
multiclass = True
orig_ip_train_data = IoTDataset(version=1, multiclass=multiclass, randomize_source_ip=False)
orig_ip_val_data = IoTDataset(version=1, multiclass=multiclass, split='val', randomize_source_ip=False)
orig_ip_test_data = IoTDataset(version=1, multiclass=multiclass, split='test', randomize_source_ip=False)

### Model Initialization

In [None]:
# Multiclass
model_egs_orig = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=orig_ip_train_data.num_features,
                                num_classes=len(orig_ip_train_data.classes)
                                )
model_fnn_orig = fnn_model.TestFNN(num_hidden_layers=2,
                              hidden_layer_widths=[128, 192],  # Should be approximately comparable to EGS
                              num_edge_attr=orig_ip_train_data.num_features,
                              num_classes=len(orig_ip_train_data.classes),
                            )
model_egsh_orig = e_graphsage_hembed.E_GraphSAGE_hEmbed(numLayers=2,
                                                   dim_node_embed=96,       # Approximately equal parameter count as EGS
                                                   num_edge_attr=orig_ip_train_data.num_features,
                                                   num_classes=len(orig_ip_train_data.classes)
                                                  )

### Training

In [None]:
# Multiclass
# 175 epochs sufficient for EGS to get plateau of validation risk.
egs_orig_training_config = {
        'num_epochs': 175,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
# 205 epochs for the FNN
fnn_orig_training_config = {
        'num_epochs': 205,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egsh_orig_training_config = {
        'num_epochs': 205,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}

In [None]:
# Instantiate Trainer
# Multiclass
egs_orig_trainer = ModelTrainer(egs_orig_training_config, orig_ip_train_data, orig_ip_val_data)
fnn_orig_trainer = ModelTrainer(fnn_orig_training_config, orig_ip_train_data, orig_ip_val_data)
egsh_orig_trainer = ModelTrainer(egsh_orig_training_config, orig_ip_train_data, orig_ip_val_data)

In [None]:
# Train the GNN
_ = egs_orig_trainer.train_model(model_egs_orig, False)

In [None]:
# Train the FCNN
_ = fnn_orig_trainer.train_model(model_fnn_orig, False)

In [None]:
# Train the EGSH
_ = egsh_orig_trainer.train_model(model_egsh_orig, False)

### Testing

In [None]:
tester_inst_orig = tester.ModelTester(orig_ip_test_data, False)

In [None]:
tester_inst_orig.test_model(model_egs_orig)

In [None]:
tester_inst_orig.test_model(model_fnn_orig)

In [None]:
tester_inst_orig.test_model(model_egsh_orig)

# NF-ToN-IoT

## Randomized IP Addresses and Ports

### Load Data
This assumes there is a "data" folder under the root of the IoT_GNN repo.  
Datasets can be found at: https://drive.google.com/drive/folders/14t41P09gXTsCqPx3YFN1Pruwb2eZQrkT?usp=share_link

In [None]:
multiclass = True
ton_randomized_ip_train_data = IoTDataset(dataset='NF-ToN-IoT', version=1, multiclass=multiclass)
ton_randomized_ip_val_data = IoTDataset(dataset='NF-ToN-IoT', version=1, multiclass=multiclass, split='val')
ton_randomized_ip_test_data = IoTDataset(dataset='NF-ToN-IoT', version=1, multiclass=multiclass, split='test')

### Model Initialization

In [None]:
# Multiclass
model_egs_ton = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=ton_randomized_ip_train_data.num_features,
                                num_classes=len(ton_randomized_ip_train_data.classes)
                                )
model_fnn_ton = fnn_model.TestFNN(num_hidden_layers=2,
                              hidden_layer_widths=[128, 192],  # Should be approximately comparable to EGS
                              num_edge_attr=ton_randomized_ip_train_data.num_features,
                              num_classes=len(ton_randomized_ip_train_data.classes),
                            )
model_egsh_ton = e_graphsage_hembed.E_GraphSAGE_hEmbed(numLayers=2,
                                                   dim_node_embed=96,       # Approximately equal parameter count as EGS
                                                   num_edge_attr=ton_randomized_ip_train_data.num_features,
                                                   num_classes=len(ton_randomized_ip_train_data.classes)
                                                  )

### Training

In [None]:
# Multiclass
# 175 epochs sufficient for EGS to get plateau of validation risk.
egs_training_config_ton = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
# 205 epochs for the FNN
fnn_training_config_ton = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egsh_training_config_ton = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}

In [None]:
# Instantiate Trainer
# Multiclass
egs_trainer_ton = ModelTrainer(egs_training_config_ton, ton_randomized_ip_train_data, ton_randomized_ip_val_data)
fnn_trainer_ton = ModelTrainer(fnn_training_config_ton, ton_randomized_ip_train_data, ton_randomized_ip_val_data)
egsh_trainer_ton = ModelTrainer(egsh_training_config_ton, ton_randomized_ip_train_data, ton_randomized_ip_val_data)

In [None]:
# Train the GNN
_ = egs_trainer_ton.train_model(model_egs_ton, False)

In [None]:
# Train the FCNN
_ = fnn_trainer_ton.train_model(model_fnn_ton, False)

In [None]:
# Train the EGSH
_ = egsh_trainer_ton.train_model(model_egsh_ton, False)

### Testing

In [None]:
# If modify tester
import importlib
importlib.reload(tester)

In [None]:
tester_inst_ton = tester.ModelTester(ton_randomized_ip_test_data, False)

In [None]:
tester_inst_ton.test_model(model_egs_ton)

In [None]:
tester_inst_ton.test_model(model_fnn_ton)

In [None]:
tester_inst_ton.test_model(model_egsh_ton)

## Non-Randomized IP Addresses and Ports

### Load Data

In [None]:
multiclass = True
ton_orig_ip_train_data = IoTDataset(dataset='NF-ToN-IoT', version=1, multiclass=multiclass, randomize_source_ip=False)
ton_orig_ip_val_data = IoTDataset(dataset='NF-ToN-IoT', version=1, multiclass=multiclass, split='val', randomize_source_ip=False)
ton_orig_ip_test_data = IoTDataset(dataset='NF-ToN-IoT', version=1, multiclass=multiclass, split='test', randomize_source_ip=False)

### Model Initialization

In [None]:
# Multiclass
model_egs_orig_ton = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=ton_orig_ip_train_data.num_features,
                                num_classes=len(ton_orig_ip_train_data.classes)
                                )
model_fnn_orig_ton = fnn_model.TestFNN(num_hidden_layers=2,
                              hidden_layer_widths=[128, 192],  # Should be approximately comparable to EGS
                              num_edge_attr=ton_orig_ip_train_data.num_features,
                              num_classes=len(ton_orig_ip_train_data.classes),
                            )
model_egsh_orig_ton = e_graphsage_hembed.E_GraphSAGE_hEmbed(numLayers=2,
                                                   dim_node_embed=96,       # Approximately equal parameter count as EGS
                                                   num_edge_attr=ton_orig_ip_train_data.num_features,
                                                   num_classes=len(ton_orig_ip_train_data.classes)
                                                  )

### Training

In [None]:
# Multiclass
# 175 epochs sufficient for EGS to get plateau of validation risk.
egs_orig_training_config_ton = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
# 205 epochs for the FNN
fnn_orig_training_config_ton = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egsh_orig_training_config_ton = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}

In [None]:
# Instantiate Trainer
# Multiclass
egs_orig_trainer_ton = ModelTrainer(egs_orig_training_config_ton, ton_orig_ip_train_data, ton_orig_ip_val_data)
fnn_orig_trainer_ton = ModelTrainer(fnn_orig_training_config_ton, ton_orig_ip_train_data, ton_orig_ip_val_data)
egsh_orig_trainer_ton = ModelTrainer(egsh_orig_training_config_ton, ton_orig_ip_train_data, ton_orig_ip_val_data)

In [None]:
# Train the GNN
_ = egs_orig_trainer_ton.train_model(model_egs_orig_ton, False)

In [None]:
# Train the FCNN
_ = fnn_orig_trainer_ton.train_model(model_fnn_orig_ton, False)

In [None]:
# Train the EGSH
_ = egsh_orig_trainer_ton.train_model(model_egsh_orig_ton, False)

### Testing

In [None]:
tester_inst_orig_ton = tester.ModelTester(ton_orig_ip_test_data, False)

In [None]:
tester_inst_orig_ton.test_model(model_egs_orig_ton)

In [None]:
tester_inst_orig_ton.test_model(model_fnn_orig_ton)

In [None]:
tester_inst_orig_ton.test_model(model_egsh_orig_ton)

# NF-UNSW-NB15

## Randomized IP Addresses and Ports

### Load Data
This assumes there is a "data" folder under the root of the IoT_GNN repo.  
Datasets can be found at: https://drive.google.com/drive/folders/14t41P09gXTsCqPx3YFN1Pruwb2eZQrkT?usp=share_link

In [None]:
multiclass = True
nsw_randomized_ip_train_data = IoTDataset(dataset='NF-UNSW-NB15', version=1, multiclass=multiclass)
nsw_randomized_ip_val_data = IoTDataset(dataset='NF-UNSW-NB15', version=1, multiclass=multiclass, split='val')
nsw_randomized_ip_test_data = IoTDataset(dataset='NF-UNSW-NB15', version=1, multiclass=multiclass, split='test')

### Model Initialization

In [None]:
# Multiclass
model_egs_nsw = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=nsw_randomized_ip_train_data.num_features,
                                num_classes=len(nsw_randomized_ip_train_data.classes)
                                )
model_fnn_nsw = fnn_model.TestFNN(num_hidden_layers=2,
                              hidden_layer_widths=[128, 192],  # Should be approximately comparable to EGS
                              num_edge_attr=nsw_randomized_ip_train_data.num_features,
                              num_classes=len(nsw_randomized_ip_train_data.classes),
                            )
model_egsh_nsw = e_graphsage_hembed.E_GraphSAGE_hEmbed(numLayers=2,
                                                   dim_node_embed=96,       # Approximately equal parameter count as EGS
                                                   num_edge_attr=nsw_randomized_ip_train_data.num_features,
                                                   num_classes=len(nsw_randomized_ip_train_data.classes)
                                                  )

### Training

In [None]:
# Multiclass
# 175 epochs sufficient for EGS to get plateau of validation risk.
egs_training_config_nsw = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
# 205 epochs for the FNN
fnn_training_config_nsw = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egsh_training_config_nsw = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}

In [None]:
# Instantiate Trainer
# Multiclass
egs_trainer_nsw = ModelTrainer(egs_training_config_nsw, nsw_randomized_ip_train_data, nsw_randomized_ip_val_data)
fnn_trainer_nsw = ModelTrainer(fnn_training_config_nsw, nsw_randomized_ip_train_data, nsw_randomized_ip_val_data)
egsh_trainer_nsw = ModelTrainer(egsh_training_config_nsw, nsw_randomized_ip_train_data, nsw_randomized_ip_val_data)

In [None]:
# Train the GNN
_ = egs_trainer_nsw.train_model(model_egs_nsw, False)

In [None]:
# Train the FCNN
_ = fnn_trainer_nsw.train_model(model_fnn_nsw, False)

In [None]:
# Train the EGSH
_ = egsh_trainer_nsw.train_model(model_egsh_nsw, False)

### Testing

In [None]:
# If modify tester
import importlib
importlib.reload(tester)

In [None]:
tester_inst_nsw = tester.ModelTester(nsw_randomized_ip_test_data, False)

In [None]:
tester_inst_nsw.test_model(model_egs_nsw)

In [None]:
tester_inst_nsw.test_model(model_fnn_nsw)

In [None]:
tester_inst_nsw.test_model(model_egsh_nsw)

## Non-Randomized IP Addresses and Ports

### Load Data

In [None]:
multiclass = True
nsw_orig_ip_train_data = IoTDataset(dataset='NF-UNSW-NB15', version=1, multiclass=multiclass, randomize_source_ip=False)
nsw_orig_ip_val_data = IoTDataset(dataset='NF-UNSW-NB15', version=1, multiclass=multiclass, split='val', randomize_source_ip=False)
nsw_orig_ip_test_data = IoTDataset(dataset='NF-UNSW-NB15', version=1, multiclass=multiclass, split='test', randomize_source_ip=False)

### Model Initialization

In [None]:
# Multiclass
model_egs_orig_nsw = e_graphsage.E_GraphSAGE(numLayers=2,
                                dim_node_embed=128,
                                num_edge_attr=nsw_orig_ip_train_data.num_features,
                                num_classes=len(nsw_orig_ip_train_data.classes)
                                )
model_fnn_orig_nsw = fnn_model.TestFNN(num_hidden_layers=2,
                              hidden_layer_widths=[128, 192],  # Should be approximately comparable to EGS
                              num_edge_attr=nsw_orig_ip_train_data.num_features,
                              num_classes=len(nsw_orig_ip_train_data.classes),
                            )
model_egsh_orig_nsw = e_graphsage_hembed.E_GraphSAGE_hEmbed(numLayers=2,
                                                   dim_node_embed=96,       # Approximately equal parameter count as EGS
                                                   num_edge_attr=nsw_orig_ip_train_data.num_features,
                                                   num_classes=len(nsw_orig_ip_train_data.classes)
                                                  )

### Training

In [None]:
# Multiclass
# 175 epochs sufficient for EGS to get plateau of validation risk.
egs_orig_training_config_nsw = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
# 205 epochs for the FNN
fnn_orig_training_config_nsw = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}
egsh_orig_training_config_nsw = {
        'num_epochs': 1000,
        'lr': 1e-3,
        'gpu': False,
        'lr_sched_factor': np.sqrt(10),
        'lr_sched_patience': 10,
}

In [None]:
# Instantiate Trainer
# Multiclass
egs_orig_trainer_nsw = ModelTrainer(egs_orig_training_config_nsw, nsw_orig_ip_train_data, nsw_orig_ip_val_data)
fnn_orig_trainer_nsw = ModelTrainer(fnn_orig_training_config_nsw, nsw_orig_ip_train_data, nsw_orig_ip_val_data)
egsh_orig_trainer_nsw = ModelTrainer(egsh_orig_training_config_nsw, nsw_orig_ip_train_data, nsw_orig_ip_val_data)

In [None]:
# Train the GNN
_ = egs_orig_trainer_nsw.train_model(model_egs_orig_nsw, False)

In [None]:
# Train the FCNN
_ = fnn_orig_trainer_nsw.train_model(model_fnn_orig_nsw, False)

In [None]:
# Train the EGSH
_ = egsh_orig_trainer_nsw.train_model(model_egsh_orig_nsw, False)

### Testing

In [None]:
tester_inst_orig_nsw = tester.ModelTester(nsw_orig_ip_test_data, False)

In [None]:
tester_inst_orig_nsw.test_model(model_egs_orig_nsw)

In [None]:
tester_inst_orig_nsw.test_model(model_fnn_orig_nsw)

In [None]:
tester_inst_orig_nsw.test_model(model_egsh_orig_nsw)