# Training notebook

Here we tune parameters of models. From the 7 available seasons, we use first 4 for training and the rest 3 go in testing. There is no fitting on data as such and so we just performance over 4 training seasons to see which parameters work best.

In [28]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
np.random.seed(1234)

import sys
sys.path.append("../src")
from utils import data as udata
from utils import dists as udists
from utils import misc as u
from truth import mask_truths
from predictors import make_predictor
import losses
from pymmwr import Epiweek
from tqdm import tqdm, trange
import ledge.merge as merge
import ledge.update as update
from functools import partial

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
EXP_DIR = "../data/processed/cdc-flusight-ensemble/"
OUTPUT_DIR = "../models/cdc-flusight-ensemble/"
TARGET = "1-ahead"
MAX_LAG = 29
REGIONS = ["nat", *[f"hhs{i}" for i in range(1, 11)]]
REGIONS = ["nat"] # Speed up!
TRAINING_SEASONS = list(range(2010, 2014))
LOSS_FN = losses.ploss

In [30]:
components = [udata.Component(EXP_DIR, m) for m in udata.available_models(EXP_DIR)]
ad = udata.ActualData(EXP_DIR)

# Evaluation

In [31]:
def evaluate(predictor, loss_fn=losses.ploss):
    """
    Evaluate the predictor over all training seasons and regions and return mean score
    """
    
    first_losses = []
    final_losses = []
    
    with tqdm(total=len(TRAINING_SEASONS) * len(REGIONS)) as pbar:
        for season in TRAINING_SEASONS:
            for region in REGIONS:
                truths = [ad.get(TARGET, region, season, lag=l) for l in range(MAX_LAG + 1)]
                c_preds = [cmp.get(TARGET, region, season) for cmp in components]
            
                first_truth = merge.zero(truths)
                final_truth = merge.latest(truths)
                pred = predictor(truths, c_preds)
                first_losses.append(float(loss_fn(pred, first_truth).mean()))
                final_losses.append(float(loss_fn(pred, final_truth).mean()))
                pbar.update()

    return {
        "first_loss": np.mean(first_losses),
        "final_loss": np.mean(final_losses)
    }

In [26]:
update_fn = update.ftl
merge_fn = merge.zero
predictor = make_predictor(LOSS_FN, merge_fn, update_fn)

# Tuning

In [34]:
# Merge strategies
merge_st = [("zero", merge.zero), ("latest", merge.latest)]

## Follow the leader

In [35]:
for merge_id, merge_fn in merge_st:
    losses = evaluate(make_predictor(LOSS_FN, merge_fn, update.ftl), LOSS_FN)
    print(f"{merge_id}, {losses['first_loss'], losses['final_loss']}")

