# MLflow Model Evaluation

This notebook loads the 'adventurous-asp-289' model from MLflow and evaluates it on the test dataset.

## Tasks:
1. Load the MLflow model
2. Evaluate on 100 batches of size 256 from test set
3. Plot error PDF histogram
4. Plot scatter of absolute error vs individual 8 model inputs


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from torch.utils.data import DataLoader
import mlflow
import mlflow.pytorch
import pandas as pd
from typing import Dict, List, Tuple

# Set matplotlib style for better plots
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")


## 1. Load MLflow Model


In [None]:
# Set MLflow tracking URI to the local mlruns directory
mlflow.set_tracking_uri("file:../mlruns")

# Load the specific run
run_id = "b7cf16c3f43b447c86f021363048226d"
run_id = "27cb8cf4393c4a63adfae778115d7732"
print(f"Loading model from run: {run_id}")

# Load the model
model_uri = f"runs:/{run_id}/model"
model = mlflow.pytorch.load_model(model_uri)

# Set model to evaluation mode
model.eval()

print(f"Model loaded successfully!")
print(f"Model type: {type(model)}")
print(f"Model device: {next(model.parameters()).device}")

# Get device from model
device = next(model.parameters()).device
print(f"Using device: {device}")


## 2. Load and Prepare Test Data


In [None]:
# Define column names (same as in train.py)
PARAM_COLS = ["alpha", "gamma", "beta", "var0", "eta", "lam", "ti", "ti_indicator"]
ORIGINAL_COLS = ["alpha", "gamma", "beta", "var0", "eta", "lam", "ti", "p", "x"]

def build_params_on_device(batch, device):
    """Build parameters on GPU from raw batch data with transforms applied on device."""
    # Move all tensors to device first (non-blocking for efficiency)
    for k in batch:
        batch[k] = batch[k].to(device, non_blocking=True)
    
    # Apply transforms on GPU
    ti_original = batch["ti"]
    
    params = torch.stack([
        batch["alpha"],
        batch["gamma"],
        batch["beta"],
        torch.log(batch["var0"]),                    # log transform on GPU
        torch.log(batch["eta"]) - 1.0,               # log transform on GPU
        batch["lam"],
        torch.log(ti_original - 0.8) - 1.0,          # log transform on GPU
        (ti_original == 1.0).float(),                # ti_indicator on GPU
    ], dim=1)
    
    targets = batch["x"]  # already [B, 512]
    return params, targets

def compute_cdf_metrics(model, params, targets, quantile_levels):
    """
    Compute CDF-based metrics efficiently (single CDF call).
    
    Args:
        model: MDN model
        params: Input parameters [B, 8]
        targets: Target quantiles [B, 512]
        quantile_levels: Reference quantile levels [512]
    
    Returns:
        row_losses: RMSE loss per row [B] for importance sampling
        mean_max_diff: Mean of max absolute differences (scalar)
    """
    # Compute CDF values once
    cdf_values = model.cdf(params, targets)  # [B, 512]
    
    # Reference quantile levels - expand to [B, 512]
    batch_size = params.size(0)
    ref_quantiles = quantile_levels.unsqueeze(0).expand(batch_size, -1)
    
    # Compute differences
    diff = cdf_values - ref_quantiles  # [B, 512]
    
    # RMSE per row (for loss and outlier detection)
    row_losses = torch.sqrt(torch.mean(diff ** 2, dim=1))  # [B]
    
    # Max absolute difference per row, then mean across batch (for monitoring)
    max_diff_per_row = torch.abs(diff).amax(dim=1)  # [B]
    mean_max_diff = max_diff_per_row.mean()  # scalar
    
    return row_losses, mean_max_diff, diff

print("Data loading functions defined.")


In [None]:
# Load the dataset
print("Loading dataset...")
ds = load_dataset("sitmo/garch_densities", token=False)
print(f"Dataset loaded. Available splits: {list(ds.keys())}")

# Create test dataset with PyTorch format
test_dataset = ds["test"].with_format("torch", columns=ORIGINAL_COLS)

# Create test loader
batch_size = 256
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
    drop_last=False,
)

print(f"Test dataset: {len(test_dataset):,} samples")
print(f"Test batches: {len(test_loader):,} batches")
print(f"Batch size: {batch_size}")

# Create quantile levels for loss computation
quantile_levels = torch.linspace(0.001, 0.999, 512, device=device)
print(f"Quantile levels shape: {quantile_levels.shape}")


## 3. Evaluate Model on Test Data


In [None]:
# Evaluate on 100 batches
num_batches = 100
print(f"Evaluating model on {num_batches} batches of size {batch_size}...")

# Storage for results
all_errors = []
all_abs_errors = []
all_params = []
all_row_losses = []
all_max_diffs = []

