### README

The model currently only has GRU active, the msst_classes file has the micro graphs turned off in the fusion module, and the only used features are the variants, not the exogenous covariates.

In [1]:
import pandas as pd
import pickle
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse
import numpy as np
from math import prod
from torch.utils.data import Dataset
import msst_classes 
from torch.utils.data import DataLoader, Subset



## Load State Data

In [2]:
state_df = pd.read_csv("../processed data/state_level/daily_covariates_state_level.csv", parse_dates=['date'], dtype={'location': str})
state_df.location = state_df.location.str.zfill(2)
state_df.head()

Unnamed: 0,date,location,people_vaccinated,people_fully_vaccinated,school_closing,workplace_closing,cancel_events,gatherings_restrictions,transport_closing,stay_home_restrictions,...,population,Alpha,Beta,Delta,Epsilon,Gamma,Iota,Omicron,deaths,Other
0,2021-01-01,1,53829.0,272.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185,0.0,0.0,0.0,0.0,0.0,0.0,0.0,45,4521.0
1,2021-01-02,1,54052.0,281.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,3711.0
2,2021-01-03,1,55046.0,284.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1,2476.0
3,2021-01-04,1,61092.0,402.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5,2161.0
4,2021-01-05,1,69520.0,844.0,2.0,1.0,1.0,0.0,0.0,1.0,...,4903185,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8,5498.0


In [3]:
state_covs = pd.read_csv("../processed data/state_level/state_level_characteristics.csv", dtype={'state_fips':str}, index_col=0)
state_covs.head()

Unnamed: 0,state_fips,Population,median_age,Income,Density_per_mile
0,1,5157699,39.6,65560,101.0
1,4,7582384,39.4,84700,65.0
2,5,3088354,39.1,64840,59.0
3,6,39431263,38.4,100600,250.0
4,8,5957494,38.0,106500,57.0


In [4]:
state_df = state_df.merge(state_covs, left_on='location', right_on='state_fips', how='inner').drop(['population', 'deaths', 'state_fips'], axis=1)

In [5]:
state_fips_list = sorted(state_df['location'].unique())
state_index = {fips: i for i, fips in enumerate(state_fips_list)}
S = len(state_fips_list)

In [6]:
all_dates = pd.date_range(state_df['date'].min(), state_df['date'].max(), freq='D')

In [7]:
state_df = state_df.sort_values(['date', 'location'])

In [8]:
state_df = state_df.set_index(['location', 'date'])

In [9]:
variant_cols = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Omicron', 'Other']

In [10]:
state_df[variant_cols] /= 1000

In [11]:
# exo_cols = ['people_vaccinated', 'people_fully_vaccinated', 'school_closing',
#        'workplace_closing', 'cancel_events', 'gatherings_restrictions',
#        'transport_closing', 'stay_home_restrictions',
#        'internal_movement_restrictions', 'international_movement_restrictions',
#        'information_campaigns', 'testing_policy', 'contact_tracing',
#        'facial_coverings', 'vaccination_policy', 'elderly_people_protection',
#        'government_response_index', 'stringency_index',
#        'containment_health_index', 'economic_support_index']
# static_cols = ['Population',
#        'median_age', 'Income', 'Density_per_mile']

In [12]:
# exo_cols = ['people_vaccinated', 'people_fully_vaccinated', 'stringency_index']

In [13]:
feature_cols = variant_cols

In [14]:
# feature_cols = variant_cols + exo_cols

In [15]:
# feature_cols = variant_cols + static_cols + exo_cols

In [16]:
T = len(all_dates)
F_macro = len(feature_cols)
X_state = np.zeros((T, S, F_macro), dtype=np.float32)

for s_idx, s_fips in enumerate(state_fips_list):
    sub = state_df.xs(s_fips, level='location')[feature_cols]
    X_state[:, s_idx, :] = sub.to_numpy()


## Build County Adjacency Matrix

In [17]:
with open('../processed data/county level/county_adj_by_state.pkl', 'rb') as file:
    adj_1 = pickle.load(file)

In [18]:
with open('../processed data/county level/county_airport_weights_by_state.pkl', 'rb') as file:
    adj_2 = pickle.load(file)

In [19]:
with open('../processed data/county level/county_highway_weights_by_state.pkl', 'rb') as file:
    adj_3 = pickle.load(file)

In [20]:
county_ids_dict = {}   # state_fips -> ordered list of county FIPS (strings or ints)
county_adj_dict = {}   # state_fips -> np.ndarray (M_s x M_s)

In [21]:
state_fips_list = sorted(adj_1.keys())

In [22]:
def normalize_df(df):
    arr = df.to_numpy(dtype=np.float32)
    max_val = np.percentile(arr, 95)  # or arr.max()
    if max_val > 0:
        arr = arr / max_val
    return arr

In [23]:
for s_fips in state_fips_list:
    df1 = adj_1[s_fips]
    df2 = adj_2.get(s_fips, None)          # may be None
    df3 = adj_3.get(s_fips, None)          # may be None

    # 1) Compute the union of county IDs present in any of the three
    counties = set(df1.index)
    if df2 is not None:
        counties |= set(df2.index)
    if df3 is not None:
        counties |= set(df3.index)
    counties = sorted(counties)

    # 2) Align each DF to this county list; missing df gets all zeros
    df1_al = df1.reindex(index=counties, columns=counties, fill_value=0)

    if df2 is not None:
        df2_al = df2.reindex(index=counties, columns=counties, fill_value=0)
    else:
        df2_al = pd.DataFrame(0.0, index=counties, columns=counties)

    if df3 is not None:
        df3_al = df3.reindex(index=counties, columns=counties, fill_value=0)
    else:
        df3_al = pd.DataFrame(0.0, index=counties, columns=counties)

    # 3) Sum the three weight types
    W1 = normalize_df(df1_al)
    W2 = normalize_df(df2_al) if df2 is not None else np.zeros_like(W1)
    W3 = normalize_df(df3_al) if df3 is not None else np.zeros_like(W1)
    
    # Option A: equal weight
    A_s = W1 + W2 + W3

    # 4) Save ordering + adjacency
    county_ids_dict[s_fips] = counties   # this defines row/col order
    county_adj_dict[s_fips] = A_s        # (M_s, M_s)

## Build global county graph

In [24]:
county_fips_global = []   # global list of county FIPS
state_of_county = []      # same length as county_fips_global

# Also count total counties to size A_county_global later
total_counties = 0
for s_fips in state_fips_list:
    counties_s = county_ids_dict[s_fips]
    total_counties += len(counties_s)

