In [None]:
from __future__ import absolute_import
from __future__ import print_function

import requests
import io

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torchdiffeq import odeint_adjoint as odeint
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [None]:
torch.cuda.is_available()

In [None]:
class SIRHCDQ(nn.Module):

    def __init__(self):
        super(SIRHCDQ, self).__init__()
        self.t_inc = torch.Tensor([14.0]).double()#.cuda()
        self.R_t = torch.Tensor([1.0]).double()#.cuda()
        
        self.t_hosp = nn.Parameter(9.5*torch.rand(1)+0.5,requires_grad=True) # (0.5, 10)
        self.t_crit = nn.Parameter(18*torch.rand(1)+2,requires_grad=True) # (2, 20)
        self.m_a = nn.Parameter(0.5*torch.rand(1)+0.5,requires_grad=True) # (0.5, 1)
        self.c_a = nn.Parameter(torch.rand(1),requires_grad=True) # (0, 1)
        self.f_a = nn.Parameter(torch.randn(1),requires_grad=True) #(0, 1)
        self.net = nn.Sequential(
            nn.Linear(7, 10,bias=True),
            nn.ReLU(),
            nn.Linear(10, 20,bias=True),
            nn.ReLU(),
            nn.Linear(20, 1,bias=True),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    
    def dS_dt(self, S, I, Q, R_t, t_inc):
        return -((R_t + Q) / t_inc) * I * S


    def dI_dt(self, S, I, Q, R_t, t_inc):
        return ((R_t + Q) / t_inc) * I * S - (I / t_inc)
    
    
    def dH_dt(self, I, C, H, t_inc, t_hosp, t_crit, m_a, f_a):
        return ((1 - m_a) * (I / t_inc)) + ((1 - f_a) * C / t_crit) - (H / t_hosp)
  

    def dC_dt(self, H, C, t_hosp, t_crit, c_a):
        return (c_a * H / t_hosp) - (C / t_crit)

    
    def dR_dt(self, I, H, t_inc, t_hosp, m_a, c_a):
        return (m_a * I / t_inc) + (1 - c_a) * (H / t_hosp)

    
    def dD_dt(self, C, t_crit, f_a):
        return f_a * C / t_crit

    # Quarantined population equation
    def dT_dt(self, I, Q):
        return Q * I

    def forward(self, t, y):
        S, I, R, H, C, D, T = y.double()
        Q = self.net(y)
        S_out = self.dS_dt(S, I, Q, self.R_t, self.t_inc)
        I_out = self.dI_dt(S, I, Q, self.R_t, self.t_inc)
        R_out = self.dR_dt(I, H, self.t_inc, self.t_hosp, self.m_a, self.c_a)
        H_out = self.dH_dt(I, C, H, self.t_inc, self.t_hosp, self.t_crit, self.m_a, self.f_a)
        C_out = self.dC_dt(H, C, self.t_hosp, self.t_crit, self.c_a)
        D_out = self.dD_dt(C, self.t_crit, self.f_a)
        T_out = self.dT_dt(I, Q)
        results = torch.cat([S_out, I_out, R_out, H_out, C_out, D_out, T_out])
        
        return results

In [None]:
train_path = 'data/train/'
output_path = 'data/output/'

In [None]:
hos_data_raw = pd.read_csv(train_path + 'Hopitalization_HOU.csv')
hos_data_raw['Date'] = pd.to_datetime(hos_data_raw['Date'])
hos_data_raw['Cum_Hosp'] = hos_data_raw['New_Hosp'].cumsum()
hos_data_raw

In [None]:
url = "https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-counties.csv"
s = requests.get(url).content
data_raw = pd.read_csv(io.StringIO(s.decode('utf-8')))
data_raw = data_raw[data_raw['state']=='Texas']
data_raw = data_raw.rename(columns={'date': 'Date', 'cases': 'ConfirmedCases', 'deaths': 'Fatalities'})
data_raw['Date'] = pd.to_datetime(data_raw['Date'])

data_raw = data_raw[data_raw['Date']<='2020-06-24']

pop_info = pd.read_csv(train_path + 'covid_county_population_usafacts.csv')
pop_info['County Name'] = pop_info['County Name'].apply(lambda x: ' '.join(x.split(' ')[:-1]))
county_lookup = dict(zip(pop_info['County Name'], pop_info['population']))

df = data_raw[['Date', 'county', 'ConfirmedCases', 'Fatalities']].reset_index()
del df['index']


pan_houston = ["Harris","Fort Bend","Montgomery","Brazoria","Galveston","Liberty","Waller","Chambers","Austin"]
sum_fatalities_df = df[df['county'].isin(pan_houston)].groupby(['Date']).sum()

sum_fatalities_df = sum_fatalities_df[sum_fatalities_df.index>=hos_data_raw['Date'].min()]
hos_data_raw['Fatalities'] = list(sum_fatalities_df['Fatalities'])
hos_data_raw['ConfirmedCases'] = list(sum_fatalities_df['ConfirmedCases'])
hos_data_raw['county'] = 'Houston'

hos_data_raw

In [None]:
for county in hos_data_raw['county'].unique():
    hosp_list = hos_data_raw[hos_data_raw['county']==county]['Cum_Hosp'].rolling(window=7).mean().cummax()
    fatalities_list = hos_data_raw[hos_data_raw['county']==county]['Fatalities'].rolling(window=7).mean().cummax()
    hos_data_raw.loc[hos_data_raw['county']==county, 'Cum_Hosp'] = hosp_list
    hos_data_raw.loc[hos_data_raw['county']==county, 'Fatalities'] = fatalities_list
    hos_data_raw.loc[hos_data_raw['county']==county] = hos_data_raw.loc[hos_data_raw['county']==county].query('Cum_Hosp > 0')
    
hos_data_raw = hos_data_raw.dropna()
hos_data_raw.set_index('Date', inplace=True)

hos_data_raw

In [None]:
hos_data_raw.iloc[65]

In [None]:
def log_fun(vals):
    x=vals.clone()
    idx = x!=0
    x[idx] = torch.log(x[idx])
    return x

def MSE_loss(pred, targets):
    H_pred, D_pred = pred[:,2]+pred[:,3]+pred[:,4]+pred[:,5], pred[:,5]
    H_true, D_true = targets[:,0], targets[:,1]
    H_error_log = torch.pow(log_fun(H_pred) - log_fun(H_true),2)
    D_error_log = torch.pow(log_fun(D_pred) - log_fun(D_true),2)
    H_error = torch.pow(H_pred - H_true,2)
    D_error = torch.pow(D_pred - D_true,2)
    MSE = torch.mean(H_error + D_error)
    MSLE = torch.mean(H_error_log+D_error_log)
    ME = torch.mean(H_error + D_error)
    return MSE,MSLE,ME

def fit_model(train_full, area_name, predict_days=14, epochs=30000, make_plot=True):
    data = train_full.copy()
    best_cost=None
    population = county_lookup[area_name]
    n_deaths = data['Fatalities'].iloc[0]
    n_recovered = 50
    n_hosp = data['Cum_Hosp'].iloc[0] - n_recovered - n_deaths
    n_crit = n_hosp * 0.05
    n_hosp = n_hosp * 0.95
    n_infected = n_hosp
    n_quarantined = 10
    n_susceptible = population - n_infected - n_deaths - n_recovered - n_hosp - n_crit
    
    initial_state = torch.tensor([n_susceptible, n_infected, n_recovered, n_hosp, n_crit, n_deaths, n_quarantined],dtype=torch.double) / population
    initial_state=initial_state.to(device)
    time_length_train = len(data) - predict_days
    time_length_pre = len(data) + predict_days
    times = torch.linspace(0., time_length_train, time_length_train).double().to(device)
    times_pr = torch.linspace(0., time_length_pre, time_length_pre).double().to(device)
    model = SIRHCDQ().double().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss = nn.MSELoss()
    
    for itr in range(1,epochs + 1):
        optimizer.zero_grad()
        pred_y = odeint(model, initial_state,times,rtol=0.01)
        y_true_cases = torch.stack([torch.tensor(data['Cum_Hosp'].values).double(), torch.tensor(data['Fatalities'].values).double()]).to(device)
        y_true_cases = y_true_cases.permute(1,0)
        pred_y = pred_y * population
    
        cost,logcost,ME = MSE_loss(pred_y[:time_length_train], y_true_cases[:time_length_train])
        if best_cost is None:
            best_cost = cost.data
        else:
            if logcost.data<best_cost:
                best_cost=cost.data
                torch.save(model.state_dict(), "SIRHCDQ_2.th")
        
        print("Iteration {:d} train MSE loss {:.6f} train MSLE loss {:.6f},train ME loss {:.6f}".format(itr,cost,logcost,ME))
        cost.backward()
        optimizer.step()
        
        if itr % 20 == 0:
            with torch.no_grad():
                pred_y = odeint(model, initial_state, times_pr,rtol=0.01)
                pred_y = pred_y * population
                
                test_dates = list(data.index) + [pd.to_datetime(data.index.max()) + pd.DateOffset(i+1) for i in range(predict_days)]
                pred_y_np = pred_y.cpu().numpy()
                #test_data
                test_data= pd.DataFrame({
                    'Date': test_dates,
                    'county': area_name,
                    'Cum_Hosp': pred_y_np[:,2]+pred_y_np[:,3]+pred_y_np[:,4]+pred_y_np[:,5],
                    'Fatalities':pred_y_np[:,5],
                })
                test_data.set_index('Date', inplace=True)
                print(pred_y[time_length_train:].shape, y_true_cases[time_length_train:].shape)
                cost_pr,logcost, ME = MSE_loss(pred_y[:len(data)], y_true_cases)
                print("Iteration {:d} val MSE loss {:.6f} val MSLE loss {:.6f} val ME loss {:.6f}".format(itr,cost_pr,logcost,ME))
 

                fig, ((ax1, ax2),(ax3,ax4)) = plt.subplots(2, 2, figsize=(10,10))    
                ax1.set_title('Hopitalization')
                ax2.set_title('Fatalities')
                ax3.set_title('Hosptalization forecast')
                ax4.set_title('Fatalities forecast')
        #             plt.show(block=False)

                data['Cum_Hosp'].plot(label='Hosptalization (train)', color='g', ax=ax1)
                test_data.loc[data.index, 'Cum_Hosp'].plot(label='Modeled Cases', color='r', ax=ax1)

                data['Fatalities'].plot(label='Fatalities (train)', color='g', ax=ax2)
                test_data.loc[data.index, 'Fatalities'].plot(label='Modeled Fatalities', color='r', ax=ax2)

                test_data[len(data):].loc[:, 'Cum_Hosp'].plot(label='Hosptalization(forecast)', color='r', linestyle=':', ax=ax3)
                test_data[len(data):].loc[:, 'Fatalities'].plot(label='Fatalities (forecast)', color='g', linestyle=':', ax=ax4)

                ax1.legend(loc='best')
                ax2.legend(loc='best')
                fig.tight_layout()
                plt.draw()
                plt.pause(0.001)
            
    return test_data


In [None]:
county = 'Houston'
population = county_lookup[county]
pr=fit_model(hos_data_raw, county)