### README

The model currently only has GRU active, the msst_classes file has the micro graphs turned off in the fusion module, and the only used features are the variants, not the exogenous covariates.

In [1]:
import pandas as pd
import pickle
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse
import numpy as np
from math import prod
from torch.utils.data import Dataset
import msst_classes 
from torch.utils.data import DataLoader, Subset



## Load State Data

In [2]:
state_df = pd.read_csv("../processed data/state_level/rolled_covariates_state_level.csv", parse_dates=['date'], dtype={'location': str})
state_df.location = state_df.location.str.zfill(2)
state_df.head()

Unnamed: 0,date,location,people_vaccinated,people_fully_vaccinated,school_closing,workplace_closing,cancel_events,gatherings_restrictions,transport_closing,stay_home_restrictions,...,population,Alpha,Beta,Delta,Epsilon,Gamma,Iota,Omicron,deaths,Other
0,2021-01-01,1,53829.0,272.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,45.0,4521.0
1,2021-01-02,1,53940.5,276.5,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,22.5,4116.0
2,2021-01-03,1,54309.0,279.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,15.333333,3569.333333
3,2021-01-04,1,56004.75,309.75,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,12.75,3217.25
4,2021-01-05,1,58707.8,416.6,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,11.8,3673.4


In [3]:
state_covs = pd.read_csv("../processed data/state_level/state_level_characteristics.csv", dtype={'state_fips':str}, index_col=0)
state_covs.head()

Unnamed: 0,state_fips,Population,median_age,Income,Density_per_mile
0,1,5157699,39.6,65560,101.0
1,4,7582384,39.4,84700,65.0
2,5,3088354,39.1,64840,59.0
3,6,39431263,38.4,100600,250.0
4,8,5957494,38.0,106500,57.0


In [4]:
state_df = state_df.merge(state_covs, left_on='location', right_on='state_fips', how='inner').drop(['population', 'deaths', 'state_fips'], axis=1)

In [5]:
state_fips_list = sorted(state_df['location'].unique())
state_index = {fips: i for i, fips in enumerate(state_fips_list)}
S = len(state_fips_list)

In [6]:
all_dates = pd.date_range(state_df['date'].min(), state_df['date'].max(), freq='D')

In [7]:
state_df = state_df.sort_values(['date', 'location'])

In [8]:
state_df = state_df.set_index(['location', 'date'])

In [9]:
variant_cols = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Omicron', 'Other']

In [10]:
state_df[variant_cols] /= 1000

In [11]:
exo_cols = ['people_vaccinated', 'people_fully_vaccinated', 'school_closing',
       'workplace_closing', 'cancel_events', 'gatherings_restrictions',
       'transport_closing', 'stay_home_restrictions',
       'internal_movement_restrictions', 'international_movement_restrictions',
       'information_campaigns', 'testing_policy', 'contact_tracing',
       'facial_coverings', 'vaccination_policy', 'elderly_people_protection',
       'government_response_index', 'stringency_index',
       'containment_health_index', 'economic_support_index']
static_cols = ['Population',
       'median_age', 'Income', 'Density_per_mile']

### Key testing change - we've set feature cols to only have variants

In [12]:
feature_cols = variant_cols

In [13]:
# feature_cols = variant_cols + static_cols + exo_cols

In [14]:
T = len(all_dates)
F_macro = len(feature_cols)
X_state = np.zeros((T, S, F_macro), dtype=np.float32)

for s_idx, s_fips in enumerate(state_fips_list):
    sub = state_df.xs(s_fips, level='location')[feature_cols]
    X_state[:, s_idx, :] = sub.to_numpy()


## Build County Adjacency Matrix

In [15]:
with open('../processed data/county level/county_adj_by_state.pkl', 'rb') as file:
    adj_1 = pickle.load(file)

In [16]:
with open('../processed data/county level/county_airport_weights_by_state.pkl', 'rb') as file:
    adj_2 = pickle.load(file)

In [17]:
with open('../processed data/county level/county_highway_weights_by_state.pkl', 'rb') as file:
    adj_3 = pickle.load(file)