M = total_counties
state_of_county = np.zeros(M, dtype=np.int64)

# Fill county_fips_global + state_of_county and build A_county_global
A_county_global = np.zeros((M, M), dtype=np.float32)

offset = 0
for s_fips in state_fips_list:
    A_s = county_adj_dict[s_fips].astype(np.float32)
    counties_s = county_ids_dict[s_fips]
    n_s = len(counties_s)

    # sanity check
    assert A_s.shape == (n_s, n_s), f"Adjacency size mismatch for state {s_fips}"

    # record county FIPS + state indices for these n_s counties
    s_idx = state_index[s_fips]
    for local_idx, c_fips in enumerate(counties_s):
        global_idx = offset + local_idx
        county_fips_global.append(c_fips)
        state_of_county[global_idx] = s_idx

    # place this state's adjacency in the block-diagonal
    A_county_global[offset:offset+n_s, offset:offset+n_s] = A_s

    offset += n_s

# final sanity
assert offset == M


## Building covariates

In [25]:
county_fips_global = [str(c) for c in county_fips_global]
county_to_global = {c_fips: idx for idx, c_fips in enumerate(county_fips_global)}

M = len(county_fips_global)
T = len(all_dates)
F_micro = len(variant_cols)


In [26]:
county_cases_daily = pd.read_parquet('../processed data/county level/daily_cases_by_county.parquet')

In [27]:
county_cases_dict = dict()
for state in state_fips_list:
    df = county_cases_daily[county_cases_daily.state == state]
    county_cases_dict[state] = df

In [28]:
# with open('../processed data/county level/rolled_county_cases.pkl', 'rb') as file:
#     county_cases_dict = pickle.load(file)

In [29]:
county_cases_dict['06']

Unnamed: 0,fips,date,state,Alpha,Beta,Delta,Epsilon,Gamma,Iota,Omicron,Other
116800,06001,2021-01-01,06,14.925956,0.0,0.0,328.006993,0.0,0.0,0.0,542.067051
116801,06001,2021-01-02,06,17.707980,0.0,0.0,387.804754,0.0,0.0,0.0,637.487267
116802,06001,2021-01-03,06,16.257089,0.0,0.0,379.206049,0.0,0.0,0.0,604.536862
116803,06001,2021-01-04,06,5.860160,0.0,0.0,222.257275,0.0,0.0,0.0,326.882565
116804,06001,2021-01-05,06,8.738730,0.0,0.0,334.351426,0.0,0.0,0.0,482.909844
...,...,...,...,...,...,...,...,...,...,...,...
159865,06999,2022-12-27,06,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000
159866,06999,2022-12-28,06,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000
159867,06999,2022-12-29,06,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000
159868,06999,2022-12-30,06,0.000000,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000


In [30]:
X_county = np.zeros((T, M, F_micro), dtype=np.float32)

for s_fips in state_fips_list:
    if s_fips not in county_cases_dict:
        # If a state has adjacency but no county time series -> stays all zeros
        continue

    df = county_cases_dict[s_fips].copy()
    # ensure types & ordering
    df['date'] = pd.to_datetime(df['date'])
    df['fips'] = df['fips'].astype(str)
    df[variant_cols] /= 1000
    # (optional) sanity: all counties from this state should be in county_ids_dict
    # missing ones will just be ignored or raise a KeyError below
    # assert set(df['county_fips'].unique()).issubset(set(county_ids_dict[s_fips]))

    # group by county and fill each global row
    for c_fips, group in df.groupby('fips'):
        # skip counties that didn't make it into the adjacency for some reason
        if c_fips not in county_to_global:
            continue

        g_idx = county_to_global[c_fips]

        # index by date, select variant columns, align to all_dates
        sub = (
            group
            .set_index('date')[variant_cols]
            .reindex(all_dates)
            .fillna(0.0)
        )

        X_county[:, g_idx, :] = sub.to_numpy(dtype=np.float32)


In [31]:
X_state.shape

(730, 49, 8)

In [32]:
X_county.shape

(730, 3109, 8)

## Build A_state

In [33]:
state_airweights = pd.read_csv('../processed data/state_level/state_level_airport_weights.csv', index_col=0, dtype={'STATEFP':str})

In [34]:
state_borderweights = pd.read_csv('../processed data/state_level/state_level_border_weights.csv', index_col=0, dtype={'state_fips':str})

In [35]:
state_highwayweights = pd.read_csv('../processed data/state_level/state_level_highway_weights.csv', index_col=0, dtype={'state_fips':str})
state_highwayweights.index = state_highwayweights.columns = [fips.zfill(2) for fips in state_highwayweights.columns.to_list()]
state_highwayweights = state_highwayweights.reindex(index=state_fips_list,
                                    columns=state_fips_list,
                                    fill_value=0.0)

In [36]:
state_airweights.index.to_list() == state_borderweights.index.to_list() == state_highwayweights.index.to_list()

True

In [37]:
W1 = normalize_df(state_airweights)
W2 = normalize_df(state_borderweights) 
W3 = normalize_df(state_highwayweights) 
A_state = W1 + W2 + W3

In [38]:
v_out = len(variant_cols)
target_variant_indices = [feature_cols.index(v) for v in variant_cols]

## Train

In [39]:
input_len = 14
horizon = 7

dataset = msst_classes.EpidemicDataset(
    X_state=X_state,
    X_county=X_county,
    input_len=input_len,
    horizon=horizon,
    target_variant_indices=target_variant_indices,
)

N = len(dataset)
train_ratio = 0.8
N_train = int(N * train_ratio)
train_indices = np.arange(0, N_train)
val_indices = np.arange(N_train, N)

# tiny_indices = np.arange(0, 10)   # or even 0..4
# train_ds = Subset(dataset, tiny_indices)
# train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

train_ds = Subset(dataset, train_indices)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)


val_ds   = Subset(dataset, val_indices)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)


In [40]:
device = "cuda" if torch.cuda.is_available() else "cpu"

hidden_gcn_macro = 64
hidden_gcn_micro = 32
macro_out_dim = 32
micro_out_dim = 32
hidden_gru = 64

v_out = len(target_variant_indices)

### GRU Only
# model = msst_classes.GRUOnly(F_macro, hidden_gru=32, horizon=horizon, v_out=v_out).to(device)


