In [1]:
import os
import sys
import gym
import math
import datetime
import json

from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env 

import numpy as np
import pandas as pd
import seaborn as sns
import pickle as pkl

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.legend import Legend
from matplotlib.ticker import MultipleLocator
from matplotlib import colors


from lib.train import train_standard
from lib.plots import dynamics_cmap, comparison_cmap, set_plot_env, plot_model, plot_variance, plot_variance_comparison, plot_debug
from lib.pandemic_model_tests import SEIVHRD_Env  
from lib.data_structures import build_model_data, build_variance_struct, update_variance_struct

In [2]:
MONTH = {
    1: 'Jan',
    2: 'Feb',
    3: 'Mar',
    4: 'Apr',
    5: 'May',
    6: 'Jun',
    7: 'Jul',
    8: 'Aug',
    9: 'Sep',
    10: 'Oct',
    11: 'Nov',
    12: 'Dec',
}

Some functions to quickly preprocess the data

In [3]:
def extract_month(date):
    [y, m] =  date.split('-')[:2]
    return f"{MONTH[int(m)]} {y[2:]}"


# Routines to calculate the number of days between two dates. Useful for data preprocessing
def convert_str_date_to_ints(date1, date2):
    # Date is formatted as "yyyy-mm-dd"
    [y1, m1, d1] = [int(elem) for elem in date1.split('-')]
    [y2, m2, d2] = [int(elem) for elem in date2.split('-')]
    if y1 > y2:
        return  (y2, m2, d2), (y1, m1, d1)
    elif y1 < y2:
        return  (y1, m1, d1), (y2, m2, d2)
    
    if m1 > m2:
        return  (y2, m2, d2), (y1, m1, d1)
    elif m1 < m2:
        return  (y1, m1, d1), (y2, m2, d2)
    
    if d1 > d2:
        return  (y2, m2, d2), (y1, m1, d1)
    
    return  (y1, m1, d1), (y2, m2, d2)
    
def is_leap(year):
    if (year % 4 == 0):
        if (year % 100 != 0):
            return True
        else:
            return (year % 400 == 0)
    return False


def get_number_of_days(y, m):
    if m in {1,3,5,7,8,10,12}:
        return 31
    elif m in {4,6,9,11}:
        return 30
    else:
        return 29 if  is_leap(y) else 28

def daysBetweenDates(y1, m1, d1, y2, m2, d2) -> int:
    # Assumes y2 >= y1
    days = 0
    # Years
    for y in range(y1 + 1, y2):
        for m in range(1, 13):
            days += get_number_of_days(y, m)
    # Months
    if y1 < y2:
        for m in range(m1 + 1, 13):
            days += get_number_of_days(y1, m)

        for m in range(1, m2):
            days += get_number_of_days(y2, m)

    else:
        for m in range(m1 + 1, m2):
            days += get_number_of_days(y1, m)
    # Days
    if y1 == y2 and m1 == m2:
        days += d2 - d1
    else:
        days += get_number_of_days(y1, m1) - d1  + d2
    return days

# Function used to keep selected periods
def cut_off_dates(df, countries_dic):
    new_df = df.copy(deep=True)
    for country, (location_name, _, [start, end]) in countries_dic.items():
        #print(f"Length of {country} before drop == {len(new_df.loc[new_df['location']==country])}")
        index = new_df.loc[(new_df['location'] == location_name) & ((new_df['date'] < start) | (new_df['date'] > end))].index
        new_df = new_df.drop(index , inplace=False)
        #print(f"Length of {country} after drop == {len(new_df.loc[new_df['location']==country])}")
    return new_df
    

def load_and_merge_dataframes(path_all, path_missing):
    df_gcp_missing = pd.read_csv(path_missing)
    df_gcp_missing = df_gcp_missing.drop(df_gcp_missing.columns[df_gcp_missing.columns.str.contains('unnamed',case = False)],axis = 1, inplace = False)
    df_gcp_all = pd.read_csv(path_all)
    df_gcp_all = df_gcp_all.drop(df_gcp_all.columns[df_gcp_all.columns.str.contains('unnamed',case = False)],axis = 1, inplace = False)
    df_all = pd.concat([df_gcp_all, df_gcp_missing], axis=0)
    return df_all