In [18]:
county_ids_dict = {}   # state_fips -> ordered list of county FIPS (strings or ints)
county_adj_dict = {}   # state_fips -> np.ndarray (M_s x M_s)

In [19]:
state_fips_list = sorted(adj_1.keys())

In [20]:
def normalize_df(df):
    arr = df.to_numpy(dtype=np.float32)
    max_val = np.percentile(arr, 95)  # or arr.max()
    if max_val > 0:
        arr = arr / max_val
    return arr

In [21]:
for s_fips in state_fips_list:
    df1 = adj_1[s_fips]
    df2 = adj_2.get(s_fips, None)          # may be None
    df3 = adj_3.get(s_fips, None)          # may be None

    # 1) Compute the union of county IDs present in any of the three
    counties = set(df1.index)
    if df2 is not None:
        counties |= set(df2.index)
    if df3 is not None:
        counties |= set(df3.index)
    counties = sorted(counties)

    # 2) Align each DF to this county list; missing df gets all zeros
    df1_al = df1.reindex(index=counties, columns=counties, fill_value=0)

    if df2 is not None:
        df2_al = df2.reindex(index=counties, columns=counties, fill_value=0)
    else:
        df2_al = pd.DataFrame(0.0, index=counties, columns=counties)

    if df3 is not None:
        df3_al = df3.reindex(index=counties, columns=counties, fill_value=0)
    else:
        df3_al = pd.DataFrame(0.0, index=counties, columns=counties)

    # 3) Sum the three weight types
    W1 = normalize_df(df1_al)
    W2 = normalize_df(df2_al) if df2 is not None else np.zeros_like(W1)
    W3 = normalize_df(df3_al) if df3 is not None else np.zeros_like(W1)
    
    # Option A: equal weight
    A_s = W1 + W2 + W3

    # 4) Save ordering + adjacency
    county_ids_dict[s_fips] = counties   # this defines row/col order
    county_adj_dict[s_fips] = A_s        # (M_s, M_s)

## Build global county graph

In [22]:
county_fips_global = []   # global list of county FIPS
state_of_county = []      # same length as county_fips_global

# Also count total counties to size A_county_global later
total_counties = 0
for s_fips in state_fips_list:
    counties_s = county_ids_dict[s_fips]
    total_counties += len(counties_s)

M = total_counties
state_of_county = np.zeros(M, dtype=np.int64)

# Fill county_fips_global + state_of_county and build A_county_global
A_county_global = np.zeros((M, M), dtype=np.float32)

offset = 0
for s_fips in state_fips_list:
    A_s = county_adj_dict[s_fips].astype(np.float32)
    counties_s = county_ids_dict[s_fips]
    n_s = len(counties_s)

    # sanity check
    assert A_s.shape == (n_s, n_s), f"Adjacency size mismatch for state {s_fips}"

    # record county FIPS + state indices for these n_s counties
    s_idx = state_index[s_fips]
    for local_idx, c_fips in enumerate(counties_s):
        global_idx = offset + local_idx
        county_fips_global.append(c_fips)
        state_of_county[global_idx] = s_idx

    # place this state's adjacency in the block-diagonal
    A_county_global[offset:offset+n_s, offset:offset+n_s] = A_s

    offset += n_s

# final sanity
assert offset == M


## Building covariates

In [23]:
county_fips_global = [str(c) for c in county_fips_global]
county_to_global = {c_fips: idx for idx, c_fips in enumerate(county_fips_global)}

M = len(county_fips_global)
T = len(all_dates)
F_micro = len(variant_cols)


In [24]:
with open('../processed data/county level/rolled_county_cases.pkl', 'rb') as file:
    county_cases_dict = pickle.load(file)

In [25]:
county_cases_dict['01']

Unnamed: 0,fips,date,state,Alpha,Beta,Delta,Epsilon,Gamma,Iota,Omicron,Other
0,01001,2021-01-01,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,49.000000
1,01001,2021-01-02,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,39.000000
2,01001,2021-01-03,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,38.333333
3,01001,2021-01-04,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,36.500000
4,01001,2021-01-05,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,71.200000
...,...,...,...,...,...,...,...,...,...,...,...
49635,01999,2022-12-27,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000
49636,01999,2022-12-28,01,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000
49637,01999,2022-12-29,01,0.0,0.0,0.0,0.0,0.0,0.0,537.485714,134.371429
49638,01999,2022-12-30,01,0.0,0.0,0.0,0.0,0.0,0.0,537.485714,134.371429