### GCN+MacroGCN concatenated
# macro_gcn = msst_classes.MacroGCN(
#     in_dim=F_macro,
#     hidden_dim=hidden_gcn,
#     out_dim=gcn_out_dim,
#     A_state=A_state
# ).to(device)
# 
# model = msst_classes.MacroGCNGRUResidualSkip(
#     macro_gcn=macro_gcn,
#     in_dim=F_macro,
#     gcn_out_dim=gcn_out_dim,
#     hidden_gru=hidden_gru,
#     horizon=horizon,
#     v_out=v_out
# ).to(device)

### Macro+Micro+GRU
macro_gcn = msst_classes.MacroGCN(
    in_dim=F_macro,
    hidden_dim=hidden_gcn_macro,
    out_dim=macro_out_dim,
    A_state=A_state
).to(device)

micro_gcn = msst_classes.MicroGCN(
    in_dim=F_micro,
    hidden_dim=hidden_gcn_micro,
    out_dim=micro_out_dim,
    A_county_global=A_county_global,
    state_of_county=state_of_county  # length M, 0..S-1
).to(device)

model = msst_classes.MacroMicroGCNGRUResidual(
    macro_gcn=macro_gcn,
    micro_gcn=micro_gcn,
    in_dim_state=F_macro,
    macro_out_dim=macro_out_dim,
    micro_out_dim=micro_out_dim,
    hidden_gru=hidden_gru,
    horizon=horizon,
    v_out=len(variant_cols),
).to(device)


In [41]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0)
loss_fn = nn.MSELoss()


In [42]:
num_epochs = 30

for epoch in range(1, num_epochs + 1):
    # --- Train ---
    model.train()
    train_losses = []

    for X_state_seq, X_county_seq, y_resid_true, baseline_full in train_loader:
    
        # Remove batch dimension (batch_size=1)
        X_state_seq   = X_state_seq.squeeze(0).to(device)      # (T_in, S, F_macro)
        X_county_seq  = X_county_seq.squeeze(0).to(device)     # (T_in, M, F_micro)
        y_resid_true  = y_resid_true.squeeze(0).to(device)     # (S, horizon, v_out)
        baseline_full = baseline_full.squeeze(0).to(device)    # (S, horizon, v_out)
    
        optimizer.zero_grad()
    
        # Model predicts residuals
        y_resid_pred = model(X_state_seq, X_county_seq)        # (S, horizon, v_out)
    
        loss = loss_fn(y_resid_pred, y_resid_true)
        loss.backward()
        optimizer.step()
    
        train_losses.append(loss.item())

    # --- Validation ---
    model.eval()
    val_losses = []
    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in val_loader:
            X_state_seq = X_state_seq.squeeze(0).to(device)
            X_county_seq = X_county_seq.squeeze(0).to(device)
            y_resid_true = y_resid_true.squeeze(0).to(device)

            y_resid_pred = model(X_state_seq, X_county_seq)         # (S, H, V)
            y_pred_raw = y_resid_pred + baseline_full.to(device)    # (S, H, V)
            y_true_raw = y_resid_true + baseline_full.to(device)    # or store y_raw separately
            loss = loss_fn(y_pred_raw, y_true_raw)
            val_losses.append(loss.item())

    mean_train = np.mean(train_losses) if train_losses else float("nan")
    mean_val = np.mean(val_losses) if val_losses else float("nan")
    print(f"Epoch {epoch:03d}: train_loss={mean_train:.4f}, val_loss={mean_val:.4f}")


Epoch 001: train_loss=2.8399, val_loss=0.9475
Epoch 002: train_loss=2.7402, val_loss=0.8493
Epoch 003: train_loss=2.5107, val_loss=0.8152
Epoch 004: train_loss=2.2919, val_loss=0.7393
Epoch 005: train_loss=2.1590, val_loss=0.6823
Epoch 006: train_loss=2.0443, val_loss=0.6306
Epoch 007: train_loss=1.9074, val_loss=0.6450
Epoch 008: train_loss=1.8332, val_loss=0.5874
Epoch 009: train_loss=1.7413, val_loss=0.5861
Epoch 010: train_loss=1.6950, val_loss=0.7284
Epoch 011: train_loss=1.5850, val_loss=0.5226
Epoch 012: train_loss=1.5139, val_loss=0.5305
Epoch 013: train_loss=1.4690, val_loss=0.5234
Epoch 014: train_loss=1.4105, val_loss=0.5955
Epoch 015: train_loss=1.4192, val_loss=0.5279
Epoch 016: train_loss=1.2693, val_loss=0.4929
Epoch 017: train_loss=1.2959, val_loss=0.4877
Epoch 018: train_loss=1.1899, val_loss=0.5904
Epoch 019: train_loss=1.1540, val_loss=0.5131
Epoch 020: train_loss=1.0841, val_loss=0.4749
Epoch 021: train_loss=1.0296, val_loss=0.5455
Epoch 022: train_loss=1.0429, val_

## Eval (Daily)

In [43]:
### Global RMSE by variant, then summed
def overall_rmse_model_from_residual(model, loader, device, case_scale=1.0):
    """
    Global RMSE for the model in raw daily-count units.

    Assumes:
      - loader yields (X_state_seq, X_county_seq, y_resid_true, baseline_full)
      - model outputs y_resid_pred with same shape as y_resid_true: (S, H, V)
    """
    model.eval()
    sum_sq = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)
            X_county_seq  = X_county_seq.squeeze(0).to(device)
            y_resid_true  = y_resid_true.squeeze(0).to(device)   # (S, H, V)

            y_resid_pred = model(X_state_seq, X_county_seq)     # (S, H, V)

            # diff in raw units (baseline cancels in residual space)
            diff = (y_resid_pred - y_resid_true) * case_scale   # (S, H, V)
            se = diff ** 2

            sum_sq += se.sum().item()
            count  += se.numel()

    return (sum_sq / count) ** 0.5

def overall_rmse_ma_baseline_from_residual(loader, device, case_scale=1.0):
    """
    Global RMSE for the MA baseline in raw daily-count units.

    Baseline: y_resid_pred = 0  => forecast = baseline_full (raw space).
    """
    sum_sq = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            y_resid_true = y_resid_true.squeeze(0).to(device)   # (S, H, V)

            # diff_raw = 0 - y_resid_true, scaled back
            diff = (-y_resid_true) * case_scale                 # (S, H, V)
            se = diff ** 2

            sum_sq += se.sum().item()
            count  += se.numel()

    return (sum_sq / count) ** 0.5

