
This notebook can be used to locally calibrate a simplistic 'SIR' model on the daily variations of fatality and positive counts, when the effectiveness of detection rate is uncertain but given fatality ratio and time-to-recovery assumptions.

The model needs to be given:
* the size of the population of interest;
* the cumulative number of positive cases and fatalities over time;
* the time-to-recovery of infectious subjects (gamma parameter in SIR)
* the fatality ratio (percentage of infected people who will die)

And the code will automatically adjust the following parameters:
* the number of people that were initially infected;
* R0 and beta: the rate of infection of susceptible people by infected people (daily new cases = beta * susceptible * infected * population);
* detection rate: the percentage of infectious people reported as positives
* time-dependence of the detection rate: assuming that the detection rate increases or decreases over time, as testing capacity is exceeded or built up

The code prints the results, along with charts to compare model vs. data. 
It also runs a long range forecast to estimate the peak of daily fatalities and the final cumulative fatalities.



In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from scipy.optimize import curve_fit

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as mtick
from matplotlib.ticker import NullFormatter
from matplotlib.ticker import FuncFormatter
from matplotlib.dates import MonthLocator
from matplotlib.dates import AutoDateLocator
from matplotlib.pyplot import cm

import seaborn as sns

import datetime
from datetime import timedelta  

import math

#formatting functions for charts
def millions(x, pos):
    'The two args are the value and tick position'
    return '%1.1fM' % (x * 1e-6)

#formatting functions for charts
def thousands(x, pos):
    'The two args are the value and tick position'
    return '%1.1fT' % (x * 1e-3)


# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

In [None]:
'''
train = pd.read_csv("../input/covid19-global-forecasting-week-2/train.csv")
display(train.head())

test = pd.read_csv("../input/covid19-global-forecasting-week-2/test.csv")
display(test.head())

submission = pd.read_csv("../input/covid19-global-forecasting-week-2/submission.csv")
display(submission.head())
'''

In [None]:
train = pd.read_csv("../input/covid19-global-forecasting-week-4/train.csv")
train['Date'] = train['Date'].apply(lambda x: (datetime.datetime.strptime(x, '%Y-%m-%d')))

train['NewFatalities'] = train['Fatalities'].diff(1)/1
train['NewCases'] = train['ConfirmedCases'].diff(1)/1

#display(train.head(5))

print("Count of Country_Region: ", train['Country_Region'].nunique())
print("Countries with Province/State: ", train[train['Province_State'].isna()==False]['Country_Region'].unique())
print("Date range: ", min(train['Date']), " - ", max(train['Date']))

display(train.head())

In [None]:
#create two new columns: 'Region' and 'State' to bring European countries into a single region. 
#All other Country_Regions with Province_State are also captured in these two columns Region, State

#EU data from https://www.google.com/publicdata/explore?ds=mo4pjipima872_&met_y=population&idim=country_group:eu&hl=en&dl=en#!ctype=l&strail=false&bcs=d&nselm=h&met_y=population&scale_y=lin&ind_y=false&rdim=country_group&idim=country_group:eu&idim=country:ea18:at:be:bg&ifdim=country_group&hl=en_US&dl=en&ind=false
#US States from https://worldpopulationreview.com/states/#statesTable

Europe=[
    'Albania',
    'Armenia',
    'Azerbaijan',
    'Austria', 
    'Belgium', 
    'Bulgaria',
    'Croatia',
    'Cyprus',
    'Czechia',
    'Denmark',
    'Estonia',
    'Finland', 
    'France', 
    'Germany', 
    'Greece', 
    'Hungary',
    'Iceland', 
    'Ireland', 
    'Italy',
    'Latvia',
    'Lichtenstein',
    'Lithuania',
    'Luxembourg',
    'Malta',
    'Montenegro',
    'Netherlands',
    'North Macedonia',
    'Norway', 
    'Poland',
    'Portugal',
    'Romania',
    'Slovakia',
    'Slovenia',
    'Spain', 
    'Sweden', 
    'Switzerland', 
    'United Kingdom'
]

train['Province_State'].fillna('',inplace=True)

train['State'] = train['Province_State']
train['Region'] = train['Country_Region']

train.loc[train['Country_Region'].isin(Europe),'Region']='EU'
train.loc[train['Country_Region'].isin(Europe),'State']=train.loc[train['Country_Region'].isin(Europe),'Country_Region']

#census populations
#add entries to this table in order to run simulations
Population = {
    'China-': 1386e6,
    'US-': 327e6,
    'EU-': 512e6 + (10+9+5+3+3+2+0.5+0.4)*1e6,

    'US-California':39937489,
    'US-Texas':29472295,
    'US-Florida':21992985,
    'US-New York':19440469,
    'US-Pennsylvania':12820878,
    'US-Illinois':12659682,
    'US-Ohio':11747694,
    'US-Georgia':10736059,
    'US-North Carolina':10611862,
    'US-Michigan':10045029,
    'US-New Jersey':8936574,
    'US-Virginia':8626207,
    'US-Washington':7797095,
    'US-Arizona':7378494,
    'US-Massachusetts':6976597,
    'US-Tennessee':6897576,
    'US-Indiana':6745354,
    'US-Missouri':6169270,
    'US-Maryland':6083116,
    'US-Wisconsin':5851754,
    'US-Colorado':5845526,
    'US-Minnesota':5700671,
    'US-South Carolina':5210095,
    'US-Alabama':4908621,
    'US-Louisiana':4645184,
    'US-Kentucky':4499692,
    'US-Oregon':4301089,
    'US-Oklahoma':3954821,
    'US-Connecticut':3563077,
    'US-Utah':3282115,
    'US-Iowa':3179849,
    'US-Nevada':3139658,
    'US-Arkansas':3038999,
    'US-Puerto Rico':3032165,
    'US-Mississippi':2989260,
    'US-Kansas':2910357,
    'US-New Mexico':2096640,
    'US-Nebraska':1952570,
    'US-Idaho':1826156,
    'US-West Virginia':1778070,
    'US-Hawaii':1412687,
    'US-New Hampshire':1371246,
    'US-Maine':1345790,
    'US-Montana':1086759,
    'US-Rhode Island':1056161,
    'US-Delaware':982895,
    'US-South Dakota':903027,
    'US-North Dakota':761723,
    'US-Alaska':734002,
    'US-District of Columbia':720687,
    'US-Vermont':628061,
    'US-Wyoming':567025,
    
    'EU-Vatican City':801,
    'EU-United Kingdom':67886011,
    'EU-Ukraine':43733762,
    'EU-Turkey':84339067,
    'EU-Switzerland':8654622,
    'EU-Sweden':10099265,
    'EU-Spain':46754778,
    'EU-Slovenia':2078938,
    'EU-Slovakia':5459642,
    'EU-Serbia':8737371,
    'EU-San Marino':33931,
    'EU-Russia':145934462,
    'EU-Romania':19237691,
    'EU-Portugal':10196709,
    'EU-Poland':37846611,
    'EU-Norway':5421241,
    'EU-Netherlands':17134872,
    'EU-Montenegro':628066,
    'EU-Monaco':39242,
    'EU-Moldova':4033963,
    'EU-Malta':441543,
    'EU-Luxembourg':625978,
    'EU-Lithuania':2722289,
    'EU-Liechtenstein':38128,
    'EU-Latvia':1886198,
    'EU-Kazakhstan':18776707,
    'EU-Italy':60461826,
    'EU-Ireland':4937786,
    'EU-Iceland':341243,
    'EU-Hungary':9660351,
    'EU-Greece':10423054,
    'EU-Germany':83783942,
    'EU-Georgia':3989167,
    'EU-France':65273511,
    'EU-Finland':5540720,
    'EU-Faroe Islands':48863,
    'EU-Estonia':1326535,
    'EU-Denmark':5792202,
    'EU-Czech Republic':10708981,
    'EU-Cyprus':1207359,
    'EU-Croatia':4105267,
    'EU-Bulgaria':6948445,
    'EU-Bosnia and Herzegovina':3280819,
    'EU-Belgium':11589623,
    'EU-Belarus':9449323,
    'EU-Azerbaijan':10139177,
    'EU-Austria':9006398,
    'EU-Armenia':2963243,
    'EU-Andorra':77265,
    'EU-Albania':2877797,
    
    'China-Hubei':59e6, #wuhan=11, hubei=59 59e6
    'China-Guangdong':104e6,
    'China-Shandong':100e6,
    'China-Henan':94e6,
    'China-Beijing':20e6,
    'China-Hong Kong':7e6,
    
    'Singapore-': 5.6e6, #not enough data to calibrate
    'Japan-': 127e6
}