def compute_initial_conditions(df, labels=
                               ['cum_deceased',  
                                'cum_confirmed',
                                'current_hospitalized_patients',
                                'cum_recovered',
                                'cum_persons_vaccinated']):
    mapping = {'cum_deceased':'D0',
               'cum_confirmed':'I0',
               'current_hospitalized_patients':'H0',
               'cum_recovered':'R0',
               'cum_persons_vaccinated':'V0'}
    res = {'D0':0, 'I0':0, 'H0':0, 'R0':0, 'V0':0}
    for lbl in labels:
        sub_df = df[lbl]
        if (len(sub_df) - sub_df.isna().sum()) < 2:
            continue
        sub_df = sub_df.tolist()
        i = 0
        first = sub_df[i]
        while pd.isnull(first):
            i += 1
            first = sub_df[i]
        """
        second = sub_df[i+1]
        if not pd.isnull(second):
            res[mapping[lbl]] = abs(second-first) # Abs not needed but we never know!
        """
        res[mapping[lbl]] = first
    return res


def verify_df(df, N):
    df = df.reset_index()  # make sure indexes pair with number of rows
    for index, row in df.iterrows():
        s = row['Deceased'] + row['Hospitalized'] + row['Infected_Asym'] + row['Infected_Sym'] + row['Vaccinated']  + row['Exposed'] + row['Recovered'] +  row['Susceptible']
        if abs(s-N) > 1.:
            raise ValueError(f"Sum of variables ({s}) different from number of population ({N})")

### Calculating vaccination rate

In [4]:
countries_data = {
    'France':('france', 'France', ['2021-01-01','2021-12-31']),
    'Italy':('italy' , 'Italy', ['2021-01-01','2021-12-31']),
    'Malaysia':('malaysia' , 'Malaysia', ['2021-01-01','2021-12-31']),
    'Japan':('japan', 'Japan', ['2021-01-01','2021-12-31']),
    'United Kingdom':('united_kingdom' , 'United Kingdom', ['2021-01-01','2021-12-31']),
    'United States':('united_states', 'United States', ['2021-01-01','2021-12-31'])
}

vacc_key = 'new_persons_fully_vaccinated'
print(f"########### RESULTS FOR KEY {vacc_key}")
for country in ['United Kingdom','United States','Japan']: 
    ([d1, d2], N) = countries_data[country]
    df_view = df_gcp_missing[(df_gcp_missing['location']==country) & ((df_gcp_missing['date'] >= d1) & (df_gcp_missing['date'] <= d2))]
    cum_vacc = df_view[vacc_key].cumsum().tolist()
    rate = (cum_vacc[-1] - cum_vacc[0]) / len(df_view) / N
    print(f"Vaccination rate for the {country} is: {round(rate, 4)}")
for country in ['France','Italy','Malaysia']: 
    ([d1, d2], N) = countries_data[country]
    df_view = df_gcp_all[(df_gcp_all['location']==country) & ((df_gcp_all['date'] >= d1) & (df_gcp_all['date'] <= d2))]
    cum_vacc = df_view[vacc_key].cumsum().tolist()
    rate = (cum_vacc[-1] - cum_vacc[0]) / len(df_view) / N
    print(f"Vaccination rate for the {country} is: {round(rate, 4)}")


########### RESULTS FOR KEY new_persons_fully_vaccinated


ValueError: too many values to unpack (expected 2)

### Training models

Below is the training code. We picked 6 countries for our simulations:

1. 'France'
2. 'Italy'
3. 'Japan'
4. 'Malaysia'
5. 'United Kingdom'
6. 'United States'

We load countries' environments and train a PPO model on them. We save the model's weights at the end. 

In [None]:
#################### Running training ##############################
countries_data = {
    'France':('france_new', 'France', ['2021-01-01','2021-12-31']),
    'Italy':('italy' , 'Italy', ['2021-01-01','2021-12-31']),
    'Malaysia':('malaysia' , 'Malaysia', ['2021-01-01','2021-12-31']),
    'Japan':('japan', 'Japan', ['2021-01-01','2021-12-31']),
    'United Kingdom':('united_kingdom' , 'United Kingdom', ['2021-01-01','2021-12-31']),
    'United States':('united_states', 'United States', ['2021-01-01','2021-12-31'])
}