model.eval()
with torch.inference_mode():
    for batch_idx, test_batch in enumerate(test_loader):
        if batch_idx >= num_batches:
            break
            
        # Build parameters on device
        params, targets = build_params_on_device(test_batch, device)
        
        # Compute metrics
        row_losses, max_diff, diff = compute_cdf_metrics(model, params, targets, quantile_levels)
        
        # Store results
        all_errors.append(diff.cpu().numpy())  # [B, 512] - raw differences
        all_abs_errors.append(torch.abs(diff).cpu().numpy())  # [B, 512] - absolute differences
        all_params.append(params.cpu().numpy())  # [B, 8] - input parameters
        all_row_losses.append(row_losses.cpu().numpy())  # [B] - RMSE per row
        all_max_diffs.append(max_diff.cpu().numpy())  # scalar
        
        if (batch_idx + 1) % 10 == 0:
            print(f"Processed {batch_idx + 1}/{num_batches} batches")

print(f"Evaluation completed! Processed {len(all_errors)} batches")
print(f"Total samples evaluated: {sum(len(err) for err in all_errors):,}")


In [None]:
# Concatenate all results
errors = np.concatenate(all_errors, axis=0)  # [N, 512]
abs_errors = np.concatenate(all_abs_errors, axis=0)  # [N, 512]
params = np.concatenate(all_params, axis=0)  # [N, 8]
row_losses = np.concatenate(all_row_losses, axis=0)  # [N]

print(f"Concatenated results:")
print(f"  Errors shape: {errors.shape}")
print(f"  Abs errors shape: {abs_errors.shape}")
print(f"  Params shape: {params.shape}")
print(f"  Row losses shape: {row_losses.shape}")

# Compute summary statistics
mean_error = np.mean(errors)
std_error = np.std(errors)
mean_abs_error = np.mean(abs_errors)
std_abs_error = np.std(abs_errors)
mean_row_loss = np.mean(row_losses)
std_row_loss = np.std(row_losses)

print(f"\nSummary statistics:")
print(f"  Mean error: {mean_error:.6f} ± {std_error:.6f}")
print(f"  Mean absolute error: {mean_abs_error:.6f} ± {std_abs_error:.6f}")
print(f"  Mean row loss (RMSE): {mean_row_loss:.6f} ± {std_row_loss:.6f}")


## 4. Plot 1: Error PDF Histogram


In [None]:
# Flatten all errors for histogram
errors_flat = errors.flatten()
abs_errors_flat = abs_errors.flatten()

print(f"Flattened errors: {len(errors_flat):,} points")
print(f"Error range: [{np.min(errors_flat):.6f}, {np.max(errors_flat):.6f}]")
print(f"Abs error range: [{np.min(abs_errors_flat):.6f}, {np.max(abs_errors_flat):.6f}]")


In [None]:
# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Model Error Analysis', fontsize=16, fontweight='bold')

# Define colors
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

# Plot 1: Raw error histogram
n, bins, patches = axes[0, 0].hist(errors_flat, bins=np.linspace(-0.002, 0.002, 100), alpha=0.7, density=True, color=colors[0], edgecolor='black', linewidth=0.5)
axes[0, 0].axvline(mean_error, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_error:.4f}')
axes[0, 0].axvline(mean_error + std_error, color='orange', linestyle='--', alpha=0.7, label=f'+1σ: {mean_error + std_error:.4f}')
axes[0, 0].axvline(mean_error - std_error, color='orange', linestyle='--', alpha=0.7, label=f'-1σ: {mean_error - std_error:.4f}')
axes[0, 0].set_xlabel('Error (CDF difference)')
axes[0, 0].set_ylabel('Density')
axes[0, 0].set_title('Raw Error Distribution')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Absolute error histogram
n, bins, patches = axes[0, 1].hist(abs_errors_flat, bins=100, alpha=0.7, density=True, color=colors[1], edgecolor='black', linewidth=0.5)
axes[0, 1].axvline(mean_abs_error, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_abs_error:.4f}')
axes[0, 1].axvline(mean_abs_error + std_abs_error, color='orange', linestyle='--', alpha=0.7, label=f'+1σ: {mean_abs_error + std_abs_error:.4f}')
axes[0, 1].set_xlabel('Absolute Error')
axes[0, 1].set_ylabel('Density')
axes[0, 1].set_title('Absolute Error Distribution')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Log-scale absolute error histogram
log_abs_errors = np.log10(np.maximum(abs_errors_flat, 1e-10))  # Avoid log(0)
n, bins, patches = axes[1, 0].hist(log_abs_errors, bins=100, alpha=0.7, density=True, color=colors[2], edgecolor='black', linewidth=0.5)
axes[1, 0].axvline(np.log10(mean_abs_error), color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_abs_error:.4f}')
axes[1, 0].set_xlabel('log₁₀(Absolute Error)')
axes[1, 0].set_ylabel('Density')
axes[1, 0].set_title('Log-Scale Absolute Error Distribution')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Row loss (RMSE) histogram
n, bins, patches = axes[1, 1].hist(row_losses, bins=100, alpha=0.7, density=True, color=colors[3], edgecolor='black', linewidth=0.5)
axes[1, 1].axvline(mean_row_loss, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_row_loss:.4f}')
axes[1, 1].axvline(mean_row_loss + std_row_loss, color='orange', linestyle='--', alpha=0.7, label=f'+1σ: {mean_row_loss + std_row_loss:.4f}')
axes[1, 1].set_xlabel('Row Loss (RMSE)')
axes[1, 1].set_ylabel('Density')
axes[1, 1].set_title('Row Loss Distribution')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print additional statistics
print(f"\nError Statistics:")
print(f"  Raw error - Mean: {mean_error:.6f}, Std: {std_error:.6f}")
print(f"  Raw error - Min: {np.min(errors_flat):.6f}, Max: {np.max(errors_flat):.6f}")
print(f"  Abs error - Mean: {mean_abs_error:.6f}, Std: {std_abs_error:.6f}")
print(f"  Abs error - Min: {np.min(abs_errors_flat):.6f}, Max: {np.max(abs_errors_flat):.6f}")
print(f"  Row loss - Mean: {mean_row_loss:.6f}, Std: {std_row_loss:.6f}")
print(f"  Row loss - Min: {np.min(row_losses):.6f}, Max: {np.max(row_losses):.6f}")


