In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import timedelta, datetime
import torch

from utils import *
from models import *

In [None]:
epiyear = 2024
epiweek = 46 
ref_date = epiweek_to_dates(epiyear, epiweek).enddate() #Saturday at the end of epiweek

data_dir = '../../data2025/'
results_dir = '../../results2025/'
figures_dir = '../../figures2025/' 

locations_fname = data_dir +"locations.csv"
locations = pd.read_csv(locations_fname)
loc_name2abbr = dict(zip(locations['location_name'], locations['abbreviation']))
locations = locations.set_index('abbreviation')
states = locations.index.values
# states = np.array(["AK", "AL", "AR", "AZ", "CA", "CO", "CT", "DE", "DC", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", 
#           "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", 
#           "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY", "PR", "US"])
num_states = len(states)


In [None]:
new_format = True
download_hosp=False
df_hosp = read_hosp_incidence_data(data_dir, epiyear, epiweek, states, 
                                   new_format=new_format, download=download_hosp, plot=False)

df_ili = read_ili_incidence_data(data_dir, epiyear, epiweek, states, df_hosp, 
                                 smooth=False, scale=True, plot=False)

plot_hosp_with_ili = False
if(plot_hosp_with_ili):
    fig, axs = plt.subplots(nrows=num_states, ncols=1, figsize=(7, 2 * num_states), sharex=True)
    for i, state in enumerate(states):
        ax = axs[i] if num_states > 1 else axs  # Handle single subplot case
        ax.plot(df_hosp['date'], df_hosp[state], label='Hospitalizations', color='blue')
        ax.plot(df_ili['date'], df_ili[state], label='Scaled ILI', color='red')
        ax.set_title(f"{state} - Hospitalizations vs. ILI (scaled)")
        ax.set_ylabel("Cases")
        ax.legend()
    plt.xlabel("Date")
    plt.tight_layout()
    plt.show()

In [None]:
switch_epiyear = 2022
switch_epiweek = 26
switch_date = pd.to_datetime(epiweek_to_dates(switch_epiyear, switch_epiweek).enddate())

df1 = df_ili[df_ili.date < switch_date]
df2 = df_hosp[(df_hosp.date >= switch_date)]
df_hosp_ex = pd.concat([df1, df2]).reset_index()
df_hosp_ex_long = pd.melt(df_hosp_ex,id_vars=['date','year','week'],value_vars=df_hosp_ex.columns[3:],var_name='state',value_name='cases')
plot_hosp_ex = False
if(plot_hosp_ex):
    g = sns.FacetGrid(df_hosp_ex_long, col="state", col_wrap=5, hue="state", sharey=False, sharex=True, height=3, aspect=1.33)
    g.map(sns.lineplot, "date", "cases")
    [plt.setp(ax.get_xticklabels(), rotation=90) for ax in g.axes.flat]
    plt.subplots_adjust(hspace=0.4, wspace=0.4)

In [None]:
num_samples = 1000
alpha_vals = [0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 1]
quantiles = np.append(np.append([0.01,0.025],np.arange(0.05,0.95+0.05,0.050)),[0.975,0.99])
weeks_to_predict = 4

past_weeks = 0 # Set this to produce past weeks forecasts
ref_dates = [pd.Timestamp(ref_date) + timedelta(weeks=w) for w in range(-past_weeks,1)]

# For reproducibility
np.random.seed(230098)
torch.manual_seed(42095703)
torch.set_float32_matmul_precision('medium')
model_list_local, model_list_global = get_forecasting_models(quantiles)

In [None]:
sel_lab_col = 'percent_positive'
df_lab = read_lab_selected_data(data_dir, epiyear, epiweek, states, loc_name2abbr, sel_lab_col, plot=False)
#add missing past values
missing_dates = df_hosp_ex[~df_hosp_ex['date'].isin(df_lab['date'])]
missing_dates_only = missing_dates[['date']].copy()
for col in df_lab.columns:
    if col != 'date':
        missing_dates_only[col] = pd.NA
df_lab = pd.concat([df_lab, missing_dates_only], ignore_index=True)
df_lab = df_lab.sort_values(by='date').reset_index(drop=True)
#add some future values
for i in range(1,5):
    next = pd.Timestamp(epiweek_to_dates(epiyear, epiweek+i).enddate())
    next_last_year = pd.Timestamp(epiweek_to_dates(epiyear-1, epiweek+i).enddate())
    next_last_year_vals = df_lab[df_lab['date']==next_last_year].values
    df_lab.loc[len(df_lab)] = [next] + next_last_year_vals[0,1:].tolist()

