In [1]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from data_generator import input_fn
import os

def calculate_metrics(y_true, y_pred):
    """
    Calculate MAPE, MSE, MAE, R2, and Wasserstein-1 distance
    
    Args:
        y_true: Array of true values
        y_pred: Array of predicted values
    """
    # Handle potential division by zero in MAPE
    mask = y_true != 0
    mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    
    return {
        'MAPE': mape,
        'MAE': mae,
        'R2': r2,
    }

def evaluate_model(model_type, traffic_model, batch_limit=2000):
    """
    Evaluate predictions for a specific model and traffic type
    
    Args:
        model_type: String indicating model type (GNN, LSTM, RNN)
        traffic_model: String indicating traffic model type
        batch_limit: Number of batches to process
    """
    # Construct paths
    prediction_file = f'predictions_delay_{traffic_model}_{model_type}.npy'
    test_path = f'/home/verma198/Public/RouteNet-Fermi/data/traffic_models/{traffic_model}/test'
    
    print(f"\nEvaluating {model_type} on {traffic_model} traffic model")
    print("-" * 50)
    
    try:
        # Load predictions
        predictions = np.load(prediction_file)
        
        # Load test data
        ds_test = input_fn(test_path, shuffle=False)
        
        # Process batches
        true_values_list = []
        for i, x in enumerate(ds_test):
            if i >= batch_limit:
                break
            print(f"Processing batch {i+1}/{batch_limit}")
            true_values_list.append(x[1])
            
        true_values = tf.concat(true_values_list, axis=0).numpy()
        
        # Calculate metrics
        metrics = calculate_metrics(true_values, predictions[:len(true_values)])
        
        print(f"\nMetrics for {model_type} on {traffic_model}:")
        print(f"MAPE: {metrics['MAPE']:.2f}%")
        print(f"MAE: {metrics['MAE']:.6f}")
        print(f"R²: {metrics['R2']:.6f}")
        
    except Exception as e:
        print(f"Error processing {model_type} for {traffic_model}: {str(e)}")

def main():
    # List of all traffic models and model types
    traffic_models = ['constant_bitrate', 'onoff', 'autocorrelated', 
                     'modulated', 'all_multiplexed']
    model_types = ['GNN', 'LSTM', 'RNN']
    
    # Evaluate each combination
    for traffic_model in traffic_models:
        for model_type in model_types:
            evaluate_model(model_type, traffic_model)

if __name__ == "__main__":
    main()


Evaluating GNN on constant_bitrate traffic model
--------------------------------------------------
Processing batch 1/2000
Processing batch 2/2000
Processing batch 3/2000
Processing batch 4/2000


2024-12-01 05:09:02.814632: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-01 05:09:02.816260: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
2024-12-01 05:09:02.835229: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Processing batch 5/2000
Processing batch 6/2000
Processing batch 7/2000
Processing batch 8/2000
Processing batch 9/2000
Processing batch 10/2000
Processing batch 11/2000
Processing batch 12/2000
Processing batch 13/2000
Processing batch 14/2000
Processing batch 15/2000
Processing batch 16/2000
Processing batch 17/2000
Processing batch 18/2000
Processing batch 19/2000
Processing batch 20/2000
Processing batch 21/2000
Processing batch 22/2000
Processing batch 23/2000
Processing batch 24/2000
Processing batch 25/2000
Processing batch 26/2000
Processing batch 27/2000
Processing batch 28/2000
Processing batch 29/2000
Processing batch 30/2000
Processing batch 31/2000
Processing batch 32/2000
Processing batch 33/2000
Processing batch 34/2000
Processing batch 35/2000
Processing batch 36/2000
Processing batch 37/2000
Processing batch 38/2000
Processing batch 39/2000
Processing batch 40/2000
Processing batch 41/2000
Processing batch 42/2000
Processing batch 43/2000
Processing batch 44/2000
Proce