def overall_rmse_last_value_baseline_from_residual(loader, device, target_variant_indices, case_scale=1.0):
    """
    Global RMSE for the last-value baseline in raw daily-count units.

    Baseline: for each state & variant, predict all H future days equal to
    the last observed daily count in the input window.

    Assumes loader yields (X_state_seq, X_county_seq, y_resid_true, baseline_full),
    and X_state_seq includes variant columns at indices target_variant_indices.
    """
    target_variant_indices = torch.as_tensor(
        target_variant_indices, dtype=torch.long, device=device
    )

    sum_sq = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
            y_resid_true  = y_resid_true.squeeze(0).to(device)   # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)  # (S, H, V)

            # True raw future counts
            y_true_raw = y_resid_true + baseline_full            # (S, H, V)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true_raw.shape[1]
            V = y_true_raw.shape[2]

            # Last observed daily counts from input window
            last_state = X_state_seq[-1]                         # (S, F_macro)
            last_targets = last_state[:, target_variant_indices] # (S, V)

            # Repeat across horizon
            last_full = last_targets.unsqueeze(1).expand(S, H, V)  # (S, H, V)

            diff = (last_full - y_true_raw) * case_scale         # (S, H, V)
            se = diff ** 2

            sum_sq += se.sum().item()
            count  += se.numel()

    return (sum_sq / count) ** 0.5


In [44]:
### RMSE per day
def rmse_by_day_model(model, loader, device, case_scale=1.0):
    """
    RMSE by forecast day for the model, in *raw* daily count units.

    Assumes loader yields:
        (X_state_seq, X_county_seq, y_resid_true, baseline_full)
    and model outputs y_resid_pred with same shape as y_resid_true: (S, H, V).
    """
    model.eval()
    sum_sq_per_day = None   # (H,)
    count_per_day = None    # (H,)

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            # Remove batch dim (batch_size=1)
            X_state_seq   = X_state_seq.squeeze(0).to(device)
            X_county_seq  = X_county_seq.squeeze(0).to(device)
            y_resid_true  = y_resid_true.squeeze(0).to(device)  # (S, H, V)

            y_resid_pred = model(X_state_seq, X_county_seq)    # (S, H, V)

            # difference in raw units (baseline cancels out)
            diff = (y_resid_pred - y_resid_true) * case_scale   # (S, H, V)
            se = diff ** 2                                      # (S, H, V)

            S, H, V = se.shape
            se_day = se.sum(dim=(0, 2))                         # (H,) sum over states & variants
            n_day  = S * V                                      # per window

            if sum_sq_per_day is None:
                sum_sq_per_day = se_day.cpu()
                count_per_day  = torch.full((H,), n_day, dtype=torch.long)
            else:
                sum_sq_per_day += se_day.cpu()
                count_per_day  += n_day

    mse_per_day = sum_sq_per_day / count_per_day
    rmse_per_day = torch.sqrt(mse_per_day).numpy()
    return rmse_per_day

def rmse_by_day_ma_baseline(loader, device, case_scale=1.0):
    """
    RMSE by forecast day for the MA baseline built into the Dataset.
    Baseline: y_resid_pred = 0  => forecast = baseline_full (in raw space).

    Uses y_resid_true only; baseline cancels in residual space.
    """
    sum_sq_per_day = None
    count_per_day = None

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            y_resid_true = y_resid_true.squeeze(0).to(device)   # (S, H, V)

            # diff_raw = 0 - y_resid_true, scaled back
            diff = (-y_resid_true) * case_scale                 # (S, H, V)
            se = diff ** 2

            S, H, V = se.shape
            se_day = se.sum(dim=(0, 2))                         # (H,)
            n_day  = S * V

            if sum_sq_per_day is None:
                sum_sq_per_day = se_day.cpu()
                count_per_day  = torch.full((H,), n_day, dtype=torch.long)
            else:
                sum_sq_per_day += se_day.cpu()
                count_per_day  += n_day

    mse_per_day = sum_sq_per_day / count_per_day
    rmse_per_day = torch.sqrt(mse_per_day).numpy()
    return rmse_per_day

def rmse_by_day_last_value_baseline_from_residual(loader, device, target_variant_indices, case_scale=1.0):
    """
    RMSE by forecast day for the last-value baseline in *raw* daily counts.

    Baseline: for each state & variant, predict all H future days equal to
    the last observed daily count in the input window.

    Assumes loader yields (X_state_seq, X_county_seq, y_resid_true, baseline_full),
    where X_state_seq includes the variant columns at indices target_variant_indices.
    """
    target_variant_indices = torch.as_tensor(
        target_variant_indices, dtype=torch.long, device=device
    )

    sum_sq_per_day = None
    count_per_day = None

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
            y_resid_true  = y_resid_true.squeeze(0).to(device)   # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)  # (S, H, V)

            # Reconstruct true raw future counts
            y_true_raw = y_resid_true + baseline_full            # (S, H, V)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true_raw.shape[1]
            V = y_true_raw.shape[2]

            # Last observed daily counts from input window
            last_state = X_state_seq[-1]                         # (S, F_macro)
            last_targets = last_state[:, target_variant_indices] # (S, V)

            # Repeat across horizon
            last_full = last_targets.unsqueeze(1).expand(S, H, V)  # (S, H, V)

            # diff in raw units
            diff = (last_full - y_true_raw) * case_scale         # (S, H, V)
            se = diff ** 2

            se_day = se.sum(dim=(0, 2))                          # (H,)
            n_day  = S * V

            if sum_sq_per_day is None:
                sum_sq_per_day = se_day.cpu()
                count_per_day  = torch.full((H,), n_day, dtype=torch.long)
            else:
                sum_sq_per_day += se_day.cpu()
                count_per_day  += n_day

    mse_per_day = sum_sq_per_day / count_per_day
    rmse_per_day = torch.sqrt(mse_per_day).numpy()
    return rmse_per_day


In [52]:
### Global MAPE
def overall_mape_model_from_residual(model, loader, device, case_scale=1.0, min_cases=0):
    """
    Global MAPE (%) for the model in raw daily-count units.

    Assumes loader yields:
        (X_state_seq, X_county_seq, y_resid_true, baseline_full)
    and model outputs y_resid_pred with shape (S, H, V).

    MAPE is computed over all state×day×variant positions where y_true_raw != 0.
    """
    model.eval()
    sum_pct = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)
            X_county_seq  = X_county_seq.squeeze(0).to(device)
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            y_resid_pred = model(X_state_seq, X_county_seq)       # (S, H, V)

            # Raw counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale
            y_hat_raw  = (y_resid_pred + baseline_full) * case_scale

            diff_abs = (y_hat_raw - y_true_raw).abs()
            denom = y_true_raw.abs()

            mask = denom > min_cases
            if mask.any():
                sum_pct += (diff_abs[mask] / denom[mask]).sum().item()
                count   += mask.sum().item()

    if count == 0:
        return np.nan
    return 100.0 * (sum_pct / count)