#Set to 0 values for states with missing lab for now
df_lab['RI'] = 0 #pd.NA
df_lab['PR'] = 0 #pd.NA
df_lab['DC'] = 0 #pd.NA
df_lab['NJ'] = 0 #pd.NA
df_lab['NH'] = 0 #pd.NA
df_lab['NV'] = 0 #pd.NA
df_lab['AK'] = 0 #pd.NA
df_lab['DE'] = 0 #pd.NA
df_lab['UT'] = 0 #pd.NA

In [None]:
for model_desc in model_list_local:
    print("-----------Model: {}-----------".format(model_desc))
    model = model_list_local[model_desc]
    for ref_date1 in ref_dates:
        pred_start_date = ref_date1 
        model = model.untrained_model()
        pred = fit_and_predict_univariate(df_hosp_ex, states, model, model_desc, 
                                          pred_start_date, weeks_to_predict, num_samples, 
                                          fit_single_model=False, #fit one model to all states or not
                                          df_past_covar=df_lab, df_future_covar=None)
        dat_hosp_changerate_ref = df_hosp_ex[df_hosp_ex.date==ref_date1-timedelta(weeks=1)]
        save_pred_results_to_file(pred, ref_date1, model_desc, locations, quantiles, dat_hosp_changerate_ref, results_dir)

for model_desc in model_list_global:
    print("-----------Model: {}-----------".format(model_desc))
    model = model_list_global[model_desc]
    for ref_date1 in ref_dates:
        pred_start_date = ref_date1 
        model = model.untrained_model()
        pred = fit_and_predict_multivariate(df_hosp_ex, states, model, model_desc, 
                                            pred_start_date, weeks_to_predict, num_samples,
                                            df_past_covar=df_lab, df_future_covar=None)
        dat_hosp_changerate_ref = df_hosp_ex[df_hosp_ex.date==ref_date1-timedelta(weeks=1)]
        save_pred_results_to_file(pred, ref_date1, model_desc, locations, quantiles, dat_hosp_changerate_ref, results_dir)

In [None]:
def calc_and_plot_models(models, plot=True):
    df_metrics_all = pd.DataFrame({
        'location': pd.Series(dtype='str'),
        'model': pd.Series(dtype='str'),
        'horizon': pd.Series(dtype='int'),
        'target_date': pd.Series(dtype='str'),
        'metric': pd.Series(dtype='str'),
        'value': pd.Series(dtype='float')
    })
    for model in models:
        print("-----------Loading model: {}-----------".format(model))
        df_results = load_pred_result_files(results_dir, model, locations)
        print("-----------Calculating metrics for model: {}-----------".format(model))
        for loc_abbr in locations.index:
            df_metrics = calc_and_plot_pred_results_fit(df_results, df_hosp_ex, locations, loc_abbr,  
                                                        alpha_vals, figures_dir, model, plot=plot)
            df_metrics_all = pd.concat([df_metrics_all, df_metrics], ignore_index=True) 
    return df_metrics_all

In [None]:
ensemble_models = list(model_list_local.keys()) + list(model_list_global.keys()) + ['SEIRS']

ensemble_dict = {}
# load models predictions
for model in ensemble_models:
    print("-----------Loading model: {}-----------".format(model))
    df_results = load_pred_result_files(results_dir, model, locations)
    ensemble_dict[model] = df_results

weights_metric = 'wis'
weights_win = 4
df_metrics = calc_and_plot_models(ensemble_models,plot=False)
df_weights = generate_pred_weights(ensemble_dict, df_metrics, weights_metric, weights_win, locations)
wis_ensemble_name = 'CU_Ensemble'
print("-----------Generating weighted ensemble: {}-----------".format(wis_ensemble_name))
generate_weighted_pred_results(ensemble_dict, df_weights, locations, results_dir, wis_ensemble_name)   

In [None]:
df_his_stats = calculate_historical_stats(df_hosp_ex_long,states,quantiles)

last_date = pd.to_datetime('2025-05-31', format="%Y-%m-%d")
df_pred_peak = generate_peak_week_pred(df_his_stats, locations, ref_date, last_date,
                                       kde_bandwith=0.5, plot_peak_week_prob=False)

save_pred_peak_with_model_pred(df_pred_peak, results_dir, wis_ensemble_name, locations, ref_date)