In [None]:
region = 'US'

c = train[train['Country_Region']==region]   #['Province_State'].unique()
c = c.groupby(['Province_State', 'Date']).sum().reset_index()

states = c[c['Fatalities']>500]['Province_State'].unique()
c = c[c['Province_State'].isin(states)]

minDate = c[c['Fatalities']>1]['Date'].min()
c = c[c['Date']>=minDate].copy()

ax = sns.lineplot(data=c, x='Date',y='Fatalities', hue='Province_State')
plt.yscale('log')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid()

ax.xaxis.set_major_formatter(mdates.DateFormatter('%B-%d'))            
plt.gcf().autofmt_xdate()

In [None]:
import urllib, json
url = 'https://covidtracking.com/api/states/daily'

import requests
r = requests.get(url)

#state, abrv = 'New York', 'NY'
#state, abrv = 'California', 'CA'
state, abrv = 'Georgia', 'GA'

data = pd.DataFrame(r.json())
data['date'] = pd.to_datetime(data['date'], format='%Y%m%d')
data = data.fillna(0)

d = data[data['state']==abrv]
d = d.sort_values(by='date')

d2 = train[train['State']==state]  


#--------------------------------
fig, axs = plt.subplots(1,3, figsize=(12,4))
fig.suptitle('{} \ndata from https://covidtracking.com'.format(state), fontsize=10)
#-----------------------------------
ax = plt.subplot(131)
plt.plot(d['date'], d['death'],'+-',label='death')
#plt.plot(d2['Date'], d2['Fatalities'],'+-',label='JH death')
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%B-%d'))
plt.legend()
plt.grid()

ax = plt.subplot(132)
plt.plot(d['date'], d['hospitalizedCurrently'],'+-',label='hospitalized')
plt.legend()
plt.grid()

ax = plt.subplot(133)
plt.plot(d['date'], d['positive'],'+-',label='positive')
#plt.plot(d2['Date'], d2['ConfirmedCases'],'+-',label='JH confirmed')
plt.legend()
plt.grid()

fig.autofmt_xdate()

#--------------------------------
fig, axs = plt.subplots(1,3, figsize=(12,4))
fig.suptitle('{} \ndata from https://covidtracking.com'.format(state), fontsize=10)
#--------------------------------
ax = plt.subplot(131)
plt.plot(d['date'], d['death'].diff(),'+-', label='daily new death')
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%B-%d'))
plt.legend()
plt.grid()

ax = plt.subplot(132)
plt.plot(d['date'], d['hospitalizedCurrently'].diff(),'+-',label='daily hospitalized change')
plt.legend()
plt.grid()

ax = plt.subplot(133)
plt.plot(d['date'], d['positive'].diff(),'+-',label='daily new positive')
plt.legend()
plt.grid()

fig.autofmt_xdate()
plt.show()

display(d.head())

In [None]:
#######################################################
# SIR model with INTERVENTION
#------------------------------------------------------
# params:
#
# x          : array of number of days since inception (not used except to size output); in the calibration below, inception starts on the first day reported fatalities reach a CUTOFF threshold
# i0         : initial percentage of infected population, for 1 per million: i0 = 1e-6
# beta       : initial daily rate of transmission by infected people to susceptible people, for R0=2.7 and gamma=1/21: beta=R0*gamma=2.7/21 
# gamma      : daily rate of recovery or death of infected people, for a 21 day speed of recovery or death: gamma = 1/21
# death_rate : daily death rate of infected people (assuming 1% of infected people die about 3 weeks after infection: death_rate=0.01/21)
#
# intervention_day : number of days after inception for intervention to start to reduce the initial transmission rate (beta)
# intervention_lag : number of days it takes for intervention to reach full effect (linear interp)
# intervention_effect : percentage reduction of initial transmission  rate, 0.25 for 25% reduction of initial beta after full intervention takes effect
########################################################

#-------------------------------------------------------
#the number returned by this function will be multiplied with the initial beta in order to estimate the transmission rate each day of the simulation
#the rate is linear with the given slope, starting from one; the slope then increases starting at day0 until day0+lag after which the effect is constant
def intervention(day, day0, lag=5, effect=0.25, slope=0):
    
    ret = max(0, min(1, 1 - slope * day))
    
    if day>day0+lag:
        ret *= 1.0 - effect
    elif day>day0:
        ret *= 1.0 - effect * (day-day0)/lag

    return max(0,min(1, ret))


