In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import timedelta
import torch

from utils import *
from darts_models import *
from historical_mean_model import *
from historical_drift_model import *
from SIR_EAKF_model import *
from SIR_AH_EAKF_model import *
from peak_predictions import *

In [None]:
epiyear = 2024
epiweek = 50 
ref_date = epiweek_to_dates(epiyear, epiweek).enddate() #Saturday at the end of epiweek

data_dir = './data/'
results_dir = './results/'
figures_dir = './figures/' 

locations_fname = data_dir +"locations.csv"
locations = pd.read_csv(locations_fname)
loc_name2abbr = dict(zip(locations['location_name'], locations['abbreviation']))
locations = locations.set_index('abbreviation')
states = locations.index.values
# states = np.array(["AK", "AL", "AR", "AZ", "CA", "CO", "CT", "DE", "DC", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", 
#           "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", 
#           "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY", "PR", "US"])
num_states = len(states)

populations_fname = data_dir +"populations.csv"
populations = pd.read_csv(populations_fname)
# df_pop = generate_pop_per_week(states, populations)

AH_daily, df_AH = read_AH(data_dir)

In [None]:
new_format = True
download_hosp=False
fix_partial_reporting=False
fix_outliers=False
df_hosp = read_hosp_incidence_data(data_dir, epiyear, epiweek, states, 
                                   new_format=new_format, download=download_hosp,
                                   fix_partial_reporting=fix_partial_reporting, 
                                   fix_outliers=fix_outliers,
                                   plot=False)

df_ili = read_ili_incidence_data(data_dir, epiyear, epiweek, states, df_hosp, 
                                 smooth=False, scale=True, plot=False)

plot_hosp_with_ili = False
if(plot_hosp_with_ili):
    fig, axs = plt.subplots(nrows=num_states, ncols=1, figsize=(7, 2 * num_states), sharex=True)
    for i, state in enumerate(states):
        ax = axs[i] if num_states > 1 else axs  # Handle single subplot case
        ax.plot(df_hosp['date'], df_hosp[state], label='Hospitalizations', color='blue')
        ax.plot(df_ili['date'], df_ili[state], label='Scaled ILI', color='red')
        ax.set_title(f"{state} - Hospitalizations vs. ILI (scaled)")
        ax.set_ylabel("Cases")
        ax.legend()
    plt.xlabel("Date")
    plt.tight_layout()
    plt.show()

In [None]:
switch_epiyear = 2022
switch_epiweek = 26
switch_date = pd.to_datetime(epiweek_to_dates(switch_epiyear, switch_epiweek).enddate())

df1 = df_ili[df_ili.date < switch_date]
df2 = df_hosp[(df_hosp.date >= switch_date)]
df_hosp_ex = pd.concat([df1, df2]).reset_index()

hosp_ex_start_date = pd.to_datetime('2010-10-09', format="%Y-%m-%d")
df_hosp_ex = df_hosp_ex[df_hosp_ex['date']>=hosp_ex_start_date]
df_hosp_ex[states] = df_hosp_ex[states].fillna(0)
# na_counts = df_hosp_ex[states].isna().sum()
# na_counts

df_hosp_ex_long = pd.melt(df_hosp_ex,id_vars=['date','year','week'],value_vars=df_hosp_ex.columns[3:],var_name='state',value_name='cases')
plot_hosp_ex = False
if(plot_hosp_ex):
    g = sns.FacetGrid(df_hosp_ex_long, col="state", col_wrap=5, hue="state", sharey=False, sharex=True, height=3, aspect=1.33)
    g.map(sns.lineplot, "date", "cases")
    [plt.setp(ax.get_xticklabels(), rotation=90) for ax in g.axes.flat]
    plt.subplots_adjust(hspace=0.4, wspace=0.4)

In [None]:
num_samples = 1000
alpha_vals = [0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 1]
quantiles = np.append(np.append([0.01,0.025],np.arange(0.05,0.95+0.05,0.050)),[0.975,0.99])
weeks_to_predict = 4

past_weeks = 0 # Set this to produce past weeks forecasts
ref_dates = [pd.Timestamp(ref_date) + timedelta(weeks=w) for w in range(-past_weeks,1)]

# For reproducibility
np.random.seed(230098)
torch.manual_seed(42095703)
torch.set_float32_matmul_precision('medium')
darts_local_models, darts_global_models = get_darts_models(quantiles)

sir_eakf_model = "SIR-EAKF"
his_drift_model = "historical_drift"
his_mean_model = "historical_mean"