In [26]:
X_county = np.zeros((T, M, F_micro), dtype=np.float32)

for s_fips in state_fips_list:
    if s_fips not in county_cases_dict:
        # If a state has adjacency but no county time series -> stays all zeros
        continue

    df = county_cases_dict[s_fips].copy()
    # ensure types & ordering
    df['date'] = pd.to_datetime(df['date'])
    df['fips'] = df['fips'].astype(str)
    df[variant_cols] /= 1000
    # (optional) sanity: all counties from this state should be in county_ids_dict
    # missing ones will just be ignored or raise a KeyError below
    # assert set(df['county_fips'].unique()).issubset(set(county_ids_dict[s_fips]))

    # group by county and fill each global row
    for c_fips, group in df.groupby('fips'):
        # skip counties that didn't make it into the adjacency for some reason
        if c_fips not in county_to_global:
            continue

        g_idx = county_to_global[c_fips]

        # index by date, select variant columns, align to all_dates
        sub = (
            group
            .set_index('date')[variant_cols]
            .reindex(all_dates)
            .fillna(0.0)
        )

        X_county[:, g_idx, :] = sub.to_numpy(dtype=np.float32)


In [27]:
X_state.shape

(730, 49, 8)

In [28]:
X_county.shape

(730, 3109, 8)

## Build A_state

In [29]:
state_airweights = pd.read_csv('../processed data/state_level/state_level_airport_weights.csv', index_col=0, dtype={'STATEFP':str})

In [30]:
state_borderweights = pd.read_csv('../processed data/state_level/state_level_border_weights.csv', index_col=0, dtype={'state_fips':str})

In [31]:
state_highwayweights = pd.read_csv('../processed data/state_level/state_level_highway_weights.csv', index_col=0, dtype={'state_fips':str})
state_highwayweights.index = state_highwayweights.columns = [fips.zfill(2) for fips in state_highwayweights.columns.to_list()]
state_highwayweights = state_highwayweights.reindex(index=state_fips_list,
                                    columns=state_fips_list,
                                    fill_value=0.0)

In [32]:
state_airweights.index.to_list() == state_borderweights.index.to_list() == state_highwayweights.index.to_list()

True

In [33]:
W1 = normalize_df(state_airweights)
W2 = normalize_df(state_borderweights) 
W3 = normalize_df(state_highwayweights) 
A_state = W1 + W2 + W3

In [34]:
v_out = len(variant_cols)
target_variant_indices = [feature_cols.index(v) for v in variant_cols]

## Train

In [35]:
input_len = 14
horizon = 7

dataset = msst_classes.EpidemicDataset(
    X_state=X_state,
    X_county=X_county,
    input_len=input_len,
    horizon=horizon,
    target_variant_indices=target_variant_indices,
)

N = len(dataset)
train_ratio = 0.8
N_train = int(N * train_ratio)
train_indices = np.arange(0, N_train)
val_indices = np.arange(N_train, N)

# tiny_indices = np.arange(0, 10)   # or even 0..4
# train_ds = Subset(dataset, tiny_indices)
# train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

train_ds = Subset(dataset, train_indices)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)


val_ds   = Subset(dataset, val_indices)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)


In [36]:
device = "cuda" if torch.cuda.is_available() else "cpu"

hidden_gcn = 64
hidden_gru = 64
v_out = len(target_variant_indices)

# model = msst_classes.MSSTVariant(
#     macro_in=F_macro,
#     micro_in=F_micro,
#     hidden_gcn=hidden_gcn,
#     hidden_gru=hidden_gru,
#     horizon=horizon,
#     v_out=v_out,
#     A_state=A_state,
#     A_county_global=A_county_global,
#     state_of_county=state_of_county,
# ).to(device)