days = np.arange(100)
effects = np.zeros(100)
for d in days:
    effects[d] = intervention(day=d, day0=20, lag=10, effect=0.75, slope=0.01)
plt.plot(days, effects)
plt.show()


#-------------------------------------------------------
# basic daily integration of a classic SIR model with a time-variable beta parameter=beta*intervention(day)
# the function returns a numpy matrix, with a row per day and the following columns (cumulative results since day of inception)
cS  = 0  #Susceptible people
cI  = 1  #Infected people
cR  = 2  #Recovered people
cF  = 3  #Fatalities
cP  = 4  #Positive cases (recovered people are not included)

def SIR5(x, population, i0, p0, f0, mixing, mixing_s, r0, phi, q, gamma, death_rate, intervention_day, intervention_lag, intervention_effect, beta_slope, detection_rate, detection2):
    
    y = np.zeros((x.size,5))
    
    beta = r0 * gamma
    death_rate = death_rate * gamma
    
    for i in range(0,x.size):
        
        if i==0:
            #initial conditions
            infected = i0
            positives = p0    
            fatalities = f0    
            recovered = f0 / (death_rate/gamma)  
            susceptible = population - infected - fatalities - recovered
          
        else:

            #compute daily variations           
            rate = beta * ((1-phi)*math.exp(-q*i)+phi)   #beta decays from beta0 to phi*beta0 with half-life 1/q
            rate = rate * intervention(day=i, day0=intervention_day, lag=intervention_lag, effect=intervention_effect, slope=beta_slope)
            detect = detection_rate * math.exp(- detection2 * i )
            
            d_fatalities = death_rate * infected
            d_recovered = (gamma - death_rate) * infected

            newlyinfected = rate * pow(susceptible/population, mixing_s) * pow(infected, mixing) #newly infected people, with a power law to have sub-exponential growth
            d_infected = newlyinfected - gamma * infected 
            d_susceptible = - newlyinfected
            d_positives = detect * newlyinfected

            #integrate and store in result array
            susceptible += d_susceptible
            positives += d_positives
            infected += d_infected
            recovered += d_recovered
            fatalities += d_fatalities
            
        y[i,cS] = susceptible
        y[i,cI] = infected
        y[i,cR] = recovered
        y[i,cF] = fatalities
        y[i,cP] = positives  #cumul of infected, does not come down on recovery. assuming all newly infected people are immediately detected
            
    return y

x = np.arange(300)




In [None]:
#extract the data for the given region or state and prepare it for the calibration
def prep_data(data, region='US', state='New York', cutoff=1, truncate=0):
    
    c = data[data['Region']==region]
    if state != '':
        c = c[c['State']==state]
    
    c = c.groupby(['Date']).sum().reset_index()
    
    state = region + '-' + state
    c['State'] = state

    #find the first date when the fatalities cutoff was reached by this STATE, and keep only these days for calibration
    minDate = c[c['Fatalities']>cutoff]['Date'].min()
    
    s1 = c[c['Date']>=minDate].copy()  #keep only the records after the given number of fatalities have been reached
    if truncate != 0:
        s1 = s1[:truncate].copy()  #keep only the given number of days
        

    #calculate the number of days since the first day fatalities exceeded the cutoff
    s1['Days'] = (s1['Date'] - minDate) / np.timedelta64(1, 'D')

    return minDate, s1

In [None]:

def earlygrowth(data, region, state, cutoff):
    
    def growthmodel(x, i0, a):  
        return np.exp(a * x) * i0

    #https://arxiv.org/abs/1709.00973  
    #solution of sub-exponential growth of the form df/dt = r.f(t)^a
    def subgrowthmodel(x, a, b, r):
        if a==1:
            return b * np.exp(r*x)
        else:
            return (r*(1-a)*x + b**(1-a))**(1/(1-a))

    def piecewiseloglinear(x, a1, a2, b, d0):
        #r = a1 * x + b
        r = np.where(x<d0, a1 * x + b, a1 * d0 + b + a2 * (x-d0))
        return r
        #return np.exp(r)

    def decaybeta(x, i0, gamma, r0, phi, q):
        beta  = r0*gamma #*((1-phi)*np.exp(-q*x)+phi)
        #print(beta-gamma)
        return i0 * np.exp((beta-gamma)*x)
        
    minDate, s1 = prep_data(data, region=region, state=state, cutoff=cutoff, truncate=0)
    population = Population[region + '-' + state]

    s1['NewFatalities'] = s1['Fatalities'].diff()
    s1['NewCases'] = s1['ConfirmedCases'].diff()

    x = s1['Days'].copy()
    n =len(x)
    
    fig, axs=  plt.subplots(1,3, figsize=(12,6))

    ax = plt.subplot(131)
    plt.plot(s1['Days'],s1['NewFatalities'], label='new fatalities')
    plt.plot(s1['Days'],s1['Fatalities'], label='fatalities')
    plt.plot(s1['Days'],s1['ConfirmedCases'], label='positives')
    plt.yscale('log')
    plt.legend()

    ax = plt.subplot(132)
    plt.plot(s1['Days'],s1['ConfirmedCases'], 'k*', label='positives')
    print("{}-{} positives:".format(region,state))
    z = s1['ConfirmedCases'].copy()

    popt, pcov = curve_fit(growthmodel, x, z, p0=(500, 1/7))
    y1 = growthmodel(x, *popt)
    plt.plot(s1['Days'],y1, label='growthmodel')
    print('growthmodel: ',popt)   
    
    y1 = decaybeta(x, i0=500,gamma=1/7,r0=2,phi=0.2,q=1/180)
    #plt.plot(s1['Days'],y1, label='decaybeta')
    
    popt, pcov = curve_fit(decaybeta, x, z, 
                  p0=(500, 1/7, 2, 0.1, 1/50),
                  bounds = ((1,1/21,1.1,0,1/1000),(1e6,1/4,4,1,1/2)))
    
    print('decaybeta: i0={} 1/g={} r0={} phi={} 1/q={}'.format(popt[0], 1/popt[1], popt[2], popt[3], 1/popt[4]))
    y1 = decaybeta(x, *popt)
    plt.plot(s1['Days'],y1, label='decaybeta')
    
    
    for n in range(n-2, n):
        popt, pcov = curve_fit(piecewiseloglinear,x[:n], np.log(z[:n]/population),
                              p0 = (0.10,0.10,1e-3,40))
        print(popt)