In [None]:
#generate pred using sir_eakf/sir_ah_eakf model
use_AH = False
sir_eakf_start_date = pd.to_datetime('2024-09-14', format="%Y-%m-%d")
for ref_date1 in ref_dates:
    dat_changerate_ref = df_hosp_ex[df_hosp_ex.date==ref_date1-timedelta(weeks=1)]
    if(use_AH):
        generate_sir_ah_eakf_pred(df_hosp_ex, sir_eakf_start_date, ref_date1, weeks_to_predict, 
                                locations, AH_daily, quantiles, num_samples, 
                                dat_changerate_ref, results_dir, model_desc=sir_eakf_model)
    else:
        generate_sir_eakf_pred(df_hosp_ex, sir_eakf_start_date, ref_date1, weeks_to_predict, 
                           locations, quantiles, num_samples, 
                           dat_changerate_ref, results_dir, model_desc=sir_eakf_model)

In [None]:
#generate pred using historical drift model
epiweek_window = 1
for ref_date1 in ref_dates:
    dat_changerate_ref = df_hosp_ex[df_hosp_ex.date==ref_date1-timedelta(weeks=1)]
    generate_historical_drift_pred(df_hosp_ex, ref_date1, weeks_to_predict, locations, quantiles, num_samples, epiweek_window,
                                   dat_changerate_ref, results_dir, model_desc=his_drift_model)

In [None]:
#generate pred using historical mean model
df_his_pred = generate_historical_mean_pred(df_hosp_ex,states,quantiles,plot=False)
for ref_date1 in ref_dates:
    df_his = save_historical_mean_pred_results_to_file(df_his_pred, ref_date1, weeks_to_predict, locations, 
                                                       quantiles, results_dir, model_desc=his_mean_model)

# df_his_pred_mean = df_his_pred.pivot(index='date', columns='state', values='mean').reset_index()
# To use as covariate we need to continue this forward up to max horizon
# for i in range(1,5):
#     next = pd.Timestamp(epiweek_to_dates(epiyear, epiweek+i).enddate())
#     next_last_year = pd.Timestamp(epiweek_to_dates(epiyear-1, epiweek+i).enddate())
#     next_last_year_vals = df_his_pred_mean[df_his_pred_mean['date']==next_last_year].values
#     df_his_pred_mean.loc[len(df_his_pred_mean)] = [next] + next_last_year_vals[0,1:].tolist()
# df_his_pred_mean

In [None]:
sel_lab_col = 'percent_positive'
df_lab = read_lab_selected_data(data_dir, epiyear, epiweek, states, loc_name2abbr, sel_lab_col, plot=False)
#add missing past values
missing_dates = df_hosp_ex[~df_hosp_ex['date'].isin(df_lab['date'])]
missing_dates_only = missing_dates[['date']].copy()
for col in df_lab.columns:
    if col != 'date':
        missing_dates_only[col] = pd.NA
df_lab = pd.concat([df_lab, missing_dates_only], ignore_index=True)
df_lab = df_lab.sort_values(by='date').reset_index(drop=True)
# To use as covariate we need to continue this forward up to max horizon
for w in range(0,weeks_to_predict):
    next_date = pd.Timestamp(ref_date) + timedelta(weeks=w)
    df_lab.loc[len(df_lab)] = [next_date] + [pd.NA for s in range(num_states)]
#     next_date = pd.Timestamp(epiweek_to_dates(epiyear, epiweek+w).enddate())
#     next_last_year = pd.Timestamp(epiweek_to_dates(epiyear-1, epiweek+w).enddate())
#     next_last_year_vals = df_lab[df_lab['date']==next_last_year].values
#     df_lab.loc[len(df_lab)] = [next_date] + next_last_year_vals[0,1:].tolist()

#Set to 0 values for states with missing lab for now
df_lab['RI'] = 0 #pd.NA
df_lab['PR'] = 0 #pd.NA
df_lab['DC'] = 0 #pd.NA
df_lab['NJ'] = 0 #pd.NA
df_lab['NH'] = 0 #pd.NA
df_lab['NV'] = 0 #pd.NA
df_lab['AK'] = 0 #pd.NA
df_lab['DE'] = 0 #pd.NA
df_lab['UT'] = 0 #pd.NA

In [None]:
# Run fit and predict darts models

fit_single_local_model=False  #fit one model to all states for local models

df_past_covar = None
df_future_covar = None
use_lab_covar = True
use_ah_covar = True
if(use_lab_covar):
    df_past_covar = df_lab
if(use_ah_covar):
    df_future_covar = df_AH
    # darts_local_models = {str(model_name) + '_AH': model for model_name, model in darts_local_models.items()}
    # darts_global_models = {str(model_name) + '_AH': model for model_name, model in darts_global_models.items()}
