# Tutorial: Comparing Models

This notebook compares candidate model outputs using the same request and metrics.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
from pydantic import ValidationError

from tollama.core.forecast_metrics import compute_forecast_metrics
from tollama.core.schemas import ForecastRequest, ForecastResponse

In [None]:
timestamps = [
    '2025-01-01',
    '2025-01-02',
    '2025-01-03',
    '2025-01-04',
    '2025-01-05',
    '2025-01-06',
]
request_payload = {
    'model': 'mock',
    'horizon': 4,
    'series': [
        {
            'id': 'sku_A',
            'freq': 'D',
            'timestamps': timestamps,
            'target': [100, 102, 101, 103, 104, 106],
            'actuals': [107, 108, 110, 109],
        }
    ],
    'parameters': {'metrics': {'names': ['mae', 'rmse', 'mape'], 'mase_seasonality': 1}},
    'options': {},
}
request = ForecastRequest.model_validate(request_payload)
request

In [None]:
candidate_outputs = {
    'chronos2': [106.8, 108.2, 110.4, 109.4],
    'timesfm-2.5-200m': [106.0, 107.8, 109.6, 110.1],
    'granite-ttm-r2': [107.2, 108.5, 109.7, 109.0],
}
rows = []
for model_name, mean_values in candidate_outputs.items():
    response = ForecastResponse.model_validate({
        'model': model_name,
        'forecasts': [
            {
                'id': 'sku_A',
                'freq': 'D',
                'start_timestamp': '2025-01-07',
                'mean': mean_values,
            }
        ],
    })
    metrics, warnings = compute_forecast_metrics(request=request, response=response)
    aggregate = metrics.aggregate if metrics is not None else {}
    rows.append({'model': model_name, **aggregate, 'warnings': '; '.join(warnings)})

results = pd.DataFrame(rows).sort_values('mae')
results

In [None]:
plot_df = results.melt(
    id_vars='model',
    value_vars=['mae', 'rmse', 'mape'],
    var_name='metric',
    value_name='value',
)
fig, ax = plt.subplots(figsize=(9, 4))
for metric in ['mae', 'rmse', 'mape']:
    subset = plot_df[plot_df['metric'] == metric]
    ax.plot(subset['model'], subset['value'], marker='o', label=metric)
ax.set_title('Metric comparison by model')
ax.set_ylabel('value')
ax.grid(alpha=0.2)
ax.legend()
plt.tight_layout()

In [None]:
px.bar(
    plot_df,
    x='model',
    y='value',
    color='metric',
    barmode='group',
    title='Model comparison metrics',
)

In [None]:
bad_payload = dict(request_payload)
bad_series = dict(request_payload['series'][0])
bad_series.pop('actuals', None)
bad_payload['series'] = [bad_series]
try:
    ForecastRequest.model_validate(bad_payload)
except ValidationError as exc:
    print('Validation error example:', exc.errors()[0]['msg'])