#        y1 = population * np.exp(piecewiseloglinear(x, *popt))
#        plt.plot(s1['Days'][:n],y1[:n], label='{} - {:.2f}'.format(n, popt[0]))
#        plt.plot(s1['Days'][1:n],np.diff(y1[:n]), label='{} - {:.2f}'.format(n, popt[0]))
    plt.yscale('log')
    plt.legend()

    ax = plt.subplot(133)
    plt.plot(s1['Days'],s1['Fatalities'], 'ko', label='deaths')
    print("{}-{} fatalities:".format(region,state))
    z = s1['Fatalities'].copy()
    for n in range(n-2, n):
        popt, pcov = curve_fit(piecewiseloglinear,x[:n], np.log(z[:n]/population))
        print(popt)
        y1 = population * np.exp(piecewiseloglinear(x, *popt))
        plt.plot(s1['Days'][:n],y1[:n], label='{} - {:.2f}'.format(n, popt[0]))
        plt.plot(s1['Days'][1:n],np.diff(y1[:n]), label='{} - {:.2f}'.format(n, popt[0]))
    plt.yscale('log')
    plt.legend()

    plt.show()

region="EU"
state="Italy"
earlygrowth(train, region=region, state=state, cutoff=10)


In [None]:
# wrapper to make a dict look like a class, to simplify access to members
# https://goodcode.io/articles/python-dict-object/
    
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d
        
class objdict(dict):
    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError("No such attribute: " + name)

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        if name in self:
            del self[name]
        else:
            raise AttributeError("No such attribute: " + name)
            
      
d = {'a': 1, 'b': 2}

o1 = objectview(d)
print(o1.a)

o2= objdict(d)
print(o2.b)
o2.c = 3
print(o2.c)

params = objdict({})
params.mixing   = 1
params.mixing_s = 1
params.phi      = 1
params.q        = 1
params.gamma    = 1/5
params.lag      = 1

print(params)


In [None]:
def growthmodel(x, r, a):  
    return np.exp(a * x) * r

def linmodel(x, r, a):  
    return a * x  + r


class ModelLocal:

        def __init__(self, region, state):
            self.region=region
            self.state=state
            
        def SIR_calib(self, assumption, param_p, param_f):

            population     = assumption['population']
            gamma          = 1 / assumption['gamma_d']
            detection_rate = assumption['detection_rate']

            pr = param_p[0] 
            pa = param_p[1]
            
            fr = param_f[0] 
            fa = param_f[1]
            
            # dF/dt = death_rate * gamma * I0 * exp[gamma*(R0-1)*t]      = fr * exp(fa * t)
            # dP/dt = detection_rate * gamma * R0 * I0 * exp[gamma*(R0-1)*t]     = pr * exp(pa * t)
            
            #both curves should have the same growth rate; take the average
            r0_f = 1 + fa/gamma
            r0_p = 1 + pa/gamma
            r0 = r0_f #r0_p #(r0_f+r0_p)/2

            #get I0 from positive curves, given assumption of detection rate
            i0 = pr / (detection_rate * gamma * r0)
            
            #get fatilities rate from fatalities and I0
            death_rate = fr / gamma / i0 

            calib = {
                    'population'      : population, 
                    'i0'              : i0,
                    'gamma_d'         : 1/gamma,
                    'beta'            : r0*gamma,
                    'r0'              : r0,
                    'death_rate'      : death_rate,
                    'detection_rate'  : detection_rate,
                    'detection2'      : 0,

                    'r0_f'            : r0_f,
                    'r0_p'            : r0_p,
                
                    'pr'              : pr,
                    'pa'              : pa,
                    'fr'              : fr,
                    'fa'              : fa,
                    'label'           : assumption['label']
            }
            return calib

        def SIR_calib_2(self, assumption, param_p, param_f):

            population     = assumption['population']
            gamma          = 1 / assumption['gamma_d']
            death_rate     = assumption['death_rate']

            pr = param_p[0] 
            pa = param_p[1]
            
            fr = param_f[0] 
            fa = param_f[1]
            
            # dF/dt = death_rate * gamma * I0 * exp[gamma*(R0-1)*t]      = fr * exp(fa * t)
            # dP/dt = detection_rate * gamma * R0 * I0 * exp[gamma*(R0-1)*t]     = pr * exp(pa * t)
            
            #get R0 from the slope of the fatalities curve and the assumption about gamma
            r0_f = 1 + fa/gamma
            r0_p = 1 + pa/gamma
            r0 = r0_f 

            #also get I0 from the fatalities curve under the given assumptions for gamma and death_rate
            i0 = fr / (death_rate * gamma)
            
            #get the detection rate from the positives curve, under the assumptions for gamma and death_rate and the I0 and R0 obtained from the death curve
            detection_rate = pr / (gamma * r0 * i0) 

            calib = {
                    'population'      : population, 
                    'i0'              : i0,
                    'gamma_d'         : 1/gamma,
                    'beta'            : r0*gamma,
                    'r0'              : r0,
                    'death_rate'      : death_rate,
                    'detection_rate'  : detection_rate,
                    'detection2'      : 0,

                    'r0_f'            : r0_f,
                    'r0_p'            : r0_p,
                
                    'pr'              : pr,
                    'pa'              : pa,
                    'fr'              : fr,
                    'fa'              : fa,
                    'label'           : assumption['label']
            }
            return calib

        def SIR_calib_3(self, assumption, param_p, param_f):

            population     = assumption['population']
            gamma          = 1 / assumption['gamma_d']
            death_rate     = assumption['death_rate']

            pr = param_p[0] 
            pa = param_p[1]
            
            fr = param_f[0] 
            fa = param_f[1]
            
            # dF/dt = death_rate * gamma * I0 * exp[gamma*(R0-1)*t]      = fr * exp(fa * t)
            # dP/dt = detection_rate * gamma * R0 * I0 * exp[gamma*(R0-1)*t]     = pr * exp(pa * t)
            
            #get R0 from the slope of the fatalities curve and the assumption about gamma
            r0_f = 1 + fa / gamma
            r0_p = 1 + pa / gamma
            r0 = max(0,r0_f)
            #print('g={} r0={} r0_f={} fa={} recalc={}'.format(1/gamma, r0, r0_f, fa, 1+fa/gamma))

            #also get I0 from the fatalities curve under the given assumptions for gamma and death_rate
            i0 = fr / (death_rate * gamma)
            
            #if not given, get the dynamics of detection from the observation of a different slope between fatalities and positives
            #dP / dt = detection(t) * beta * S(t) * I(t); detection(t) = detection_rate * exp(-t * detection2) with detection 2 being the half-life in days
            if 'detection2' in assumption:
                detection2 = assumption['detection2']
            else:
                detection2 = fa-pa  
                
            #get the detection rate from the positives curve, under the assumptions for gamma and death_rate and the I0 and R0 obtained from the death curve
            detection_rate = pr / (gamma * r0 * i0) 
            #detection_rate = min(1, max(0,detection_rate))

            calib = {
                    'population'      : population, 
                    'i0'              : i0,
                    'gamma_d'         : 1/gamma,
                    'beta'            : r0*gamma,
                    'r0'              : r0,
                    'death_rate'      : death_rate,
                    'detection_rate'  : detection_rate,
                    'detection2'      : detection2,

                    'r0_f'            : r0_f,
                    'r0_p'            : r0_p,
                
                    'pr'              : pr,
                    'pa'              : pa,
                    'fr'              : fr,
                    'fa'              : fa,
                    'label'           : assumption['label']
            }
            return calib
        
    
        def fit(self, days, positives, fatalities, assumptions):
            
            self.assumptions = assumptions
            self.days = days
            self.positives = positives
            self.fatalities = fatalities
            
            dpositives = np.diff(positives)
            dfatalities = np.diff(fatalities)

            logdpositives = np.log(np.where(dpositives<1,1,dpositives))
            logdfatalities = np.log(np.where(dfatalities<1,1,dfatalities))
            
            self.calibs_1 = []
            for a in self.assumptions:

                #fit exponential growth models to initial data
                n1 = a['n1']

                if ('method' in a) and (a['method']=='linear'):
                    param_p1, pcov = curve_fit(linmodel, days[:n1], logdpositives[:n1], p0=(100, 1/7))
                    param_p1[0] = math.exp(param_p1[0])

                    param_f1, pcov = curve_fit(linmodel, days[:n1], logdfatalities[:n1], p0=(1, 1/7))
                    param_f1[0] = math.exp(param_f1[0])
                else:
                    param_p1, pcov = curve_fit(growthmodel, days[:n1], dpositives[:n1], p0=(100, 1/7))
                    param_f1, pcov = curve_fit(growthmodel, days[:n1], dfatalities[:n1], p0=(1, 1/7))

                
                if 'death_rate' in a:
                    calib = self.SIR_calib_3(assumption=a, param_p=param_p1, param_f=param_f1)
                else:
                    calib = self.SIR_calib(assumption=a, param_p=param_p1, param_f=param_f1)

                calib['p0'] = positives[0]
                calib['f0'] = fatalities[0]
                calib['n'] = n1

                #a['detection2'] = calib['detection2']  #pass the detection rate dynamics to the late calibration
                
                self.calibs_1.append(calib)
            

            self.calibs_2 = []
            for a in self.assumptions:

                #fit exponential growth models to initial data
                n2 = a['n2']

                #fit exponential growth models to latest data; reset t=0 to the beginning of the calibration data
                param_p2, pcov = curve_fit(growthmodel, np.arange(n2), dpositives[-n2:], p0=(100, 1/7))
                param_f2, pcov = curve_fit(growthmodel, np.arange(n2), dfatalities[-n2:], p0=(1, 1/7))            

                if 'death_rate' in a:
                    calib = self.SIR_calib_3(assumption=a, param_p=param_p2, param_f=param_f2)
                else:
                    calib = self.SIR_calib(assumption=a, param_p=param_p2, param_f=param_f2)

                calib['p0'] = positives[-n2]
                calib['f0'] = fatalities[-n2]
                calib['n'] = n2
                self.calibs_2.append(calib)
                
        def display(self, minDate, forecast, ax_diff, ax_p, ax_f, ax_i):
            
            positives = self.positives
            fatalities = self.fatalities
            x = self.days
            xd = minDate + np.arange(len(x)) * timedelta(days=1)

            ax_p.set_title('{}-{} Cumulative Positives'.format(region, state))
            ax_f.set_title('{}-{} Cumulative Fatalities'.format(region,state))
            ax_diff.set_title('{}-{} Daily Values'.format(region,state))
            ax_i.set_title('{}-{} Infectious Cases'.format(region,state))
            
            early_colors = plt.get_cmap('seismic')(np.linspace(0,0.4,len(self.assumptions)))
            late_colors = plt.get_cmap('seismic')(np.linspace(0.6, 1,len(self.assumptions)))