def overall_mape_ma_baseline_from_residual(loader, device, case_scale=1.0, min_cases=0):
    """
    Global MAPE (%) for the MA baseline in raw daily-count units.

    Baseline: forecast_raw = baseline_full * case_scale.
    """
    sum_pct = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            y_true_raw = (y_resid_true + baseline_full) * case_scale
            y_hat_raw  = baseline_full * case_scale

            diff_abs = (y_hat_raw - y_true_raw).abs()
            denom = y_true_raw.abs()

            mask = denom > min_cases
            if mask.any():
                sum_pct += (diff_abs[mask] / denom[mask]).sum().item()
                count   += mask.sum().item()

    if count == 0:
        return np.nan
    return 100.0 * (sum_pct / count)

def overall_mape_last_value_baseline_from_residual(loader, device, target_variant_indices, case_scale=1.0, min_cases=0):
    """
    Global MAPE (%) for the last-value baseline in raw daily-count units.

    Baseline: for each state & variant, predict all H future days equal to
    the last observed daily count in the input window.

    Assumes:
      - loader yields (X_state_seq, X_county_seq, y_resid_true, baseline_full)
      - X_state_seq includes variant columns at indices target_variant_indices.
    """
    target_variant_indices = torch.as_tensor(
        target_variant_indices, dtype=torch.long, device=device
    )

    sum_pct = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)     # (T_in, S, F_macro)
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            # True raw future counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true_raw.shape[1]
            V = y_true_raw.shape[2]

            # Last observed daily counts for the variants of interest
            last_state = X_state_seq[-1]                          # (S, F_macro)
            last_targets = last_state[:, target_variant_indices]  # (S, V)

            # Repeat across horizon
            y_hat_raw = last_targets.unsqueeze(1).expand(S, H, V) * case_scale  # (S, H, V)

            diff_abs = (y_hat_raw - y_true_raw).abs()
            denom = y_true_raw.abs()

            mask = denom > min_cases
            if mask.any():
                sum_pct += (diff_abs[mask] / denom[mask]).sum().item()
                count   += mask.sum().item()

    if count == 0:
        return np.nan
    return 100.0 * (sum_pct / count)


In [58]:
### RMSE for total cases 
def overall_rmse_model_total_from_residual(model, loader, device, case_scale=1.0):
    """
    Global RMSE (raw units) for the MODEL on TOTAL daily cases.

    Steps per batch:
      - y_resid_true, baseline_full -> reconstruct per-variant true counts
      - model(...) -> per-variant residual preds -> per-variant pred counts
      - sum over variants -> totals
      - compute RMSE over all states × horizon

    Assumes loader yields:
      (X_state_seq, X_county_seq, y_resid_true, baseline_full)
    with shapes:
      y_resid_true: (1, S, H, V) or (S, H, V) before squeeze
      baseline_full: same
    """
    model.eval()
    sum_sq = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)
            X_county_seq  = X_county_seq.squeeze(0).to(device)
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            # Model residual preds
            y_resid_pred = model(X_state_seq, X_county_seq)       # (S, H, V)

            # Reconstruct raw per-variant counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)
            y_hat_raw  = (y_resid_pred + baseline_full) * case_scale  # (S, H, V)

            # Sum over variants to get totals
            y_true_total = y_true_raw.sum(dim=2)  # (S, H)
            y_hat_total  = y_hat_raw.sum(dim=2)   # (S, H)

            diff = y_hat_total - y_true_total     # (S, H)
            se = diff ** 2

            sum_sq += se.sum().item()
            count  += se.numel()

    return (sum_sq / count) ** 0.5

def overall_rmse_ma_total_from_residual(loader, device, case_scale=1.0):
    """
    Global RMSE (raw units) for the MA baseline on TOTAL daily cases.

    Baseline per variant: forecast_raw = baseline_full * case_scale.
    Totals: sum over variants.
    """
    sum_sq = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            # True per-variant raw counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)
            # Baseline per-variant raw preds
            y_hat_raw  = baseline_full * case_scale                   # (S, H, V)

            y_true_total = y_true_raw.sum(dim=2)  # (S, H)
            y_hat_total  = y_hat_raw.sum(dim=2)   # (S, H)

            diff = y_hat_total - y_true_total
            se = diff ** 2

            sum_sq += se.sum().item()
            count  += se.numel()

    return (sum_sq / count) ** 0.5

def overall_rmse_last_total_from_residual(loader, device, target_variant_indices, case_scale=1.0):
    """
    Global RMSE (raw units) for the last-value baseline on TOTAL daily cases.

    Baseline: for each state,
      - take last observed per-variant daily counts from input window,
      - sum over variants to get a total,
      - predict that same total for all H forecast days.

    Assumes:
      - loader yields (X_state_seq, X_county_seq, y_resid_true, baseline_full)
      - X_state_seq includes variant columns at indices target_variant_indices.
    """
    target_variant_indices = torch.as_tensor(
        target_variant_indices, dtype=torch.long, device=device
    )

    sum_sq = 0.0
    count = 0

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)     # (T_in, S, F_macro)
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            # True per-variant raw counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)
            y_true_total = y_true_raw.sum(dim=2)                      # (S, H)

            T_in, S, F_macro = X_state_seq.shape
            H = y_true_total.shape[1]
            V = y_true_raw.shape[2]

            # Last observed per-variant daily counts from input window
            last_state = X_state_seq[-1]                              # (S, F_macro)
            last_targets = last_state[:, target_variant_indices]      # (S, V)
            last_targets_raw = last_targets * case_scale              # (S, V)

            # Total last-value per state
            last_total = last_targets_raw.sum(dim=1)                  # (S,)

            # Repeat across horizon
            y_hat_total = last_total.unsqueeze(1).expand(S, H)        # (S, H)

            diff = y_hat_total - y_true_total
            se = diff ** 2

            sum_sq += se.sum().item()
            count  += se.numel()

    return (sum_sq / count) ** 0.5