countries_data = {
    #'France':('france_new', 'France', ['2021-01-01','2021-12-31']),
    #'France':('france', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Health_Reward':('france_health_reward', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Health_Reward_Hosp_Weight':('france_health_reward_hosp_weight', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Health_Reward_Hosp_Weight_Max':('france_health_reward_hosp_weight_max', 'France', ['2021-01-01','20210-12-31']),
    #'France_Health_Reward_Max':('france_health_reward_max', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Hosp_Dec_Health_Weights_Max':('france_hosp_dec_health_weights_max', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Hosp_Dec_Weights':('france_hosp_dec_weights', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Hosp_Dec_Weights_Max':('france_hosp_dec_weights_max', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Init_Cond':('france_init_cond', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Init_Cond_Reward_Hosp_Weight_Max':('france_init_cond_reward_hosp_weight_max', 'France',  ['2021-01-01','20210-12-31']),
    #'France_Init_Cond_Hosp_Dec_Health_Weights_Max':('france_init_cond_hosp_dec_health_weights_max', 'France',  ['2021-01-01','20210-12-31']),
}

countries_data = {
    'Paris':('paris', 'Paris', ['2020-03-18', '2021-03-18']),
    #'Paris_Health_Reward_Hosp_Weight_Max':('paris_health_reward_hosp_weight_max', 'Paris', ['2020-03-18', '2021-03-18']),
    #'Paris_Health_Reward_Max':('paris_health_reward_max', 'Paris', ['2020-03-18', '2021-03-18']),
    #'Paris_Hosp_Dec_Health_Weights_Max':('paris_hosp_dec_health_weights_max', 'Paris', ['2020-03-18', '2021-03-18']),
    #'Paris_Init_Cond_Hosp_Dec_Health_Weights_Max':('paris_init_cond_hosp_dec_health_weights_max', 'Paris', ['2020-03-18', '2021-03-18']),
    #'IDF':('idf', 'Île-de-France',  ['2020-03-18', '2021-03-18']),
    #'France_Health_Reward':('france_health_reward', 'France',  ['2021-01-01','20210-12-31']),
 }


countries_data = {
    'Singapore':('singapore', 'Singapore', ['2020-04-20', '2021-04-20']),
    'Paris':('paris', 'Paris', ['2020-03-18', '2021-03-18']),
    'New-York':('new-york', 'New-York', ['2020-04-11', '2021-04-11']),
    'Tokyo':('tokyo', 'Tokyo', ['2020-05-08', '2021-05-08']),
 }

"""
Effective range for singapore city: ['2020-04-20', '2021-08-04']
Effective range for paris city: ['2020-03-18', '2023-06-30']
Effective range for new-york city: ['2020-04-11', '2023-03-23']
Effective range for tokyo city: ['2020-05-08', '2023-05-07']
"""

for country, (f_name, location_key, _) in countries_data.items():
    print(f"############# Loading configuration file '{f_name}.json'")
    cfg_file = f_name + '.json'
    with open(os.path.join(f'./config',f'{cfg_file}'), 'r') as ftp:
        cfg = json.load(ftp)
    
    ### Pandemic Environnement 
    params = cfg['spec-params']
    params_reg = cfg['env-params']
    print(f"############# Creating SEIVHRD environment")
    env = SEIVHRD_Env(params = params, **params_reg)
    
    
    """
    if 'init-conds' in cfg:
        print(f"############# Updating SEIVHRD environment with initial conditions")
        params_cond = cfg['init-conds']
        env.update_initial_conditions(V0=params_cond['V0'], 
                                  E0=params_cond['E0'], 
                                  I_a0=params_cond['I_a0'], 
                                  I_s0=params_cond['I_s0'], 
                                  H0=params_cond['H0'], 
                                  R0=params_cond['R0'], 
                                  D0=params_cond['D0'])
    """
    check_env(env)
    
    print(f"############# Creating RL model")
    agent_type = "PPO"
    save_dir = f"./outputs/simulations/{agent_type}/{f_name}"
    os.makedirs(save_dir, exist_ok=True)  
    weights_dir = os.path.join(save_dir, "weights")
    log_dir = os.path.join(save_dir, "logs")
    
    model_ppo = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_dir)

    TIMESTEPS, EPOCHS = 730, 200 # 2 years
    env.max_steps = round(TIMESTEPS / env.days)
    
    print(f"############# Training RL model for {TIMESTEPS} steps")
    for index in range(0, EPOCHS):
        print(f"##### Epoch number {index}")
        model_ppo.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO_{f_name}")
        if index % 10 == 0:
            model_ppo.save(f"{weights_dir}_{index}_{TIMESTEPS}")
    print(f"##### Saving final model at '{weights_dir}_{index+1}_{TIMESTEPS}'")
    model_ppo.save(f"{weights_dir}_{index+1}_{TIMESTEPS}")