# else:
#     darts_local_models = {str(model_name) + '_NoCov': model for model_name, model in darts_local_models.items()}
#     darts_global_models = {str(model_name) + '_NoCov': model for model_name, model in darts_global_models.items()}

for model_desc in darts_local_models:
    print("-----------Model: {}-----------".format(model_desc))
    model = darts_local_models[model_desc]
    for ref_date1 in ref_dates:
        pred_start_date = ref_date1 
        model = model.untrained_model()
        pred = fit_and_predict_univariate(df_hosp_ex, states, model, model_desc, 
                                          pred_start_date, weeks_to_predict, num_samples, 
                                          fit_single_local_model, 
                                          df_past_covar, df_future_covar)
        dat_changerate_ref = df_hosp_ex[df_hosp_ex.date==ref_date1-timedelta(weeks=1)]
        save_darts_pred_results_to_file(pred, ref_date1, weeks_to_predict, locations, quantiles,
                                        dat_changerate_ref, results_dir, model_desc)

for model_desc in darts_global_models:
    print("-----------Model: {}-----------".format(model_desc))
    model = darts_global_models[model_desc]
    for ref_date1 in ref_dates:
        pred_start_date = ref_date1 
        model = model.untrained_model()
        pred = fit_and_predict_multivariate(df_hosp_ex, states, model, model_desc, 
                                            pred_start_date, weeks_to_predict, num_samples,
                                            df_past_covar, df_future_covar)
        dat_changerate_ref = df_hosp_ex[df_hosp_ex.date==ref_date1-timedelta(weeks=1)]
        save_darts_pred_results_to_file(pred, ref_date1, weeks_to_predict, locations, quantiles,
                                        dat_changerate_ref, results_dir, model_desc)

In [None]:
def calc_and_plot_models(models, plot=True):
    df_metrics_all = pd.DataFrame({
        'location': pd.Series(dtype='str'),
        'model': pd.Series(dtype='str'),
        'horizon': pd.Series(dtype='int'),
        'target_date': pd.Series(dtype='str'),
        'metric': pd.Series(dtype='str'),
        'value': pd.Series(dtype='float')
    })
    for model in models:
        print("-----------Loading model: {}-----------".format(model))
        df_results = load_pred_result_files(results_dir, model, locations)
        df_results = df_results[df_results['target'] == 'wk inc flu hosp']
        print("-----------Calculating metrics for model: {}-----------".format(model))
        for loc_abbr in locations.index:
            df_metrics = calc_and_plot_pred_results_fit(df_results, df_hosp_ex, locations, loc_abbr,  
                                                        alpha_vals, figures_dir, model, plot=plot)
            df_metrics_all = pd.concat([df_metrics_all, df_metrics], ignore_index=True) 
    return df_metrics_all

In [None]:
ensemble_models = list(darts_local_models.keys()) +
                  list(darts_global_models.keys()) +
                  [sir_eakf_model, his_drift_model]

ensemble_dict = {}
# load models predictions
for model in ensemble_models:
    print("-----------Loading model: {}-----------".format(model))
    df_results = load_pred_result_files(results_dir, model, locations)
    ensemble_dict[model] = df_results

ensemble_name = 'CU_Ensemble'
use_weighted_ensemble = True
if(use_weighted_ensemble):
    df_metrics = calc_and_plot_models(ensemble_models,plot=False)
    df_weights = generate_pred_weights(ensemble_dict, df_metrics, locations)
    print("-----------Generating weighted ensemble: {}-----------".format(ensemble_name))
    generate_weighted_pred_results(ensemble_dict, df_weights, locations, results_dir, ensemble_name)   
else:
    generate_mean_ensemble_pred_results(ensemble_dict, results_dir, ensemble_name)

In [None]:
df_his_stats = calculate_historical_stats(df_hosp_ex_long,states,quantiles,populations)
first_date = pd.to_datetime('2024-11-23', format="%Y-%m-%d")
last_date = pd.to_datetime('2025-05-31', format="%Y-%m-%d")
min_peak_date = None #pd.Timestamp(ref_date) + timedelta(weeks=1)
df_pred_peak = generate_peak_week_pred(df_his_stats, df_hosp, locations, ref_date, 
                                       first_date, last_date,min_peak_date,
                                       kde_bandwith=0.333, plot_peak_week_prob=False)
df_pred_peak.to_csv("{}/_peak_pred/{}-CU-{}.csv".format(results_dir,format(ref_date,'%Y-%m-%d'),'peak_pred'), index=False)