<!-- ABSTRACT -->

The goal of this script is to check how well the model performs on the test set. For this, we will look at the overall test set, as well as some specific cases, that we will visualize.

In [None]:
import os
import sys
import copy
import json
import joblib

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import geopandas as gpd

import torch
from torch_geometric.profile import count_parameters
from torch.utils.data import Subset, DataLoader

# Add the 'scripts' directory to Python Path
scripts_path=os.path.abspath(os.path.join(os.getcwd(), '..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

import evaluation.help_functions as hf
import evaluation.plot_functions as pf

import gnn.gnn_io as gio
from gnn.help_functions import compute_spearman_pearson
from gnn.models.trans_conv import TransConv
from training.help_functions import seed_worker, normalize_x_features_with_scaler, normalize_dataset
from data_preprocessing.help_functions import highway_mapping

In [None]:
# Get the absolute path to the project root
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))

# Paths, adjust as needed
run_path = os.path.join(project_root, "data", "TR-C_Benchmarks", "tc_54x_part_4")
districts = gpd.read_file(os.path.join(project_root, "data", "visualisation", "districts_paris.geojson"))
base_case_path = os.path.join(project_root, "data", "links_and_stats", "pop_1pct_basecase_average_output_links.geojson")
result_path = 'results/'

# GNN Parameters (Others are default values)
in_channels = 5
out_channels = 1
use_dropout = False
use_graph_norm = True
use_residuals = True
num_heads = 4
hidden_channels = [32,64,128,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,128,64,32]

links_base_case = gpd.read_file(base_case_path, crs="EPSG:4326")
data_created_during_training = os.path.join(run_path, 'data_created_during_training')

In [None]:
###########################################
### Load test data from the run itself! ###
###########################################

# Load scalers
scaler_x = joblib.load(os.path.join(data_created_during_training, 'test_x_scaler.pkl'))
scaler_pos = joblib.load(os.path.join(data_created_during_training, 'test_pos_scaler.pkl'))

# Load the test dataset created during training
test_set_dl = torch.load(os.path.join(data_created_during_training, 'test_dl.pt'))

# Load the DataLoader parameters
with open(os.path.join(data_created_during_training, 'test_loader_params.json'), 'r') as f:
    test_set_dl_loader_params = json.load(f)
    
# Remove or correct collate_fn if it is incorrectly specified
if 'collate_fn' in test_set_dl_loader_params and isinstance(test_set_dl_loader_params['collate_fn'], str):
    del test_set_dl_loader_params['collate_fn']  # Remove it to use the default collate function
    
test_set_loader = torch.utils.data.DataLoader(test_set_dl, **test_set_dl_loader_params)

In [None]:
################################
### Load separate test data! ###
################################

dataset_path = os.path.join(project_root, "data", "test_data", "pst_roads_in_inner_districts")

datalist = []
batch_num = 1
while True:
    print(f"Processing batch number: {batch_num}")
    batch_file = os.path.join(dataset_path, f'datalist_batch_{batch_num}.pt')
    if not os.path.exists(batch_file):
        break
    batch_data = torch.load(batch_file, map_location='cpu')
    if isinstance(batch_data, list):
        datalist.extend(batch_data)
    batch_num += 1
print(f"Loaded {len(datalist)} items into datalist")

dataset_length = len(datalist)
test_indices = range(dataset_length)
test_subset = Subset(datalist, test_indices)

node_features = ["VOL_BASE_CASE",
                 "CAPACITY_BASE_CASE",
                 "CAPACITY_REDUCTION",
                 "FREESPEED",
                 "LENGTH"]

### Use a new scaler! ###
# test_set_normalized, scalers_test = normalize_dataset(dataset_input=test_subset, node_features=node_features)
# test_set_loader = DataLoader(dataset=test_set_normalized, batch_size=8,
#                              shuffle=True, num_workers=4, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
# scaler_x = scalers_test['x_scaler']
# scaler_pos = scalers_test['pos_scaler']
##########################