In [60]:
### RMSE by state for total cases
def rmse_by_state_total_model_from_residual(model, loader, device, case_scale=1.0):
    """
    Per-state RMSE (raw units) for the MODEL on TOTAL daily cases.

    Returns: np.ndarray of shape (S,), where S is the #states, in the
    same order as your state dimension in X_state_seq / y_resid_true.
    """
    model.eval()
    sum_sq_state = None
    count_state = None

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)
            X_county_seq  = X_county_seq.squeeze(0).to(device)
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            S, H, V = y_resid_true.shape
            if sum_sq_state is None:
                sum_sq_state = torch.zeros(S, device=device)
                count_state  = torch.zeros(S, device=device)

            # Model residual preds
            y_resid_pred = model(X_state_seq, X_county_seq)       # (S, H, V)

            # Raw per-variant counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)
            y_hat_raw  = (y_resid_pred + baseline_full) * case_scale  # (S, H, V)

            # Totals per state & horizon
            y_true_total = y_true_raw.sum(dim=2)  # (S, H)
            y_hat_total  = y_hat_raw.sum(dim=2)   # (S, H)

            diff = y_hat_total - y_true_total     # (S, H)
            se = diff ** 2                        # (S, H)

            sum_sq_state += se.sum(dim=1)         # sum over horizon -> (S,)
            count_state  += H

    rmse_state = torch.sqrt(sum_sq_state / count_state.clamp(min=1)).cpu().numpy()
    return rmse_state

def rmse_by_state_total_ma_from_residual(loader, device, case_scale=1.0):
    """
    Per-state RMSE (raw units) for the MA baseline on TOTAL daily cases.
    """
    sum_sq_state = None
    count_state = None

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            S, H, V = y_resid_true.shape
            if sum_sq_state is None:
                sum_sq_state = torch.zeros(S, device=device)
                count_state  = torch.zeros(S, device=device)

            # Raw per-variant counts
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)
            y_hat_raw  = baseline_full * case_scale                   # (S, H, V)

            y_true_total = y_true_raw.sum(dim=2)  # (S, H)
            y_hat_total  = y_hat_raw.sum(dim=2)   # (S, H)

            diff = y_hat_total - y_true_total
            se = diff ** 2

            sum_sq_state += se.sum(dim=1)
            count_state  += H

    rmse_state = torch.sqrt(sum_sq_state / count_state.clamp(min=1)).cpu().numpy()
    return rmse_state

def rmse_by_state_total_last_from_residual(loader, device, target_variant_indices, case_scale=1.0):
    """
    Per-state RMSE (raw units) for the last-value baseline on TOTAL daily cases.

    Baseline: for each state, total forecast = sum of last observed per-variant
    daily counts in the input window, repeated across all H forecast days.
    """
    target_variant_indices = torch.as_tensor(
        target_variant_indices, dtype=torch.long, device=device
    )

    sum_sq_state = None
    count_state = None

    with torch.no_grad():
        for X_state_seq, X_county_seq, y_resid_true, baseline_full in loader:
            X_state_seq   = X_state_seq.squeeze(0).to(device)     # (T_in, S, F_macro)
            y_resid_true  = y_resid_true.squeeze(0).to(device)    # (S, H, V)
            baseline_full = baseline_full.squeeze(0).to(device)   # (S, H, V)

            S, H, V = y_resid_true.shape
            if sum_sq_state is None:
                sum_sq_state = torch.zeros(S, device=device)
                count_state  = torch.zeros(S, device=device)

            # True totals
            y_true_raw = (y_resid_true + baseline_full) * case_scale  # (S, H, V)
            y_true_total = y_true_raw.sum(dim=2)                      # (S, H)

            T_in, S2, F_macro = X_state_seq.shape
            assert S2 == S

            # Last observed per-variant daily counts
            last_state = X_state_seq[-1]                             # (S, F_macro)
            last_targets = last_state[:, target_variant_indices]     # (S, V)
            last_targets_raw = last_targets * case_scale             # (S, V)

            # Total last-value per state
            last_total = last_targets_raw.sum(dim=1)                 # (S,)

            # Repeat across horizon
            y_hat_total = last_total.unsqueeze(1).expand(S, H)       # (S, H)

            diff = y_hat_total - y_true_total
            se = diff ** 2

            sum_sq_state += se.sum(dim=1)
            count_state  += H

    rmse_state = torch.sqrt(sum_sq_state / count_state.clamp(min=1)).cpu().numpy()
    return rmse_state


In [45]:
case_scale=1000
rmse_model_day = rmse_by_day_model(model, val_loader, device, case_scale=case_scale)
rmse_ma_day    = rmse_by_day_ma_baseline(val_loader, device, case_scale=case_scale)
rmse_last_day  = rmse_by_day_last_value_baseline_from_residual(
    val_loader, device, target_variant_indices, case_scale=case_scale
)

for d in range(len(rmse_model_day)):
    print(
        f"Day {d+1}: model={rmse_model_day[d]:7.2f}, "
        f"MA={rmse_ma_day[d]:7.2f}, "
        f"last={rmse_last_day[d]:7.2f}"
    )
rmse_model_global = overall_rmse_model_from_residual(model, val_loader, device, case_scale)
rmse_ma_global    = overall_rmse_ma_baseline_from_residual(val_loader, device, case_scale)
rmse_last_global  = overall_rmse_last_value_baseline_from_residual(
    val_loader, device, target_variant_indices, case_scale
)

print("Global RMSE (model):      ", rmse_model_global)
print("Global RMSE (MA baseline):", rmse_ma_global)
print("Global RMSE (last value): ", rmse_last_global)


Day 1: model= 723.30, MA=1024.32, last=1528.55
Day 2: model= 752.84, MA=1015.24, last=1560.28
Day 3: model= 714.74, MA=1014.22, last=1492.44
Day 4: model= 762.93, MA=1015.57, last=1495.43
Day 5: model= 715.68, MA=1018.60, last=1559.44
Day 6: model= 727.88, MA=1007.76, last=1503.08
Day 7: model= 694.57, MA=1000.90, last= 646.39
Global RMSE (model):       727.74196639087
Global RMSE (MA baseline): 1013.8242516281312
Global RMSE (last value):  1431.4600919432226


In [55]:
min_cases = 100

mape_model = overall_mape_model_from_residual(model, val_loader, device, case_scale, min_cases)
mape_ma    = overall_mape_ma_baseline_from_residual(val_loader, device, case_scale, min_cases)
mape_last  = overall_mape_last_value_baseline_from_residual(val_loader, device, target_variant_indices, case_scale, min_cases)

print(f"MAPE Thresholded at {min_cases} (model):      ", mape_model)
print(f"MAPE Thresholded at {min_cases} (MA baseline):", mape_ma)
print(f"MAPE Thresholded at {min_cases} (last value): ", mape_last)