## 5. Plot 2: Scatter of Absolute Error vs Model Inputs


In [None]:
# Parameter names for plotting
param_names = ["alpha", "gamma", "beta", "log(var0)", "log(eta)-1", "lam", "log(ti-0.8)-1", "ti_indicator"]

# Compute mean absolute error per sample (across all 512 quantiles)
mean_abs_error_per_sample = np.mean(abs_errors, axis=1)  # [N]

print(f"Mean absolute error per sample shape: {mean_abs_error_per_sample.shape}")
print(f"Mean absolute error per sample range: [{np.min(mean_abs_error_per_sample):.6f}, {np.max(mean_abs_error_per_sample):.6f}]")


In [None]:
# Create scatter plots for each input parameter
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
fig.suptitle('Absolute Error vs Model Inputs', fontsize=16, fontweight='bold')

axes = axes.flatten()

for i in range(8):
    ax = axes[i]
    if i == 7:
        continue
    
    # Convert parameter values to rank values [0, 1]
    param_values = params[:, i]
    param_ranks = (np.argsort(np.argsort(param_values)) + 1) / len(param_values)
    
    # Create scatter plot with C0 color (no gradient)
    ax.scatter(param_ranks, mean_abs_error_per_sample, 
               alpha=0.6, s=1, c='C0', edgecolors='none')
    
    # Bin the data and compute statistics
    n_bins = 20
    bin_edges = np.linspace(0, 1, n_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    
    # Compute mean and 98th percentile for each bin
    bin_means = []
    bin_98th = []
    
    for j in range(n_bins):
        mask = (param_ranks >= bin_edges[j]) & (param_ranks < bin_edges[j + 1])
        if j == n_bins - 1:  # Include the last point
            mask = (param_ranks >= bin_edges[j]) & (param_ranks <= bin_edges[j + 1])
        
        if np.sum(mask) > 0:
            bin_errors = mean_abs_error_per_sample[mask]
            bin_means.append(np.mean(bin_errors))
            bin_98th.append(np.percentile(bin_errors, 98))
        else:
            bin_means.append(np.nan)
            bin_98th.append(np.nan)
    
    # Convert to numpy arrays for plotting
    bin_means = np.array(bin_means)
    bin_98th = np.array(bin_98th)
    
    # Plot horizontal lines for mean and 98th percentile
    valid_mask = ~np.isnan(bin_means)
    if np.any(valid_mask):
        ax.plot(bin_centers[valid_mask], bin_means[valid_mask], 
                color='red', linewidth=2, alpha=0.8)
        ax.plot(bin_centers[valid_mask], bin_98th[valid_mask], 
                color='black', linewidth=2, alpha=0.8)
    
    # Set labels and title
    ax.set_xlabel(f'{param_names[i]} (rank)')
    ax.set_ylabel('Mean Absolute Error')
    ax.set_title(f'{param_names[i]}')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print correlation coefficients
print("\nCorrelation coefficients between inputs and mean absolute error:")
for i, name in enumerate(param_names):
    corr = np.corrcoef(params[:, i], mean_abs_error_per_sample)[0, 1]
    print(f"  {name:15s}: {corr:7.4f}")


## 6. Additional Analysis: Error vs Quantile Level


In [None]:
# Analyze error patterns across quantile levels
quantile_levels_np = quantile_levels.cpu().numpy()
mean_error_per_quantile = np.mean(abs_errors, axis=0)  # [512]
std_error_per_quantile = np.std(abs_errors, axis=0)   # [512]

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Define colors
blue_color = '#1f77b4'
red_color = '#d62728'

# Plot 1: Mean error vs quantile level
axes[0].plot(quantile_levels_np, mean_error_per_quantile, color=blue_color, linewidth=2, label='Mean')
axes[0].fill_between(quantile_levels_np, 
                     mean_error_per_quantile - std_error_per_quantile,
                     mean_error_per_quantile + std_error_per_quantile,
                     alpha=0.3, color=blue_color, label='±1σ')
axes[0].set_xlabel('Quantile Level')
axes[0].set_ylabel('Mean Absolute Error')
axes[0].set_title('Error vs Quantile Level')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Log-scale error vs quantile level
axes[1].semilogy(quantile_levels_np, mean_error_per_quantile, color=red_color, linewidth=2, label='Mean')
axes[1].fill_between(quantile_levels_np, 
                     np.maximum(mean_error_per_quantile - std_error_per_quantile, 1e-10),
                     mean_error_per_quantile + std_error_per_quantile,
                     alpha=0.3, color=red_color, label='±1σ')
axes[1].set_xlabel('Quantile Level')
axes[1].set_ylabel('Mean Absolute Error (log scale)')
axes[1].set_title('Error vs Quantile Level (Log Scale)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nError by quantile level:")
print(f"  Min error at quantile {quantile_levels_np[np.argmin(mean_error_per_quantile)]:.3f}: {np.min(mean_error_per_quantile):.6f}")
print(f"  Max error at quantile {quantile_levels_np[np.argmax(mean_error_per_quantile)]:.3f}: {np.max(mean_error_per_quantile):.6f}")
print(f"  Mean error across all quantiles: {np.mean(mean_error_per_quantile):.6f}")


## 7. Summary Statistics


In [None]:
# Create summary statistics table
summary_stats = {
    'Metric': [
        'Total samples evaluated',
        'Mean raw error',
        'Std raw error',
        'Mean absolute error',
        'Std absolute error',
        'Mean row loss (RMSE)',
        'Std row loss (RMSE)',
        'Max absolute error',
        'Min absolute error'
    ],
    'Value': [
        f"{len(errors_flat):,}",
        f"{mean_error:.6f}",
        f"{std_error:.6f}",
        f"{mean_abs_error:.6f}",
        f"{std_abs_error:.6f}",
        f"{mean_row_loss:.6f}",
        f"{std_row_loss:.6f}",
        f"{np.max(abs_errors_flat):.6f}",
        f"{np.min(abs_errors_flat):.6f}"
    ]
}

summary_df = pd.DataFrame(summary_stats)
print("\n=== EVALUATION SUMMARY ===")
print(summary_df.to_string(index=False))

print("\n=== CORRELATION ANALYSIS ===")
print("Correlation between input parameters and mean absolute error:")
for i, name in enumerate(param_names):
    corr = np.corrcoef(params[:, i], mean_abs_error_per_sample)[0, 1]
    print(f"  {name:20s}: {corr:7.4f}")

print("\n=== QUANTILE-LEVEL ANALYSIS ===")
print(f"Error varies across quantile levels:")
print(f"  Lowest error: {np.min(mean_error_per_quantile):.6f} at quantile {quantile_levels_np[np.argmin(mean_error_per_quantile)]:.3f}")
print(f"  Highest error: {np.max(mean_error_per_quantile):.6f} at quantile {quantile_levels_np[np.argmax(mean_error_per_quantile)]:.3f}")
print(f"  Error range: {np.max(mean_error_per_quantile) - np.min(mean_error_per_quantile):.6f}")

print("\n=== MODEL PERFORMANCE ===")
if mean_abs_error < 0.01:
    print("✅ Model shows excellent performance (mean abs error < 0.01)")
elif mean_abs_error < 0.05:
    print("✅ Model shows good performance (mean abs error < 0.05)")
elif mean_abs_error < 0.1:
    print("⚠️  Model shows moderate performance (mean abs error < 0.1)")
else:
    print("❌ Model shows poor performance (mean abs error >= 0.1)")

print(f"\nEvaluation completed successfully!")
print(f"Model: {run_id}")
print(f"Device: {device}")
print(f"Samples evaluated: {len(errors_flat):,}")