##### Load scalers from the run for consistency! #####
scaler_x = joblib.load(os.path.join(data_created_during_training, 'test_x_scaler.pkl'))
data_list = [copy.deepcopy(test_subset.dataset[idx]) for idx in test_subset.indices]
test_set_normalized = normalize_x_features_with_scaler(data_list, node_features, scaler_x)
test_set_loader = DataLoader(dataset=test_set_normalized, batch_size=8,
                             shuffle=True, num_workers=4, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
######################################################

In [None]:
model = TransConv(use_dropout=use_dropout, use_graph_norm=use_graph_norm, use_residuals=use_residuals,
                  in_channels=in_channels, out_channels=out_channels, num_heads=num_heads,
                  hidden_channels=hidden_channels)

print(f"Trainable model parameters: {round(count_parameters(model) / 1e6, 2)} M")

# Load the model state dictionary
model_path = os.path.join(run_path, 'trained_model/model.pth')
model.load_state_dict(torch.load(model_path))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

loss_fct = torch.nn.MSELoss().to(dtype=torch.float32).to(device)

In [None]:
# Basic Test

test_loss, r_squared, actual_vals, predictions, baseline_loss = hf.validate_model_on_test_set(model, test_set_loader.dataset, loss_fct, device)
spearman, pearson = compute_spearman_pearson(predictions, actual_vals)

print(f"Test Loss: {test_loss}")
print(f"R-squared: {r_squared}")
print(f"Spearman Correlation: {spearman}")
print(f"Pearson Correlation: {pearson}")

### Next, we will look at single elements of the test set and visualize the performance of the model.


In [None]:
i = 2 # index from the test set, pick a particular sample

fixed_norm_max = 50
    
my_test_data = test_set_loader.dataset[i]
my_test_x = test_set_loader.dataset[i].x
my_test_x = my_test_x.to('cpu')

test_loss_my_test_data, r_squared_my_test_data, actual_vals_my_test_data, predictions_my_test_data, baseline_loss_my_test_data = hf.validate_model_on_test_set(model, my_test_data, loss_fct, device)
print(f"Sample {i}")
print(f"Test Loss: {test_loss_my_test_data}")
print(f"R-squared: {r_squared_my_test_data}")
print(f"Baseline Loss: {baseline_loss_my_test_data}")

inversed_x = scaler_x.inverse_transform(my_test_x)

gdf_with_og_values = hf.data_to_geodataframe_with_og_values(data=my_test_data, original_gdf=links_base_case, predicted_values=predictions_my_test_data, inversed_x=inversed_x)
gdf_with_og_values['capacity_reduction_rounded'] = gdf_with_og_values['capacity_reduction'].round(decimals=3)
gdf_with_og_values['highway'] = gdf_with_og_values['highway'].map(highway_mapping)

# gdf_with_og_values['district'] = gdf_with_og_values.apply(lambda row: districts[districts.contains(row.geometry)].iloc[0]['c_ar'] if not districts[districts.contains(row.geometry)].empty else 'Unknown', axis=1)
# gdf_with_og_values = gpd.sjoin(gdf_with_og_values, districts, how='left', op='intersects')

print(f"\nPredicted:")
pf.plot_combined_output(gdf_input=gdf_with_og_values, column_to_plot="vol_car_change_predicted", is_predicted=True,
                        save_it=False, number_to_plot=i, result_path=result_path,
                        use_fixed_norm=True, fixed_norm_max=fixed_norm_max, known_districts=False, districts_of_interest=None,
                        plot_contour_lines=True, plot_policy_roads=False, with_legend=True)

print(f"Actual:")
pf.plot_combined_output(gdf_input=gdf_with_og_values, column_to_plot="vol_car_change_actual", is_predicted=False,
                        save_it=False, number_to_plot=i, result_path=result_path,
                        use_fixed_norm=True, fixed_norm_max=fixed_norm_max, known_districts=False, districts_of_interest=None,
                        plot_contour_lines=True, plot_policy_roads=False, with_legend=True)

### Plot results across the entire test set.

In [None]:
# Create gdfs for the entire test set
gdfs = []

for i in tqdm(range(len(test_set_loader.dataset))):
    my_test_data = test_set_loader.dataset[i]
    my_test_x = test_set_loader.dataset[i].x
    my_test_x = my_test_x.to('cpu')
    
    test_loss_my_test_data, r_squared_my_test_data, actual_vals_my_test_data, predictions_my_test_data, baseline_loss_my_test_data = hf.validate_model_on_test_set(model, my_test_data, loss_fct, device)
    inversed_x = scaler_x.inverse_transform(my_test_x)
    
    gdf = hf.data_to_geodataframe_with_og_values(data=my_test_data, original_gdf=links_base_case, predicted_values=predictions_my_test_data, inversed_x=inversed_x)
    
    # gdf = gpd.sjoin(gdf, districts, how='left', op='intersects')
    # gdf = gdf.rename(columns={"c_ar": "district"})
    
    gdf['capacity_reduction_rounded'] = gdf['capacity_reduction'].round(decimals=3)
    gdf['highway'] = gdf['highway'].map(highway_mapping)
    
    gdfs.append(gdf)

In [None]:
# Discrete plot for prediction error
# Absolute differences in number of vehicles
discrete_thresholds=(2.5,5,7.5)

result_gdf = pf.plot_average_prediction_differences(
    gdf_inputs=gdfs,
    scale_type="discrete",
    discrete_thresholds=discrete_thresholds,
    save_it=True,
    use_fixed_norm=True,
    fixed_norm_max=100,
    use_absolute_value_of_difference=True,
    use_percentage=False,
    disagreement_threshold=None,
    result_path=result_path,
    loss_fct="l1",
    cmap = 'Spectral_r'
)

In [None]:
# Continuous plot for prediction error, relative difference.
discrete_thresholds = (10, 25, 50)

result_gdf = pf.plot_average_prediction_differences(
    gdf_inputs=gdfs,
    scale_type="discrete",
    discrete_thresholds=discrete_thresholds,
    save_it=True,
    use_fixed_norm=True,
    fixed_norm_max=100,
    use_absolute_value_of_difference=True,
    use_percentage=True,
    disagreement_threshold=None,
    result_path=result_path,
    loss_fct="l1",
    cmap = 'Spectral_r'
)