### Running predictions

We load the preprocessed data

In [6]:
df_all = None
countries_data = {
    'Singapore':('singapore', 'Singapore', ['2020-04-20', '2021-04-20']),
    'Paris':('paris', 'Paris', ['2020-03-18', '2021-03-18']),
    'New-York':('new-york', 'New-York', ['2020-04-11', '2021-04-11']),
    'Tokyo':('tokyo', 'Tokyo', ['2020-05-08', '2021-05-08']),
 }
for city, (f_name, location_key, _) in countries_data.items():
    df = pd.read_csv(os.path.join(f"./data/cities/preprocessed",f"{city}.csv"))
    #print(f"############# Number of rows for city {f_name} == {len(df)}")
    #for col in df.columns:
    #    print(f"===>>> Number of NaN for column '{col}' = {df[col].isna().sum()}")
    #print(f"columns for city {f_name} are : {df.columns}")
    if df_all is None:
        df_all = df
    else:
        df_all = pd.concat([df_all, df], ignore_index=True)
df_truncated = df_all.drop(df_all.columns[df_all.columns.str.contains('Unnamed',case = False)], axis = 1, inplace = False)
#df_truncated = cut_off_dates(df_truncated, countries_data)
df_truncated = df_truncated.interpolate(method='linear', limit_direction='both', axis=0)
df_truncated['month'] = df_truncated['date'].apply(extract_month)
df_truncated

Unnamed: 0,date,cum_confirmed,cum_deceased,cum_recovered,cum_persons_vaccinated,current_hospitalized_patients,population,location,month
0,2020-04-20,8014.0,11.0,801.0,2.0,1364.0,5638676,singapore,Apr 20
1,2020-04-21,9125.0,11.0,839.0,2.0,1381.0,5638676,singapore,Apr 20
2,2020-04-22,10141.0,12.0,896.0,2.0,1571.0,5638676,singapore,Apr 20
3,2020-04-23,11178.0,12.0,924.0,2.0,1342.0,5638676,singapore,Apr 20
4,2020-04-24,12075.0,12.0,956.0,2.0,1205.0,5638676,singapore,Apr 20
...,...,...,...,...,...,...,...,...,...
3151,2023-05-03,4382724.0,8111.0,3093983.0,15748511.0,47585.0,13942856,tokyo,May 23
3152,2023-05-04,4383630.0,8114.0,3093983.0,15748511.0,47592.0,13942856,tokyo,May 23
3153,2023-05-05,4384692.0,8117.0,3093983.0,15748511.0,47599.0,13942856,tokyo,May 23
3154,2023-05-06,4387037.0,8120.0,3093983.0,15748511.0,47562.0,13942856,tokyo,May 23


We now load both the model and environment and run predictions. We save them at the end.

In [14]:
#################### Running predictions ##############################
countries_data = {
    'Singapore':('singapore', 'Singapore', ['2020-04-20', '2021-04-20']),
    'Paris':('paris', 'Paris', ['2020-03-18', '2021-03-18']),
    'New-York':('new-york', 'New-York', ['2020-04-11', '2021-04-11']),
    'Tokyo':('tokyo', 'Tokyo', ['2020-05-08', '2021-05-08']),
 }

TIMESTEPS, EPOCHS = 730, 200 # 3 years
WINDOW = 1095