100%|██████████| 4/4 [00:51<00:00, 13.00s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, (0.8572893443772136, 0.9007682099397212)


100%|██████████| 4/4 [00:20<00:00,  5.08s/it]

latest, (0.8558966113027653, 0.8934405764400863)





## MW

In [36]:
for merge_id, merge_fn in merge_st:
    for eta in np.linspace(0.5, 1.0, 11):
        update_fn = partial(update.mw, eta=eta)
        losses = evaluate(make_predictor(LOSS_FN, merge_fn, update_fn), LOSS_FN)
        print(f"{merge_id}, {eta}: {losses['first_loss'], losses['final_loss']}")

100%|██████████| 4/4 [00:19<00:00,  4.86s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.5: (0.8800236126573013, 0.8978411657567614)


100%|██████████| 4/4 [00:20<00:00,  5.14s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.55: (0.8777744189700727, 0.8970100292176808)


100%|██████████| 4/4 [00:19<00:00,  4.97s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.6: (0.8755608329764345, 0.896247947204854)


100%|██████████| 4/4 [00:20<00:00,  5.36s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.65: (0.8734250196673613, 0.8955600217249031)


100%|██████████| 4/4 [00:23<00:00,  5.78s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.7: (0.8714124461204742, 0.894950121159422)


100%|██████████| 4/4 [00:20<00:00,  5.08s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.75: (0.8695717976653177, 0.8944245182416244)


100%|██████████| 4/4 [00:20<00:00,  5.15s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.8: (0.8679648876150415, 0.8940058327929925)


100%|██████████| 4/4 [00:21<00:00,  5.28s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.8500000000000001: (0.8667059246596749, 0.8937735359512097)


100%|██████████| 4/4 [00:20<00:00,  4.99s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.9: (0.8661076418539102, 0.8939530424696256)


100%|██████████| 4/4 [00:21<00:00,  5.38s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.95: (0.8678538998560871, 0.8946198995365171)


100%|██████████| 4/4 [00:22<00:00,  5.67s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 1.0: (0.8714403289507634, 0.8943696669853193)


100%|██████████| 4/4 [00:20<00:00,  5.18s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.5: (0.8839089347666891, 0.8973592565817979)


100%|██████████| 4/4 [00:22<00:00,  5.66s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.55: (0.8819961697712663, 0.8963151587345148)


100%|██████████| 4/4 [00:25<00:00,  6.34s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.6: (0.8800134450318927, 0.8953043358961883)


100%|██████████| 4/4 [00:21<00:00,  5.50s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.65: (0.8779576307057846, 0.8943303403079039)


100%|██████████| 4/4 [00:21<00:00,  5.42s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.7: (0.8758292610184053, 0.8933921188079093)


100%|██████████| 4/4 [00:21<00:00,  5.32s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.75: (0.87363937369659, 0.8924841399100167)


100%|██████████| 4/4 [00:22<00:00,  5.53s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.8: (0.8714203371579674, 0.8916023694191065)


100%|██████████| 4/4 [00:23<00:00,  5.78s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.8500000000000001: (0.8692405518642256, 0.8907705654642719)


100%|██████████| 4/4 [00:22<00:00,  5.71s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.9: (0.8672534561248718, 0.8901393386791897)


100%|██████████| 4/4 [00:21<00:00,  5.34s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.95: (0.8662375708334357, 0.8903008346009428)


100%|██████████| 4/4 [00:21<00:00,  5.32s/it]

latest, 1.0: (0.8697122091091656, 0.8942047744704603)





## Hedge

In [39]:
for merge_id, merge_fn in merge_st:
    for eta in np.linspace(0.5, 5.0, 11):
        update_fn = partial(update.hedge, eta=eta)
        losses = evaluate(make_predictor(LOSS_FN, merge_fn, update_fn), LOSS_FN)
        print(f"{merge_id}, {eta}: {losses['first_loss'], losses['final_loss']}")

100%|██████████| 4/4 [00:18<00:00,  4.68s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.5: (0.8866194723930647, 0.9007701611618466)


100%|██████████| 4/4 [00:19<00:00,  4.76s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 0.95: (0.8786988854604574, 0.8977764475960712)


100%|██████████| 4/4 [00:18<00:00,  4.67s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 1.4: (0.8733773154344155, 0.896411911459502)


100%|██████████| 4/4 [00:19<00:00,  4.84s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 1.85: (0.8698247612121197, 0.8957915176301239)


100%|██████████| 4/4 [00:18<00:00,  4.57s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 2.3: (0.8674130386009042, 0.8954886666826891)


100%|██████████| 4/4 [00:19<00:00,  4.78s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 2.75: (0.8657448546867752, 0.8953536608484746)


100%|██████████| 4/4 [00:18<00:00,  4.67s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 3.2: (0.8645798794605682, 0.8953492701788603)


100%|██████████| 4/4 [00:18<00:00,  4.60s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 3.65: (0.8637636425743663, 0.8954523134793623)


100%|██████████| 4/4 [00:19<00:00,  4.80s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 4.1: (0.8631903787159638, 0.8956322606827222)


100%|██████████| 4/4 [00:19<00:00,  4.93s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 4.55: (0.8627863699383512, 0.8958583160325466)


100%|██████████| 4/4 [00:19<00:00,  4.79s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

zero, 5.0: (0.8625005939688857, 0.896106312958935)


100%|██████████| 4/4 [00:23<00:00,  5.96s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.5: (0.8892561550495424, 0.9007019452019019)


100%|██████████| 4/4 [00:24<00:00,  6.25s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 0.95: (0.882868569951269, 0.8970867948698384)


100%|██████████| 4/4 [00:23<00:00,  5.94s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 1.4: (0.8780479514699451, 0.8949693759852758)


100%|██████████| 4/4 [00:23<00:00,  5.89s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 1.85: (0.874225362996002, 0.8936493076880561)


100%|██████████| 4/4 [00:22<00:00,  5.67s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 2.3: (0.8711662560948767, 0.8927680903584527)


100%|██████████| 4/4 [00:24<00:00,  6.04s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 2.75: (0.868762938771623, 0.892150846201724)


100%|██████████| 4/4 [00:22<00:00,  5.56s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 3.2: (0.866920941552163, 0.8917069769143513)


100%|██████████| 4/4 [00:21<00:00,  5.28s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 3.65: (0.8655294667451761, 0.8913817949776875)


100%|██████████| 4/4 [00:21<00:00,  5.43s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 4.1: (0.864476030473778, 0.8911387987320996)


100%|██████████| 4/4 [00:23<00:00,  5.78s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

latest, 4.55: (0.8636649135021445, 0.8909539240942543)


100%|██████████| 4/4 [00:21<00:00,  5.39s/it]

latest, 5.0: (0.86302364913953, 0.890812073202538)