model = msst_classes.GRUOnly(F_macro, hidden_gru=32, horizon=horizon, v_out=v_out).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0)
loss_fn = nn.MSELoss()


In [37]:
num_epochs = 30

for epoch in range(1, num_epochs + 1):
    # --- Train ---
    model.train()
    train_losses = []

    for X_state_seq, X_county_seq, y_true in train_loader:

        # Remove batch dimension (batch_size=1)
        X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
        X_county_seq = X_county_seq.squeeze(0).to(device)  # (T_in, M, F_micro)
        y_true = y_true.squeeze(0).to(device)              # (S, horizon, v_out)

        optimizer.zero_grad()
        y_pred = model(X_state_seq, X_county_seq)          # (S, horizon, v_out)

        loss = loss_fn(y_pred, y_true)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
        # print("RMSE loss / y_mean:", np.sqrt(loss.item()) / y_true.mean())

    # --- Validation ---
    model.eval()
    val_losses = []
    with torch.no_grad():
        for X_state_seq, X_county_seq, y_true in val_loader:
            X_state_seq = X_state_seq.squeeze(0).to(device)
            X_county_seq = X_county_seq.squeeze(0).to(device)
            y_true = y_true.squeeze(0).to(device)

            y_pred = model(X_state_seq, X_county_seq)
            loss = loss_fn(y_pred, y_true)
            val_losses.append(loss.item())

    mean_train = np.mean(train_losses) if train_losses else float("nan")
    mean_val = np.mean(val_losses) if val_losses else float("nan")
    print(f"Epoch {epoch:03d}: train_loss={mean_train:.4f}, val_loss={mean_val:.4f}")


Epoch 001: train_loss=2.7115, val_loss=0.0367
Epoch 002: train_loss=0.9729, val_loss=0.0337
Epoch 003: train_loss=0.4818, val_loss=0.0332
Epoch 004: train_loss=0.3327, val_loss=0.0457
Epoch 005: train_loss=0.2920, val_loss=0.0321
Epoch 006: train_loss=0.2392, val_loss=0.0374
Epoch 007: train_loss=0.2369, val_loss=0.0436
Epoch 008: train_loss=0.2251, val_loss=0.0324
Epoch 009: train_loss=0.2106, val_loss=0.0402
Epoch 010: train_loss=0.1957, val_loss=0.0545
Epoch 011: train_loss=0.2062, val_loss=0.0344
Epoch 012: train_loss=0.2004, val_loss=0.0322
Epoch 013: train_loss=0.1764, val_loss=0.0289
Epoch 014: train_loss=0.1804, val_loss=0.0318
Epoch 015: train_loss=0.1644, val_loss=0.0371
Epoch 016: train_loss=0.1800, val_loss=0.0336
Epoch 017: train_loss=0.1602, val_loss=0.0296
Epoch 018: train_loss=0.1661, val_loss=0.0278
Epoch 019: train_loss=0.1498, val_loss=0.0322
Epoch 020: train_loss=0.1446, val_loss=0.0298
Epoch 021: train_loss=0.1405, val_loss=0.0299
Epoch 022: train_loss=0.1479, val_

## Eval per State

In [38]:
def evaluate_per_state_rmse(model, loader, device, case_scale=1.0):
    """
    Returns: rmse_per_state: np.ndarray of shape (S,)
    RMSE is computed over *all* windows, horizons, and output variants.
    If you trained on scaled cases, pass case_scale to get RMSE in original units.
    """
    model.eval()
    sum_sq = None   # will become tensor of shape (S,)
    count = 0       # total number of (horizon * v_out * num_windows)

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_true in loader:
            # remove batch dimension (batch_size=1)
            X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
            X_county_seq = X_county_seq.squeeze(0).to(device)  # (T_in, M, F_micro)
            y_true = y_true.squeeze(0).to(device)              # (S, H, v_out)

            y_pred = model(X_state_seq, X_county_seq)          # (S, H, v_out)

            # Undo scaling if needed so RMSE is in original case units
            diff = (y_pred - y_true) * case_scale              # (S, H, v_out)

            se = diff ** 2                                     # (S, H, v_out)
            # sum over horizon + variant dimensions -> per-state sum squared error
            se_per_state = se.sum(dim=(1, 2))                  # (S,)

            if sum_sq is None:
                sum_sq = se_per_state
            else:
                sum_sq += se_per_state

            # each window contributes H * v_out elements per state
            H = y_pred.shape[1]
            V = y_pred.shape[2]
            count += H * V

    mse_per_state = sum_sq / count                             # (S,)
    rmse_per_state = torch.sqrt(mse_per_state)                 # (S,)

    return rmse_per_state.cpu().numpy()


