# Analysis of Prediction Results

This notebook loads and analyzes the prediction results saved from `demo_working.py`. It compares the performance of the standard model and the Test-Time Trained (TTT) model.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd

# Apply a common style for plots
try:
    plt.style.use('seaborn-v0_8-whitegrid')
except OSError:
    print("Seaborn style 'seaborn-v0_8-whitegrid' not found, using default matplotlib style.")

## 1. Load Prediction Data

In [None]:
# Define the path to the results file (assuming notebook is in project root)
results_path = 'results/prediction_results.npy'

# Load the results
results = np.load(results_path, allow_pickle=True).item()

# Extract data components
y_true = results['true']
std_pred = results['standard']['pred']
std_metrics = results['standard']['metrics']
ttt_pred = results['ttt']['pred']
ttt_metrics = results['ttt']['metrics']

print("Data loaded successfully.")
print(f"Shape of y_true: {y_true.shape}")
print(f"Shape of std_pred: {std_pred.shape}")
print(f"Shape of ttt_pred: {ttt_pred.shape}")

## 2. Display Stored Metrics

In [None]:
print("--- Standard Prediction Metrics ---")
for metric, value in std_metrics.items():
    print(f"{metric.upper()}: {value:.4f}")

print("
--- TTT Prediction Metrics ---")
for metric, value in ttt_metrics.items():
    print(f"{metric.upper()}: {value:.4f}")

print("
--- Improvement with TTT ---")
for metric_key in ['mse', 'mae', 'rmse']:
    if std_metrics[metric_key] != 0: # Avoid division by zero
        improvement = (std_metrics[metric_key] - ttt_metrics[metric_key]) / std_metrics[metric_key] * 100
        print(f"Improvement in {metric_key.upper()}: {improvement:.2f}%")
    else:
        print(f"Improvement in {metric_key.upper()}: N/A (standard metric is zero)")

# R² improvement (higher is better)
if abs(std_metrics['r2']) > 1e-9: # Avoid division by near-zero or zero
    r2_improvement = (ttt_metrics['r2'] - std_metrics['r2']) / abs(std_metrics['r2']) * 100
    print(f"Improvement in R²: {r2_improvement:.2f}%")
else:
    print(f"Improvement in R²: N/A (standard R² is near zero)")

## 3. Scatter Plots: True vs. Predicted Values

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# Standard Prediction Plot
ax1.scatter(y_true, std_pred, alpha=0.6, edgecolors='w', linewidth=0.5, label='Predictions')
ax1.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2, label='Ideal (y=x)')
ax1.set_title('Standard Prediction Performance', fontsize=14)
ax1.set_xlabel('True Values', fontsize=12)
ax1.set_ylabel('Predicted Values', fontsize=12)
std_metrics_text = f"R²: {std_metrics['r2']:.3f}\nMAE: {std_metrics['mae']:.3f}\nMSE: {std_metrics['mse']:.3f}"
ax1.text(0.05, 0.95, std_metrics_text, transform=ax1.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='aliceblue', alpha=0.8))
ax1.legend(loc='lower right')
ax1.grid(True, linestyle='--', alpha=0.7)

# TTT Prediction Plot
ax2.scatter(y_true, ttt_pred, alpha=0.6, edgecolors='w', linewidth=0.5, label='Predictions')
ax2.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', lw=2, label='Ideal (y=x)')
ax2.set_title('TTT Prediction Performance', fontsize=14)
ax2.set_xlabel('True Values', fontsize=12)
ax2.set_ylabel('Predicted Values', fontsize=12)
ttt_metrics_text = f"R²: {ttt_metrics['r2']:.3f}\nMAE: {ttt_metrics['mae']:.3f}\nMSE: {ttt_metrics['mse']:.3f}"
ax2.text(0.05, 0.95, ttt_metrics_text, transform=ax2.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='aliceblue', alpha=0.8))
ax2.legend(loc='lower right')
ax2.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
fig.suptitle('Comparison of Prediction Accuracy: Standard vs. TTT', fontsize=16)
plt.show()

## 4. Histograms of Residuals (True - Predicted)

In [None]:
std_residuals = y_true.flatten() - std_pred.flatten()
ttt_residuals = y_true.flatten() - ttt_pred.flatten()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7), sharey=True)

# Standard Residuals
ax1.hist(std_residuals, bins=30, alpha=0.75, color='cornflowerblue', edgecolor='black')
ax1.axvline(std_residuals.mean(), color='red', linestyle='dashed', linewidth=2, label=f'Mean: {std_residuals.mean():.3f}')
ax1.axvline(np.median(std_residuals), color='darkorange', linestyle='dashed', linewidth=2, label=f'Median: {np.median(std_residuals):.3f}')
ax1.set_title('Residuals Distribution (Standard Prediction)', fontsize=14)
ax1.set_xlabel('Residual (True - Predicted)', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.legend()
ax1.grid(axis='y', linestyle='--', alpha=0.7)

# TTT Residuals
ax2.hist(ttt_residuals, bins=30, alpha=0.75, color='salmon', edgecolor='black')
ax2.axvline(ttt_residuals.mean(), color='red', linestyle='dashed', linewidth=2, label=f'Mean: {ttt_residuals.mean():.3f}')
ax2.axvline(np.median(ttt_residuals), color='darkorange', linestyle='dashed', linewidth=2, label=f'Median: {np.median(ttt_residuals):.3f}')
ax2.set_title('Residuals Distribution (TTT Prediction)', fontsize=14)
ax2.set_xlabel('Residual (True - Predicted)', fontsize=12)
ax2.legend()
ax2.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout
fig.suptitle('Distribution of Prediction Residuals: Standard vs. TTT', fontsize=16)
plt.show()

print(f"Standard Residuals - Mean: {std_residuals.mean():.4f}, Std Dev: {std_residuals.std():.4f}, Median: {np.median(std_residuals):.4f}")
print(f"TTT Residuals      - Mean: {ttt_residuals.mean():.4f}, Std Dev: {ttt_residuals.std():.4f}, Median: {np.median(ttt_residuals):.4f}")