for country, (f_name, location_key, [d1, d2]) in countries_data.items():
    country = country.split('_')[0]
    print(f"############# Loading configuration for country {country}")
    cfg_file = f_name + '.json'
    with open(os.path.join(f'./config',f'{cfg_file}'), 'r') as ftp:
        cfg = json.load(ftp)
    
    ### Pandemic Environnement 
    params = cfg['spec-params']
    params_reg = cfg['env-params']
    print(f"############# Creating SEIVHRD environment")
    env = SEIVHRD_Env(params = params, **params_reg)
    env.max_steps = round(WINDOW / env.days)
    check_env(env)
    
    
    ### Model 
    print(f"############# Loading trained RL model")
    agent_type = "PPO"
    save_dir = f"./outputs/simulations/{agent_type}/{f_name}"
    weights_dir = os.path.join(save_dir, "weights")
    log_dir = os.path.join(save_dir, "logs")
    model_ppo = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_dir)
    model_ppo.load(f"{weights_dir}_{EPOCHS}_{TIMESTEPS}")


    ### Inference
    print(f"############# Running inference")
    data_variance, rounds = build_variance_struct(), 1000
    

    """
    cum_confirmed, 
    cum_deceased,
    cum_recovered, 
    vaccines, 
    cum_persons_vaccinated, 
    current_hospitalized_patients
    """
   
    df = df_truncated.loc[(df_truncated['location']==f_name)]
    """
    init_conds = compute_initial_conditions(
        df, 
        labels = ['cum_deceased',
                  'cum_confirmed',
                  'current_hospitalized_patients',
                  'cum_recovered',
                  'cum_persons_vaccinated']
    )
    print(f"Initial conditions == {init_conds}")
    env.update_initial_conditions(V0=init_conds['V0'], 
                                  E0=0, 
                                  I_a0=0, 
                                  I_s0=init_conds['I0'], 
                                  H0=init_conds['H0'], 
                                  R0=init_conds['R0'], 
                                  D0=init_conds['D0'])
    """

    for i in range(rounds):
        run_data = train_standard(env, model=model_ppo, T_max=TIMESTEPS, mode="Predict")
        data_variance = update_variance_struct(data_variance, run_data)


    ### Saving the structure
    print(f"############# Saving the results...")
    dataframe_path = f"./outputs/simulations/{agent_type}/results"
    os.makedirs(dataframe_path, exist_ok=True)
    if 'Infected_cumul' in data_variance:
        del data_variance['Infected_cumul'] # We don't need this one
    df = pd.DataFrame.from_dict(data_variance)
    df = df.drop(df.columns[df.columns.str.contains('Unnamed',case = False)], axis = 1, inplace = False)
    df.to_csv(f"{os.path.join(dataframe_path,f_name)}_results.csv")


############# Loading configuration for country Singapore
############# Creating SEIVHRD environment
############# Loading trained RL model
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
############# Running inference
############# Saving the results...
############# Loading configuration for country Paris
############# Creating SEIVHRD environment
############# Loading trained RL model
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
############# Running inference
############# Saving the results...
############# Loading configuration for country New-York
############# Creating SEIVHRD environment
############# Loading trained RL model
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
############# Running inference
############# Saving the results...
############# Loading configuration for country Tokyo
############# Creating SEIVHRD environment
##

### Plotting curves

We now plot the curves together (real vs simulated). We compare cumulated deaths and hospitalized patients.

In [12]:
agent_type = "PPO"

comparison_cmap = {
    'Hospitalized':'xkcd:blue',
    'Hospitalized_ref':'xkcd:red',
    'Recovered':'xkcd:blue',
    'Recovered_ref': 'xkcd:red',
    'Deceased':'xkcd:blue',
    'Deceased_ref':'xkcd:red',
    'Infected':'xkcd:blue',
    'Infected_cum':'xkcd:blue',
    'Infected_ref': 'xkcd:red',
}