#            early_colors = plt.get_cmap('Blues')(np.linspace(0,1,len(assumptions)))
#            late_colors = plt.get_cmap('YlOrRd')(np.linspace(0,1,len(assumptions)))
            
            ax_p.plot(xd, positives, 'kx:', label='positives')
            ax_f.plot(xd, fatalities, 'k+:', label='fatalities')

            ax_diff.plot(xd[:len(positives)-1], np.diff(positives), 'kx:', label='positives')
            ax_diff.plot(xd[:len(positives)-1], np.diff(fatalities), 'k+:', label='fatalities')
            
            xx = np.arange(x[0], forecast)
            xxd = minDate + np.arange(x[0],forecast) * timedelta(days=1)
            for i,c in enumerate(self.calibs_1):

                n1 = c['n']
                
                r, a = c['pr'], c['pa']
                y1 = growthmodel(self.days, r=r, a=a)
                #ax_diff.plot(self.days[:n1], y1[:n1], 'k-', label='')
                ax_diff.plot(xd[:n1], y1[:n1], 'k-', label='')
                
                r, a = c['fr'], c['fa']
                y1 = growthmodel(self.days, r=r, a=a)
                #ax_diff.plot(self.days[:n1], y1[:n1], 'k-', label='')
                ax_diff.plot(xd[:n1], y1[:n1], 'k-', label='')
                
                
                p = objdict(c.copy())
                y2 = SIR5(xx, 
                          population=p.population, i0=p.i0, p0=p.p0, f0=p.f0,
                          mixing=1, mixing_s=1, 
                          r0=p.r0, phi=1, q=1, gamma=1/p.gamma_d, 
                          death_rate=p.death_rate , 
                          intervention_day=0, intervention_lag=1, intervention_effect=0, 
                          beta_slope=0, 
                          detection_rate=p.detection_rate, detection2=p.detection2)
  
                c['current i0'] = y2[len(x)-1, cI]
                c['recovered'] = y2[len(x)-1, cR]

                color = early_colors[i]
                ax_p.plot(xxd, y2[:,cP], c=color, linestyle='-', label=c['label'])
                ax_f.plot(xxd, y2[:,cF], c=color, linestyle='-', label=c['label'])
                ax_diff.plot(xxd[0:n1-1], np.diff(y2[:n1,cF]), c=color, linestyle='-', label=c['label'])
                ax_diff.plot(xxd[0:n1-1], np.diff(y2[:n1,cP]), c=color, linestyle='-', label='')
                ax_i.plot(xxd, y2[:,cI], c=color, linestyle='-', label=c['label'])

            for i,c in enumerate(self.calibs_2):

                n2 = c['n']
                xx = np.arange(x[-n2], forecast)
                xxd = minDate + np.arange(x[-n2], forecast) * timedelta(days=1)
                
                r, a = c['pr'], c['pa']
                y1 = growthmodel(np.arange(n2), r=r, a=a)
                #ax_diff.plot(self.days[-n2:], y1[-n2:], 'k-', label='')
                ax_diff.plot(xd[-n2:], y1[-n2:], 'k-', label='')

                r, a = c['fr'], c['fa']
                y1 = growthmodel(np.arange(n2), r=r, a=a)
                #ax_diff.plot(self.days[-n2:], y1[-n2:], 'k-', label='')
                ax_diff.plot(xd[-n2:], y1[-n2:], 'k-', label='')
                
                
                p = objdict(c.copy())
                y2 = SIR5(xx, 
                          population=p.population, i0=p.i0, p0=p.p0, f0=p.f0, 
                          mixing=1, mixing_s=1, 
                          r0=p.r0, phi=1, q=1, gamma=1/p.gamma_d, 
                          death_rate=p.death_rate , 
                          intervention_day=0, intervention_lag=1, intervention_effect=0, 
                          beta_slope=0, 
                          detection_rate=p.detection_rate, detection2=p.detection2)

                #display('currently infectious {:,.0f}'.format(y2[n2, cI]))
                c['current i0'] = y2[n2, cI]
                c['recovered'] = y2[n2, cR]
                
                color = late_colors[i]
                ax_p.plot(xxd, y2[:,cP], c=color, linestyle='-', label='')
                ax_f.plot(xxd, y2[:,cF], c=color, linestyle='-', label='')
                ax_diff.plot(xxd[1:n2], np.diff(y2[:n2,cF]), c=color, linestyle='-', label='')
                ax_diff.plot(xxd[1:n2], np.diff(y2[:n2,cP]), c=color, linestyle='-', label='')
                ax_i.plot(xxd, y2[:,cI], c=color, linestyle='-', label='')
                
                
            ax_p.set_yscale('log')
            ax_p.legend()
            ax_p.grid(which='both')

            ax_f.set_yscale('log')
            ax_f.legend()
            ax_f.grid(which='both')
            
            ax_diff.set_yscale('log')
            ax_diff.legend()
            ax_diff.grid(which='both')

            ax_i.set_yscale('log')
            ax_i.legend()
            ax_i.grid(which='both')
            

            format_dict = {'population':'{:,.0f}', 
                           'i0': '{:,.0f}',
                           'current i0': '{:,.0f}',
                           'recovered': '{:,.0f}',
                           'p0': '{:,.0f}',
                           'f0': '{:,.0f}',
                           'gamma_d': '{:.1f}',
                           'r0': '{:.2f}',
                           'r0_f': '{:.2f}',
                           'r0_p': '{:.2f}',
                           'death_rate': '{:.1%}',
                           'detection_rate': '{:.1%}',
                           'detection2_d': '{:.0f}',
                          }

            display('early fit')
            r1 = pd.DataFrame(m.calibs_1)
            r1['detection2_d'] = 1/r1['detection2']
            display(r1.style.format(format_dict).hide_index())

            display('late fit')
            r2 = pd.DataFrame(m.calibs_2)
            r2['detection2_d'] = 1/r2['detection2']
            display(r2.style.format(format_dict).hide_index())

        def whatif(self, minDate, forecast, calibs, ax_p, ax_f, ax_i):
            
            positives = self.positives
            fatalities = self.fatalities
            x = self.days
            xd = minDate + np.arange(len(x)) * timedelta(days=1)

            ax_p.set_title('{}-{} Cumulative Positives'.format(self.region, self.state))
            ax_f.set_title('{}-{} Cumulative Fatalities'.format(self.region,self.state))
            ax_i.set_title('{}-{} Infectious Cases'.format(self.region,self.state))
            
            ax_p.plot(xd, positives, 'kx:', label='positives')
            ax_f.plot(xd, fatalities, 'k+:', label='fatalities')

            xx = np.arange(x[0], forecast)
            xxd = minDate + np.arange(x[0],forecast) * timedelta(days=1)

            for c in calibs:
                p = objdict(c.copy())
                n2 = p.n
                xx = np.arange(x[-n2], forecast)
                xxd = minDate + np.arange(x[-n2], forecast) * timedelta(days=1)

                y2 = SIR5(xx, 
                          population=p.population, i0=p.i0, p0=p.p0, f0=p.f0, 
                          mixing=1, mixing_s=1, 
                          r0=p.r0, phi=1, q=1, gamma=1/p.gamma_d, 
                          death_rate=p.death_rate , 
                          intervention_day=0, intervention_lag=1, intervention_effect=0, 
                          beta_slope=0, 
                          detection_rate=p.detection_rate, detection2=p.detection2)

                ax_p.plot(xxd, y2[:,cP], linestyle='-', label=p.label)
                ax_f.plot(xxd, y2[:,cF], linestyle='-', label=p.label)
                ax_i.plot(xxd, y2[:,cI], linestyle='-', label=p.label)
            
            ax_p.legend()
            ax_p.grid(which='both')
            ax_p.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
        
            ax_f.legend()
            ax_f.grid(which='both')
            ax_f.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))            
            
            ax_i.legend()
            ax_i.grid(which='both')
            ax_i.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
            