MAPE Thresholded at 100 (model):       79.32596404106931
MAPE Thresholded at 100 (MA baseline): 76.62154538740546
MAPE Thresholded at 100 (last value):  105.43070804415895


In [59]:
rmse_model_total = overall_rmse_model_total_from_residual(model, val_loader, device, case_scale)
rmse_ma_total    = overall_rmse_ma_total_from_residual(val_loader, device, case_scale)
rmse_last_total  = overall_rmse_last_total_from_residual(
    val_loader, device, target_variant_indices, case_scale
)

print("Global RMSE on TOTAL cases (model):      ", rmse_model_total)
print("Global RMSE on TOTAL cases (MA baseline):", rmse_ma_total)
print("Global RMSE on TOTAL cases (last value): ", rmse_last_total)


Global RMSE on TOTAL cases (model):       2123.238342545057
Global RMSE on TOTAL cases (MA baseline): 2951.7399352207817
Global RMSE on TOTAL cases (last value):  4158.839079426377


In [61]:
rmse_model_state_total = rmse_by_state_total_model_from_residual(model, val_loader, device, case_scale)
rmse_ma_state_total    = rmse_by_state_total_ma_from_residual(val_loader, device, case_scale)
rmse_last_state_total  = rmse_by_state_total_last_from_residual(val_loader, device, target_variant_indices, case_scale)

# If you have a list of state FIPS codes in the same order:
for fips, r_model, r_ma, r_last in zip(state_fips_list, rmse_model_state_total,
                                       rmse_ma_state_total, rmse_last_state_total):
    print(f"State {fips}: model={r_model:.1f}, MA={r_ma:.1f}, last={r_last:.1f}")


State 01: model=2085.4, MA=2222.8, last=3158.7
State 04: model=1551.5, MA=3124.0, last=4442.6
State 05: model=316.0, MA=255.3, last=340.7
State 06: model=8635.8, MA=11589.1, last=16446.3
State 08: model=720.4, MA=702.9, last=976.8
State 09: model=505.2, MA=452.3, last=632.0
State 10: model=247.6, MA=209.7, last=287.8
State 11: model=202.7, MA=146.0, last=207.3
State 12: model=6588.2, MA=6222.4, last=8396.5
State 13: model=1977.9, MA=4042.3, last=5796.1
State 16: model=362.3, MA=359.2, last=508.7
State 17: model=1713.0, MA=2350.3, last=3317.7
State 18: model=1517.5, MA=2194.6, last=3105.0
State 19: model=689.5, MA=779.4, last=1101.9
State 20: model=789.2, MA=1263.7, last=1799.7
State 21: model=2105.0, MA=2960.9, last=4247.0
State 22: model=2013.2, MA=2345.5, last=3217.2
State 23: model=157.5, MA=161.3, last=224.1
State 24: model=704.8, MA=766.1, last=1072.4
State 25: model=1417.5, MA=3004.4, last=4232.7
State 26: model=2587.6, MA=3749.6, last=5260.8
State 27: model=1343.8, MA=2338.2, la

In [66]:
beats_all = 0
beats_ma = 0
beats_last = 0
for i, j, k in zip(rmse_model_state_total, rmse_ma_state_total, rmse_last_state_total):
    if (i < j) and (i < k):
        beats_all += 1
        beats_ma += 1
        beats_last += 1
    elif i < j:
        beats_ma += 1
    elif i < k:
        beats_last += 1
beats_all/len(rmse_model_state_total), beats_ma/len(rmse_model_state_total), beats_last/len(rmse_model_state_total)

(0.7346938775510204, 0.7346938775510204, 1.0)

#### Global RMSE results
GRU+GCN_Macro = 724, GRU = 807, GRU+GCN_Macro/Micro = 727

## Eval per State (Rolling)