def plot_variance(env, data, filepath, xlim=366):

    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
    #figure.suptitle('Variations', fontsize=18)
    sns.lineplot(x='Days', y='Economy', label='Economy',  color=dynamics_cmap['Economy'], data=data, ax=axis, linewidth=2.5).set(ylabel=None, xlabel="Days")
    axis.legend(loc="best")
    plt.xlim(left=0, right=xlim)
    plt.savefig(filepath+f"_eco.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
    sns.lineplot(x='Days', y='Actions', label='Actions', color=dynamics_cmap['Actions'], data=data, ax=axis, linewidth=2.5).set(ylabel=None, xlabel="Days")
    axis.legend(loc="best")
    plt.xlim(left=0, right=xlim)
    
    plt.savefig(filepath+f"_actions.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
    sns.lineplot(x='Days', y='Infected', data=data,  color=dynamics_cmap['Infected'], ax=axis, label="Infected", linewidth=2.5).set(xlabel="Days", ylabel=None)
    axis.legend(loc="best")
    plt.xlim(left=0, right=xlim)
    
    plt.savefig(filepath+f"_infections.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
    sns.lineplot(x='Days', y='Deceased', data=data,  color=dynamics_cmap['Deceased'], ax=axis, label='Deceased', linewidth=2.5).set(xlabel="Days", ylabel=None)
    axis.legend(loc="best")
    plt.xlim(left=0, right=xlim)
    
    plt.savefig(filepath+f"_deceased.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10,8), facecolor="#ffffff")
    p = sns.lineplot(x='Days', y='Hospitalized', data=data, color=dynamics_cmap['Hospitalized'], ax=axis, label='Hospitalized', linewidth=2.5)
    p.set(xlabel="Days", ylabel=None)
    p.axhline(env.hospital_capacity, color="r", linestyle='--',label=r'$C_{h}$', linewidth=2.5)
    p.legend(loc="best")
    plt.xlim(left=0, right=xlim)
    
    plt.savefig(filepath+f"_hospitalized.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

    ### All dynamics
    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
    sns.lineplot(x='Days', y='Infected', data=data,  color=dynamics_cmap['Infected'], ax=axis, label="Infected", linewidth=2.5).set(xlabel="Days", ylabel=None)
    sns.lineplot(x='Days', y='Deceased', data=data,  color=dynamics_cmap['Deceased'], ax=axis, label='Deceased', linewidth=2.5).set(xlabel="Days", ylabel=None)
    sns.lineplot(x='Days', y='Hospitalized', data=data, color=dynamics_cmap['Hospitalized'], ax=axis, label='Hospitalized', linewidth=2.5).set(xlabel="Days", ylabel=None)
    axis.axhline(env.hospital_capacity, color="r", linestyle='--',label=r'$C_{h}$', linewidth=2.5)
    axis.legend(loc="best")
    plt.xlim(left=0, right=xlim)
    
    plt.savefig(filepath+f"_dynamics.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

def plot_debug(data, hosp_cap = None, file_path='./out.pdf', xlim = None, 
               labels=[
                'Hospitalized', 
                'Deceased', 
                'Recovered', 
                #'Infected_cum',
                'Infected',
                'Susceptible',
                'Vaccinated',
                'Exposed']):
    ### All labels
    _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
    for label in labels:
        # Results of agent's management
        sns.lineplot(x='Days', y=label, data=data,  color=dynamics_cmap[label], ax=axis, label=label, linewidth=2.5).set(xlabel="Days", ylabel=None)
    # Actual numbers
    axis.legend(loc="best")
    if hosp_cap:
        axis.axhline(hosp_cap, color="r", linestyle='--',label=r'$C_{h}$', linewidth=2.5)
    if xlim is not None:
        plt.xlim(left=0, right=xlim)
    else:
        plt.xlim(left=0)
    plt.savefig(f"{file_path}.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
    plt.close()

def plot_variance_comparison(data, reference, hosp_cap = None, file_path='./out.pdf', labels=['Hospitalized', 'Deceased']):
    label_map = {
        'Hospitalized': 'current_hospitalized_patients',
        'Deceased': 'cum_deceased',
    }

    ### All labels
    for label in labels:
        _, axis = plt.subplots(ncols=1, nrows=1, figsize=(10, 8), facecolor="#ffffff")
        # Results of agent's management
        sns.lineplot(
            x='Days', y=label, data=data, color=comparison_cmap[label], ax=axis, label="Simulation", 
            linewidth=2.5).set(xlabel=None, ylabel=label) #.set(xlabel="Days", ylabel=label)

        # Actual numbers
        sns.lineplot(
            x='Days', y=label_map[label], data=reference,  color=comparison_cmap[f"{label}_ref"], ax=axis, 
            label=f"Real data", linewidth=1.8).set(xlabel=None, ylabel=label) #.set(xlabel="Days", ylabel=None)

        #axis.set_xticks(reference['Days'], minor=True, direction='in')
        step, rot = 90, 60
        plt.xticks(
            list(range(0, len(reference), step)), 
            [reference['month'].iloc[i] for i in range(0, len(reference), step)],
            minor=False,
            rotation=rot)
        if label == "Hospitalized":
            if not hosp_cap:
                raise ValueError(f"Hospitals number 'hosp_cap' should be provided when plotting hospitalized counts.")
            axis.axhline(hosp_cap, color="grey", linestyle='--',label=r'$C_{h}$', linewidth=2.5)
        axis.set_yscale('log')
        axis.legend(loc="best")
        plt.xlim(right=min(len(data), len(reference)) + 1)
        plt.savefig(f"{file_path}_{label}.pdf", pad_inches=0, bbox_inches='tight', transparent=True)
        plt.close()

countries_data = {
    'Singapore':('singapore', 'Singapore', ['2020-04-20', '2021-04-20']),
    'Paris': ('paris', 'Paris', ['2020-03-18', '2021-03-18']),
    'New-York':('new-york', 'New-York', ['2020-04-11', '2021-04-11']),
    'Tokyo':('tokyo', 'Tokyo', ['2020-05-08', '2021-05-08']),
 }

for key, (country, location_key, _) in countries_data.items():
    dataframe_path = f"./outputs/simulations/{agent_type}/results/{country}_results.csv"
    print(f"############# Loading the dataframe from '{dataframe_path}'...")
    df_sim = pd.read_csv(f"{dataframe_path}")
    df_sim = df_sim.groupby('Days', as_index=False).mean()
    
    print(f"############# Loading configuration for country {country}")
    cfg_file = country + '.json'
    with open(os.path.join(f'./config',f'{cfg_file}'), 'r') as ftp:
        cfg = json.load(ftp)
    
    ### Pandemic Environnement 
    params = cfg['spec-params']
    params_reg = cfg['env-params']
    print(f"############# Creating SEIVHRD environment")
    env = SEIVHRD_Env(params = params, **params_reg)
    check_env(env)
    
    #print(f"############# Verifying the saved predictions")
    #verify_df(df_sim, params_reg['N'])
    
    print(f"############# Plotting simulations")
    df_ref = df_truncated.loc[df_truncated['location']==country]
    #df_ref = add_cummulative_col(df_ref, 'new_deceased', 'new_deceased_cum')
    #df_ref.insert(1, "new_deceased_cum", df_ref['new_deceased'].cumsum())
    df_ref.insert(0, 'Days', range(1, 1 + len(df_ref)))
    
    dir_path = f"./outputs/simulations/plots/{agent_type}"
    os.makedirs(dir_path, exist_ok=True)
    print(f"## Comparison with real data")
    plot_variance_comparison(df_sim, df_ref, cfg['env-params']['hosp_cap'], os.path.join(dir_path, f"{country}") , labels=['Hospitalized', 'Deceased'])    
    print(f"## Everything together")
    #df_sim.insert(0, "Infected_cum", df_sim['Infected_Asym'] + df_sim['Infected_Sym'])
    #plot_debug(df_sim, cfg['env-params']['hosp_cap'], os.path.join(dir_path, f"{country}_all"), xlim=366)
    plot_debug(df_sim, None, os.path.join(dir_path, f"{country}_all"), xlim=366)
    print(f"### Variances")
    plot_variance(env, df_sim, os.path.join(dir_path, f"{country}_variances"))

############# Loading the dataframe from './outputs/simulations/PPO/results/singapore_results.csv'...
############# Loading configuration for country singapore
############# Creating SEIVHRD environment
############# Plotting simulations
## Comparison with real data
## Everything together
### Variances
############# Loading the dataframe from './outputs/simulations/PPO/results/paris_results.csv'...
############# Loading configuration for country paris
############# Creating SEIVHRD environment
############# Plotting simulations
## Comparison with real data
## Everything together
### Variances
############# Loading the dataframe from './outputs/simulations/PPO/results/new-york_results.csv'...
############# Loading configuration for country new-york
############# Creating SEIVHRD environment
############# Plotting simulations
## Comparison with real data
## Everything together
### Variances
############# Loading the dataframe from './outputs/simulations/PPO/results/tokyo_results.csv'...


In [None]:
df_ref

In [None]:
df_sim

In [None]:
df_sim.groupby('Days', as_index=False).mean()