#---------------------------------------------------------------
region   = 'EU'
state    = 'France'
cutoff   = 100
truncate = 0   #use negative number to remove latest points
population = Population[region+'-'+state]


minDate, s1 = prep_data(train, region=region, state=state, cutoff=cutoff, truncate=truncate)
print("calibration starts on: ", minDate)

x = s1['Days'].to_numpy().copy()
positives = s1['ConfirmedCases'].to_numpy().copy()
fatalities = s1['Fatalities'].to_numpy().copy()

#---------------------------
fig, axs = plt.subplots(2,3,figsize=(18,6))

ax = plt.subplot(231)
plt.plot(s1['Date'], fatalities, 'k+-', label='fatalities')
plt.grid()
plt.legend()
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

ax = plt.subplot(232)
plt.plot(s1['Date'][1:], np.diff(fatalities), 'k*-', label='new daily fatalities')
plt.grid()
plt.legend()
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

ax=plt.subplot(233)
plt.plot(s1['Date'], fatalities, 'k+-', label='fatalities')
plt.plot(s1['Date'][1:], np.diff(fatalities), 'k*-', label='new daily fatalities')
plt.yscale('log')
plt.grid()
plt.legend()
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

#--------------

ax = plt.subplot(234)
plt.plot(s1['Date'], positives, 'k+-', label='positives')
plt.grid()
plt.legend()
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