In [46]:
# def evaluate_per_state_rmse(model, loader, device, case_scale=1.0):
#     """
#     Returns: rmse_per_state: np.ndarray of shape (S,)
#     RMSE is computed over *all* windows, horizons, and output variants.
#     If you trained on scaled cases, pass case_scale to get RMSE in original units.
#     """
#     model.eval()
#     sum_sq = None   # will become tensor of shape (S,)
#     count = 0       # total number of (horizon * v_out * num_windows)
# 
#     with torch.no_grad():
#         for X_state_seq, X_county_seq, y_true in loader:
#             # remove batch dimension (batch_size=1)
#             X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
#             X_county_seq = X_county_seq.squeeze(0).to(device)  # (T_in, M, F_micro)
#             y_true = y_true.squeeze(0).to(device)              # (S, H, v_out)
# 
#             y_pred = model(X_state_seq, X_county_seq)          # (S, H, v_out)
# 
#             # Undo scaling if needed so RMSE is in original case units
#             diff = (y_pred - y_true) * case_scale              # (S, H, v_out)
# 
#             se = diff ** 2                                     # (S, H, v_out)
#             # sum over horizon + variant dimensions -> per-state sum squared error
#             se_per_state = se.sum(dim=(1, 2))                  # (S,)
# 
#             if sum_sq is None:
#                 sum_sq = se_per_state
#             else:
#                 sum_sq += se_per_state
# 
#             # each window contributes H * v_out elements per state
#             H = y_pred.shape[1]
#             V = y_pred.shape[2]
#             count += H * V
# 
#     mse_per_state = sum_sq / count                             # (S,)
#     rmse_per_state = torch.sqrt(mse_per_state)                 # (S,)
# 
#     return rmse_per_state.cpu().numpy()
# 
# def evaluate_last_value_baseline(loader, device, target_variant_indices, case_scale=1.0):
#     """
#     Baseline: y_pred(s, h, v) = last observed value in the input window for that state & variant.
#     Returns: rmse_per_state: np.ndarray of shape (S,)
#     """
#     sum_sq = None
#     count = 0
#     target_variant_indices = torch.as_tensor(target_variant_indices, dtype=torch.long, device=device)
# 
#     with torch.no_grad():
#         for X_state_seq, X_county_seq, y_true in loader:
#             # squeeze batch dim (batch_size = 1)
#             X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
#             y_true = y_true.squeeze(0).to(device)              # (S, H, v_out)
# 
#             T_in, S, F_macro = X_state_seq.shape
#             H = y_true.shape[1]
#             V = y_true.shape[2]
# 
#             # last time step: (S, F_macro)
#             last_state = X_state_seq[-1]                       # (S, F_macro)
# 
#             # pick only target variants: (S, v_out)
#             last_targets = last_state[:, target_variant_indices]   # (S, v_out)
# 
#             # repeat across horizon: (S, H, v_out)
#             y_pred = last_targets.unsqueeze(1).expand(S, H, V)
# 
#             # undo scaling if needed
#             diff = (y_pred - y_true) * case_scale              # (S, H, v_out)
#             se = diff ** 2
#             se_per_state = se.sum(dim=(1, 2))                  # (S,)
# 
#             if sum_sq is None:
#                 sum_sq = se_per_state
#             else:
#                 sum_sq += se_per_state
# 
#             count += H * V
# 
#     mse_per_state = sum_sq / count
#     rmse_per_state = torch.sqrt(mse_per_state)
#     return rmse_per_state.cpu().numpy()
# 
# def evaluate_moving_average_baseline(loader, device, target_variant_indices, case_scale=1.0, K=7):
#     """
#     Baseline: y_pred(s, h, v) = mean of last K days in the input window (per state & variant).
#     Returns: rmse_per_state: np.ndarray of shape (S,)
#     """
#     sum_sq = None
#     count = 0
#     target_variant_indices = torch.as_tensor(target_variant_indices, dtype=torch.long, device=device)
# 
#     with torch.no_grad():
#         for X_state_seq, X_county_seq, y_true in loader:
#             X_state_seq = X_state_seq.squeeze(0).to(device)    # (T_in, S, F_macro)
#             y_true = y_true.squeeze(0).to(device)              # (S, H, v_out)
# 
#             T_in, S, F_macro = X_state_seq.shape
#             H = y_true.shape[1]
#             V = y_true.shape[2]
# 
#             k = min(K, T_in)
#             # last k days: (k, S, F_macro)
#             last_k = X_state_seq[-k:]                          # (k, S, F_macro)
# 
#             # mean over time: (S, F_macro)
#             mean_state = last_k.mean(dim=0)
# 
#             # pick only target variants: (S, v_out)
#             mean_targets = mean_state[:, target_variant_indices]  # (S, v_out)
# 
#             # repeat across horizon: (S, H, v_out)
#             y_pred = mean_targets.unsqueeze(1).expand(S, H, V)
# 
#             diff = (y_pred - y_true) * case_scale              # (S, H, v_out)
#             se = diff ** 2
#             se_per_state = se.sum(dim=(1, 2))                  # (S,)
# 
#             if sum_sq is None:
#                 sum_sq = se_per_state
#             else:
#                 sum_sq += se_per_state
# 
#             count += H * V
# 
#     mse_per_state = sum_sq / count
#     rmse_per_state = torch.sqrt(mse_per_state)
#     return rmse_per_state.cpu().numpy()
# 
# def rmse_by_horizon_model(model, loader, device, case_scale=1.0):
#     model.eval()
#     sum_sq = None  # (H,)
#     count = 0
#     with torch.no_grad():
#         for X_state_seq, X_county_seq, y_true in loader:
#             X_state_seq = X_state_seq.squeeze(0).to(device)   # (T_in, S, F_macro)
#             y_true = y_true.squeeze(0).to(device)             # (S, H, V)
#             y_pred = model(X_state_seq, X_county_seq.squeeze(0).to(device)
#                            if X_county_seq is not None else None)
# 
#             diff = (y_pred - y_true) * case_scale             # (S, H, V)
#             se = diff ** 2                                    # (S, H, V)
#             se_h = se.sum(dim=(0, 2))                         # (H,)
# 
#             if sum_sq is None:
#                 sum_sq = se_h
#             else:
#                 sum_sq += se_h
# 
#             S = y_true.shape[0]
#             V = y_true.shape[2]
#             count += S * V
# 
#     mse_h = sum_sq / count                                    # (H,)
#     return torch.sqrt(mse_h).cpu().numpy()                    # (H,)
# 
# def rmse_by_horizon_last(loader, device, target_variant_indices, case_scale=1.0, K=7):
#     target_variant_indices = torch.as_tensor(target_variant_indices, dtype=torch.long, device=device)
#     sum_sq = None
#     count = 0
#     with torch.no_grad():
#         for X_state_seq, X_county_seq, y_true in loader:
#             X_state_seq = X_state_seq.squeeze(0).to(device)   # (T_in, S, F_macro)
#             y_true = y_true.squeeze(0).to(device)             # (S, H, V)
# 
#             T_in, S, F_macro = X_state_seq.shape
#             H = y_true.shape[1]
#             V = y_true.shape[2]
# 
#             last_state = X_state_seq[-1]                      # (S, F_macro)
#             last_targets = last_state[:, target_variant_indices]  # (S, V)
#             y_pred = last_targets.unsqueeze(1).expand(S, H, V)
# 
#             diff = (y_pred - y_true) * case_scale
#             se = diff ** 2
#             se_h = se.sum(dim=(0, 2))                         # (H,)
# 
#             if sum_sq is None:
#                 sum_sq = se_h
#             else:
#                 sum_sq += se_h
# 
#             count += S * V
# 
#     mse_h = sum_sq / count
#     return torch.sqrt(mse_h).cpu().numpy()


In [47]:
# case_scale = 1000.0  # or 1.0 if you didn't scale cases
# device = next(model.parameters()).device
# 
# # Model RMSE
# rmse_model = evaluate_per_state_rmse(model, val_loader, device, case_scale=case_scale)
# 
# # Baselines
# rmse_last = evaluate_last_value_baseline(val_loader, device, target_variant_indices, case_scale=case_scale)
# rmse_ma   = evaluate_moving_average_baseline(val_loader, device, target_variant_indices, case_scale=case_scale, K=7)
# 
# print("Overall mean RMSE (model):      ", rmse_model.mean())
# print("Overall mean RMSE (last value):", rmse_last.mean())
# print("Overall mean RMSE (7-day MA):  ", rmse_ma.mean())
# 
# print("\nPer-state comparison (first few):")
# for s_idx, s_fips in enumerate(state_fips_list[:10]):  # or all of them
#     print(
#         f"State {s_fips}: "
#         f"model={rmse_model[s_idx]:8.1f}, "
#         f"last={rmse_last[s_idx]:8.1f}, "
#         f"ma={rmse_ma[s_idx]:8.1f}"
#     )


In [48]:
# rmse_h_model = rmse_by_horizon_model(model, val_loader, device, case_scale)
# rmse_h_last  = rmse_by_horizon_last(val_loader, device, target_variant_indices, case_scale)
# 
# for d in range(len(rmse_h_model)):
#     print(f"Day {d+1}: model={rmse_h_model[d]:.2f}, last={rmse_h_last[d]:.2f}")