## Baselines

In [41]:
def evaluate_last_value_baseline(loader, device, target_variant_indices, case_scale=1.0):
    """
    Baseline: y_pred(s, h, v) = last observed value in the input window for that state & variant.
    Returns: rmse_per_state: np.ndarray of shape (S,)
    """
    sum_sq = None
    count = 0
    target_variant_indices = torch.as_tensor(target_variant_indices, dtype=torch.long, device=device)

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_true in loader:
            # squeeze batch dim (batch_size = 1)
            X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
            y_true = y_true.squeeze(0).to(device)              # (S, H, v_out)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true.shape[1]
            V = y_true.shape[2]

            # last time step: (S, F_macro)
            last_state = X_state_seq[-1]                       # (S, F_macro)

            # pick only target variants: (S, v_out)
            last_targets = last_state[:, target_variant_indices]   # (S, v_out)

            # repeat across horizon: (S, H, v_out)
            y_pred = last_targets.unsqueeze(1).expand(S, H, V)

            # undo scaling if needed
            diff = (y_pred - y_true) * case_scale              # (S, H, v_out)
            se = diff ** 2
            se_per_state = se.sum(dim=(1, 2))                  # (S,)

            if sum_sq is None:
                sum_sq = se_per_state
            else:
                sum_sq += se_per_state

            count += H * V

    mse_per_state = sum_sq / count
    rmse_per_state = torch.sqrt(mse_per_state)
    return rmse_per_state.cpu().numpy()


def evaluate_moving_average_baseline(loader, device, target_variant_indices, case_scale=1.0, K=7):
    """
    Baseline: y_pred(s, h, v) = mean of last K days in the input window (per state & variant).
    Returns: rmse_per_state: np.ndarray of shape (S,)
    """
    sum_sq = None
    count = 0
    target_variant_indices = torch.as_tensor(target_variant_indices, dtype=torch.long, device=device)

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_true in loader:
            X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
            y_true = y_true.squeeze(0).to(device)              # (S, H, v_out)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true.shape[1]
            V = y_true.shape[2]

            k = min(K, T_in)
            # last k days: (k, S, F_macro)
            last_k = X_state_seq[-k:]                          # (k, S, F_macro)

            # mean over time: (S, F_macro)
            mean_state = last_k.mean(dim=0)

            # pick only target variants: (S, v_out)
            mean_targets = mean_state[:, target_variant_indices]  # (S, v_out)

            # repeat across horizon: (S, H, v_out)
            y_pred = mean_targets.unsqueeze(1).expand(S, H, V)

            diff = (y_pred - y_true) * case_scale              # (S, H, v_out)
            se = diff ** 2
            se_per_state = se.sum(dim=(1, 2))                  # (S,)

            if sum_sq is None:
                sum_sq = se_per_state
            else:
                sum_sq += se_per_state

            count += H * V

    mse_per_state = sum_sq / count
    rmse_per_state = torch.sqrt(mse_per_state)
    return rmse_per_state.cpu().numpy()


def rmse_by_horizon_model(model, loader, device, case_scale=1.0):
    model.eval()
    sum_sq = None  # (H,)
    count = 0
    with torch.no_grad():
        for X_state_seq, X_county_seq, y_true in loader:
            X_state_seq = X_state_seq.squeeze(0).to(device)   # (T_in, S, F_macro)
            y_true = y_true.squeeze(0).to(device)             # (S, H, V)
            y_pred = model(X_state_seq, X_county_seq.squeeze(0).to(device)
                           if X_county_seq is not None else None)

            diff = (y_pred - y_true) * case_scale             # (S, H, V)
            se = diff ** 2                                    # (S, H, V)
            se_h = se.sum(dim=(0, 2))                         # (H,)

            if sum_sq is None:
                sum_sq = se_h
            else:
                sum_sq += se_h

            S = y_true.shape[0]
            V = y_true.shape[2]
            count += S * V

    mse_h = sum_sq / count                                    # (H,)
    return torch.sqrt(mse_h).cpu().numpy()                    # (H,)