ax = plt.subplot(235)
plt.plot(s1['Date'][1:], np.diff(positives), 'k*-', label='new daily positives')
plt.grid()
plt.legend()
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

ax=plt.subplot(236)
plt.plot(s1['Date'], positives, 'k+-', label='positives')
plt.plot(s1['Date'][1:], np.diff(positives), 'k*-', label='new daily positives')
plt.yscale('log')
plt.grid()
plt.legend()
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

fig.autofmt_xdate()
plt.show()
#---------------------------


assumptions = [
    {'population':population, 'gamma_d':7,  'death_rate':0.005, 'n1': 15, 'n2':20, 'label':'g=7, fatality=0.5%'}, #, 'method':'linear'},
    {'population':population, 'gamma_d':7,  'death_rate':0.02,  'n1': 15, 'n2':20, 'label':'g=7, fatality=2%'},   #,'method':'linear'},
]

m = ModelLocal(state=state, region=region)
c = m.fit(x, positives, fatalities, assumptions)

fig, axs = plt.subplots(1,4,figsize=(24,6))
fig.autofmt_xdate()
for ax in axs:
    ax.xaxis.set_major_locator(AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

m.display(minDate=minDate, ax_diff=axs[0], ax_p=axs[1], ax_f=axs[2], ax_i=axs[3], forecast=180)

plt.show()


#----------------------------------------
fig, axs = plt.subplots(1,3,figsize=(18,6))
fig.autofmt_xdate()
for ax in axs:
    ax.xaxis.set_major_locator(AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

calib0 = m.calibs_1[0].copy()
calib0['n']=0
calib0['label'] = 'initial'
calib1 = m.calibs_2[0].copy()

calib2 = calib1.copy()
calib2['n'] = 1
calib2['r0'] = m.calibs_1[0]['r0']
calib2['i0'] = calib2['current i0'] 
calib2['p0'] = positives[-1]
calib2['f0'] = fatalities[-1]
calib2['label'] = 'what-if'
calibs = [calib0, calib1, calib2]

m.whatif(minDate=minDate, calibs=calibs, ax_p=axs[0], ax_f=axs[1], ax_i=axs[2], forecast=180)

plt.show()

In [None]:
def prep_trackingdata(data, state, abrv, cutoff, truncate):
    
    population = Population['US-'+state]

    c = data[data['state']==abrv].copy()
    c = c.sort_values(by='date', ascending=True)
    c['State'] = state
    
    #find the first date when the fatalities cutoff was reached by this STATE, and keep only these days for calibration
    minDate = c[c['death']>cutoff]['date'].min()

    #keep only the records after the given number of fatalities have been reached
    s1 = c[c['date']>=minDate].copy()  
    
    #keep only the given number of days from the beginning, or remove the given number of days from the end
    if truncate != 0:
        s1 = s1[:truncate].copy()  #keep only the given number of days

    #calculate the number of days since the first day fatalities exceeded the cutoff
    s1['Days'] = (s1['date'] - minDate) / np.timedelta64(1, 'D')

    x = s1['Days'].to_numpy().copy()
    positives = s1['positive'].to_numpy().copy()
    fatalities = s1['death'].to_numpy().copy()

    return population, minDate, x, positives, fatalities
    

In [None]:
#---------------------------------------------------------------
region         = 'US'
#state, abrv    = 'New York', 'NY'
#state, abrv    = 'California', 'CA'
#state, abrv    = 'New Jersey', 'NJ'
#state, abrv    = 'Connecticut', 'CT'
#state, abrv    = 'Massachusetts', 'MA'
#state, abrv    = 'Florida', 'FL'
#state, abrv    = 'Louisiana', 'LA'
#state, abrv    = 'Washington', 'WA'
#state, abrv    = 'Mississippi', 'MS'
#state, abrv    = 'Ohio', 'OH'
#state, abrv    = 'District of Columbia', 'DC'
state, abrv    = 'Georgia', 'GA'

cutoff   = 25
truncate = 0   #use negative number to remove latest points

population, minDate, x, positives, fatalities = prep_trackingdata(data=data, state=state, abrv=abrv, cutoff=cutoff, truncate=truncate)
print("Calibration starts on {} for a cutoff of {} fatalities".format(minDate, cutoff))

assumptions = [
#    {'population':population, 'gamma_d':14,  'death_rate':0.005,  'n1': 15, 'n2':15, 'label':'n=15, g=14, fatality=0.5%','method':'linear'},
    {'population':population, 'gamma_d':7,  'death_rate':0.005,  'n1': 15, 'n2':15, 'label':'n=15, g=7, fatality=0.5%','method':'linear'},
    {'population':population, 'gamma_d':7,  'death_rate':0.02,  'n1': 15, 'n2':15, 'label':'n=15, g=7, fatality=2%'   ,'method':'linear'},
]

m = ModelLocal(state=state, region=region)
c = m.fit(x, positives, fatalities, assumptions)

fig, axs = plt.subplots(1,4,figsize=(24,6))
fig.autofmt_xdate()
for ax in axs:
    ax.xaxis.set_major_locator(AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

m.display(minDate=minDate, ax_diff=axs[0], ax_p=axs[1], ax_f=axs[2], ax_i=axs[3], forecast=50)

plt.show()


#----------------------------------------
fig, axs = plt.subplots(1,3,figsize=(18,6))
fig.autofmt_xdate()
for ax in axs:
    ax.xaxis.set_major_locator(AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

calib0 = m.calibs_1[0].copy()
calib0['n']=0

calib1 = m.calibs_2[0].copy()
calib1['label'] = 'current forecast'

calib2 = calib1.copy()
calib2['n'] = 1
calib2['r0'] = m.calibs_1[0]['r0']
calib2['i0'] = calib2['current i0'] #750000 #this should come from the calibration
calib2['p0'] = positives[-1]
calib2['f0'] = fatalities[-1]
calib2['label'] = 'what-if'
calibs = [calib0, calib1, calib2]

m.whatif(minDate=minDate, calibs=calibs, ax_p=axs[0], ax_f=axs[1], ax_i=axs[2], forecast=180)
plt.show()





In [None]:
#try to calibrate perfect data

from datetime import datetime, timedelta
region='US'
state='New York'
population=Population[region+'-'+state]
i0=100

mixing=1
mixing_s=1
gamma=1/14
r0=2
phi=1
q=1
death_rate=0.01
intervention_day = 25
intervention_lag=1
intervention_effect = 0.6
beta_slope = 0
detection_rate=0.3
detection2=1/100

p0=i0 * detection_rate
f0=i0 * death_rate

n=50
x = np.arange(0,n)
y0 =  SIR5(x, population=population, i0=i0, p0=i0, f0=i0, mixing=mixing, mixing_s=mixing_s, r0=r0, phi=phi, q=q, gamma=gamma, death_rate=death_rate,
           intervention_day=intervention_day, intervention_lag=intervention_lag, intervention_effect=intervention_effect, beta_slope=beta_slope,
           detection_rate=detection_rate, detection2=detection2)

test = pd.DataFrame(y0[:,cF], columns=['Fatalities'])
test['ConfirmedCases'] = y0[:,cP]
test['NewFatalities'] = test['Fatalities'].diff(1)/1
test['NewCases'] = test['ConfirmedCases'].diff(1)/1
test['Date'] = datetime(2020,1,1) + np.arange(n) * timedelta(days=1)
test['Region']='US'
test['State']='New York'

test['I'] = y0[:,cI]


display(test.tail())

#########
minDate, s1 = prep_data(test, region='US', state='New York', cutoff=10, truncate=0)
x = s1['Days'].to_numpy().copy()
positives = s1['ConfirmedCases'].to_numpy().copy()
fatalities = s1['Fatalities'].to_numpy().copy()

#plt.plot(s1['Days'][1:], np.diff(fatalities))

assumptions = [
    {'population':population, 'gamma_d':14, 'death_rate':0.005, 'n1':20, 'n2':20, 'label':'0.5%'},
    {'population':population, 'gamma_d':14, 'death_rate':0.01, 'n1':20, 'n2':20, 'label':'1%'},
    {'population':population, 'gamma_d':14, 'death_rate':0.02, 'n1':20, 'n2':20, 'label':'2%'},
]

m = ModelLocal(region,state)
c = m.fit(x, positives, fatalities, assumptions)

fig, axs = plt.subplots(1,4,figsize=(24,6))
m.display(minDate=minDate, ax_diff=axs[0], ax_p=axs[1], ax_f=axs[2], ax_i=axs[3], forecast=70)

fig.autofmt_xdate()    
plt.show()


#----------------------------------------
fig, axs = plt.subplots(1,3,figsize=(18,6))
fig.autofmt_xdate()
for ax in axs:
    ax.xaxis.set_major_locator(AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))

calib0 = m.calibs_1[0].copy()
calib0['n'] = 0

calib1 = m.calibs_2[0].copy()

calib2 = calib1.copy()
calib2['n'] = 1
calib2['r0'] = m.calibs_1[0]['r0']
calib2['i0'] = calib2['current i0'] #750000 #this should come from the calibration
calib2['p0'] = positives[-1]
calib2['f0'] = fatalities[-1]
calib2['label'] = 'what-if'
calibs = [calib0, calib1, calib2]

m.whatif(minDate=minDate, calibs=calibs, ax_p=axs[0], ax_f=axs[1], ax_i=axs[2], forecast=60)

plt.show()



In [None]:

region = 'US'

c = train[train['Country_Region']==region]   #['Province_State'].unique()
c = c.groupby(['Province_State', 'Date']).sum().reset_index()
states = c[c['Fatalities']>1000]['Province_State'].unique()
c = c[c['Province_State'].isin(states)]

countries = c['Province_State'].unique()

results = []
for state in countries:
    try:
        m = ModelLocal(region,state)
        population = Population[region+'-'+state]

        minDate, s1 = prep_data(train, region=region, state=state, cutoff=100, truncate=0)

        x = s1['Days'].to_numpy().copy()
        positives = s1['ConfirmedCases'].to_numpy().copy()
        fatalities = s1['Fatalities'].to_numpy().copy()

        assumptions = [
            {'population':population, 'gamma_d':7, 'death_rate':0.002, 'n1':10, 'n2':10, 'label':'fatality=0.2%', 'method':'linear'},
            {'population':population, 'gamma_d':7, 'death_rate':0.01,  'n1':10, 'n2':10, 'label':'fatality=1%',   'method':'linear'},
        ]    

        m.fit(x, positives, fatalities, assumptions)

        for i in range(len(assumptions)):
            res = {'region'           : m.region,
                   'state'            : m.state,
                   'label'            : m.calibs_1[i]['label'],
                   'early_r0'         : m.calibs_1[i]['r0'],
                   'late_r0'          : m.calibs_2[i]['r0'],
                   'early_death_rate' : m.calibs_1[i]['death_rate'],
                   'late_death_rate'  : m.calibs_2[i]['death_rate'],
                   'early_fatalities' : fatalities[i],
                   'late_fatalities'  : fatalities[-1],
                   'early_positives'  : positives[i],
                   'late_positives'   : positives[-1],
                   'population'       : m.calibs_1[i]['population']
                  }
            results.append(res)
    except:
        pass
    
res = pd.DataFrame(results)
display(res)

fig, axs = plt.subplots(ncols=3, figsize=(18,6))

sns.lineplot(data=res, x='state', y='late_r0', label='late R0', marker='o', ax=axs[0])
sns.lineplot(data=res, x='state', y='early_r0',label='early R0', marker='o',  ax=axs[0])
axs[0].tick_params(axis='x', labelrotation=90)
#plt.xticks(rotation=90)

sns.lineplot(data=res, x='state', y='late_death_rate',  err_style="bars", label='late death rate', marker='o', ax=axs[1])
sns.lineplot(data=res, x='state', y='early_death_rate',  err_style="bars", label='early death rate', marker='o', ax=axs[2])
axs[1].yaxis.set_major_formatter(mtick.PercentFormatter(1))
axs[1].tick_params(axis='x', labelrotation=90)
axs[2].yaxis.set_major_formatter(mtick.PercentFormatter(1))
axs[2].tick_params(axis='x', labelrotation=90)
axs[1].set_ylim(0,0.1)
axs[2].set_ylim(0,0.1)