def rmse_by_horizon_last(loader, device, target_variant_indices, case_scale=1.0, K=7):
    target_variant_indices = torch.as_tensor(target_variant_indices, dtype=torch.long, device=device)
    sum_sq = None
    count = 0
    with torch.no_grad():
        for X_state_seq, X_county_seq, y_true in loader:
            X_state_seq = X_state_seq.squeeze(0).to(device)   # (T_in, S, F_macro)
            y_true = y_true.squeeze(0).to(device)             # (S, H, V)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true.shape[1]
            V = y_true.shape[2]

            last_state = X_state_seq[-1]                      # (S, F_macro)
            last_targets = last_state[:, target_variant_indices]  # (S, V)
            y_pred = last_targets.unsqueeze(1).expand(S, H, V)

            diff = (y_pred - y_true) * case_scale
            se = diff ** 2
            se_h = se.sum(dim=(0, 2))                         # (H,)

            if sum_sq is None:
                sum_sq = se_h
            else:
                sum_sq += se_h

            count += S * V

    mse_h = sum_sq / count
    return torch.sqrt(mse_h).cpu().numpy()


In [40]:
case_scale = 1000.0  # or 1.0 if you didn't scale cases
device = next(model.parameters()).device

# Model RMSE
rmse_model = evaluate_per_state_rmse(model, val_loader, device, case_scale=case_scale)

# Baselines
rmse_last = evaluate_last_value_baseline(val_loader, device, target_variant_indices, case_scale=case_scale)
rmse_ma   = evaluate_moving_average_baseline(val_loader, device, target_variant_indices, case_scale=case_scale, K=7)

print("Overall mean RMSE (model):      ", rmse_model.mean())
print("Overall mean RMSE (last value):", rmse_last.mean())
print("Overall mean RMSE (7-day MA):  ", rmse_ma.mean())

print("\nPer-state comparison (first few):")
for s_idx, s_fips in enumerate(state_fips_list[:10]):  # or all of them
    print(
        f"State {s_fips}: "
        f"model={rmse_model[s_idx]:8.1f}, "
        f"last={rmse_last[s_idx]:8.1f}, "
        f"ma={rmse_ma[s_idx]:8.1f}"
    )


Overall mean RMSE (model):       113.54932
Overall mean RMSE (last value): 98.323654
Overall mean RMSE (7-day MA):   107.30355

Per-state comparison (first few):
State 01: model=   168.4, last=   165.9, ma=   157.5
State 04: model=   102.9, last=    95.4, ma=   116.8
State 05: model=    36.7, last=    24.2, ma=    34.1
State 06: model=   409.3, last=   419.9, ma=   525.9
State 08: model=    60.5, last=    52.9, ma=    58.7
State 09: model=    53.9, last=    40.8, ma=    42.5
State 10: model=    31.2, last=    10.3, ma=    11.3
State 11: model=    28.7, last=     8.6, ma=     8.7
State 12: model=   793.8, last=   743.0, ma=   743.9
State 13: model=   114.2, last=   113.7, ma=   132.4


In [42]:
rmse_h_model = rmse_by_horizon_model(model, val_loader, device, case_scale)
rmse_h_last  = rmse_by_horizon_last(val_loader, device, target_variant_indices, case_scale)

for d in range(len(rmse_h_model)):
    print(f"Day {d+1}: model={rmse_h_model[d]:.2f}, last={rmse_h_last[d]:.2f}")


Day 1: model=111.76, last=90.93
Day 2: model=133.15, last=116.73
Day 3: model=152.03, last=137.73
Day 4: model=169.90, last=156.79
Day 5: model=190.17, last=175.51
Day 6: model=208.59, last=192.67
Day 7: model=224.34, last=210.79
