In [None]:
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.graph_objs import *

###### Links to the raw files for Covid-19 dataset provided by CSSEGIS JHU

In [None]:
confirmed_cases_file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
deaths_cases_file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv'
recovered_cases_file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_recovered_global.csv'
country_cases_file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/web-data/data/cases_country.csv'

In [None]:
#loading Dataset (https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_time_series)
confirmed_df = pd.read_csv(confirmed_cases_file_link)
deaths_df = pd.read_csv(deaths_cases_file_link)
recovered_df = pd.read_csv(recovered_cases_file_link)
cases_country_df = pd.read_csv(country_cases_file_link)

print(confirmed_df.shape , '|' , deaths_df.shape , '|', recovered_df.shape , '|', cases_country_df.shape )

In [None]:
#columns in 'confirmed_df' dataframe
confirmed_df.columns

In [None]:
confirmed_df.head()

In [None]:
confirmed_df[confirmed_df['Country/Region']=='Australia']

In [None]:
confirmed_df[confirmed_df['Country/Region']=='India']

In [None]:
confirmed_df['Country/Region'].nunique()

In [None]:
#columns in 'cases_country_df' dataframe
cases_country_df.columns

In [None]:
cases_country_df.head()

In [None]:
cases_country_df.isna().sum()

### EDA

In [None]:
global_data = cases_country_df.copy().drop(['Lat','Long_','Country_Region','Last_Update'], axis = 1)
global_summary = pd.DataFrame(global_data.sum()).T
global_summary.style.format("{:,.0f}")

###### For Chart 1 : Total Confirmed Covid-19 Cases (Globally)

In [None]:
confirmed_ts = confirmed_df.copy().drop(['Lat','Long','Country/Region','Province/State'],axis = 1)
confirmed_ts_summary = confirmed_ts.sum()

In [None]:
cases_country_df[cases_country_df['Country_Region'] == 'India']

In [None]:
confirmed_ts_summary

In [None]:
fig_1 = go.Figure(data = go.Scatter(x = confirmed_ts_summary.index,y = confirmed_ts_summary.values, mode = 'lines+markers'))

fig_1.update_layout(title = 'Total Confirmed Covid-19 Cases (Globally)', yaxis_title = 'Confirmed cases', xaxis_tickangle = 315, plot_bgcolor='rgba(0,0,0,0)')

fig_1.show()

###### Defining a template plot function and color array

In [None]:
# Initializing color Array to be used access the analysis
color_arr = px.colors.qualitative.Dark24

In [None]:
def draw_plot(ts_array, ts_label, title, colors, mode_size, line_size, x_axis_title, y_axis_title, tickangle = 0, yaxis_type='', additional_annotations=[]):
    #Initialize figure
    fig = go.Figure()
    #add all traces
    for index,ts in enumerate(ts_array):
        fig.add_traces(go.Scatter(x = ts.index,
                                  y = ts.values,
                                  name = ts_label[index],
                                  line = dict(color = colors[index],width = line_size[index]),connectgaps=True))
    #base x_axis prop.
    x_axis_dict = dict(showline = True,
                     showgrid = True,
                     showticklabels = True,
                     linecolor = 'rgb(204, 204, 204)',
                     linewidth = 2,
                     ticks = 'outside',
                     tickfont = dict(family = 'Arial',size = 12, color = 'rgb(82,82,82)'))
    #setting x_axis params
    if x_axis_title:
        x_axis_dict['title'] = x_axis_title
        
    if tickangle >0:
        x_axis_dict['tickangle'] = tickangle
        
    #base y_axis prop.
    y_axis_dict = dict(showline = True,
                     showgrid = True,
                     showticklabels = True,
                     linecolor = 'rgb(204, 204, 204)',
                     linewidth = 2)
    #setting my_axis params
    if yaxis_type != '':
        y_axis_dict['type'] = yaxis_type
        
    if y_axis_title:
        y_axis_dict['title'] = y_axis_title
    
    #updating the layout
    fig.update_layout(xaxis = x_axis_dict,
                     yaxis = y_axis_dict,
                     autosize = True,
                     margin = dict(autoexpand=True,l=100,r=20,t=110),
                     showlegend = True,
                     plot_bgcolor = 'rgba(0,0,0,0)',
                    legend=dict(x=0, y=1)
                     )
    
    #base annotations for any graph
    annotations = []
    #Title
    annotations.append(dict(xref='paper',yref='paper',x=0.0,y=1.05,xanchor='left',yanchor='bottom',
                           text = title,
                           font=dict(family = 'Arial',size = 16, color = 'rgb(37,37,37)'),showarrow=False))
    
    #adding annotations in params
    if len(additional_annotations) > 0:
        annotations.append(additional_annotations)
        
    #updating the layout
    fig.update_layout(annotations=annotations)
    
    return fig

###### For Chart 2 : Covid-19 Case Status

In [None]:
confirmed_agg_ts = confirmed_df.copy().drop(['Lat','Long','Country/Region','Province/State'],axis = 1).sum()
death_agg_ts = deaths_df.copy().drop(['Lat','Long','Country/Region','Province/State'],axis = 1).sum()
recovered_agg_ts = recovered_df.copy().drop(['Lat','Long','Country/Region','Province/State'],axis = 1).sum()

#There is no time series dat for the active cases,therefore it needs to be engineered seperately
active_agg_ts = pd.Series(data = np.array([x1-x2-x3 for (x1, x2, x3) in zip(confirmed_agg_ts.values, death_agg_ts.values, recovered_agg_ts.values )]), index = confirmed_agg_ts.index)

#Plot and add traces for all the aggregated timeseries

In [None]:
ts_array = [confirmed_agg_ts, active_agg_ts, recovered_agg_ts, death_agg_ts]
labels = ['Confirmed', 'Active','Recovered','Deaths']
colors = [color_arr[0],color_arr[1],color_arr[2],color_arr[3]]
mode_size = [8,8,12,8]
line_size = [2,2,4,2]

#calling the draw.plot function defined above
fig_2 = draw_plot(ts_array = ts_array,
                 ts_label = labels,
                 title = '',
                 colors = colors, mode_size = mode_size,
                 line_size = line_size,
                 x_axis_title = 'Date',
                 y_axis_title = 'Case Count',
                 tickangle = 315,
                 yaxis_type = '',additional_annotations =[])

fig_2.show()

In [None]:
a = pd.DataFrame(confirmed_agg_ts).rename(columns = {0:'Confirmed'}).reset_index().rename(columns = {'index':'Date'})
b = pd.DataFrame(active_agg_ts).rename(columns = {0:'Active'}).reset_index().rename(columns = {'index':'Date'})
c = pd.DataFrame(recovered_agg_ts).rename(columns = {0:'Recovered'}).reset_index().rename(columns = {'index':'Date'})
d = pd.DataFrame(death_agg_ts).rename(columns = {0:'Deaths'}).reset_index().rename(columns = {'index':'Date'})
df = a.merge(b, on='Date', how = 'left')
df = df.merge(c,on='Date', how = 'left')
df = df.merge(d,on='Date', how = 'left')
df

In [None]:
a

In [None]:
a = pd.DataFrame(confirmed_agg_ts).rename(columns = {0:'Confirmed'}).reset_index().rename(columns = {'index':'Date'})
b = pd.DataFrame(active_agg_ts).rename(columns = {0:'Active'}).reset_index().rename(columns = {'index':'Date'})
c = pd.DataFrame(recovered_agg_ts).rename(columns = {0:'Recovered'}).reset_index().rename(columns = {'index':'Date'})
d = pd.DataFrame(death_agg_ts).rename(columns = {0:'Deaths'}).reset_index().rename(columns = {'index':'Date'})
df = a.merge(b, on='Date', how = 'left')
df = df.merge(c,on='Date', how = 'left')
df = df.merge(d,on='Date', how = 'left')
df.index = df.Date
df = df.drop(columns = 'Date')

#display
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=df.index.tolist(), y=df['Deaths'],
    hoverinfo='x+y',
    mode='lines+markers',
     line_color='red',
    stackgroup='one',
    name = "Deaths (Total =" + str(df["Deaths"].max()) + ')',
))

fig.add_trace(go.Scatter(x=df.index.tolist(), y=df['Recovered'],
    hoverinfo='x+y',
    mode='lines+markers',
    line_color='green',
    stackgroup='one',
    name = "Recovered (Total =" + str(df["Recovered"].max()) + ')',                
))
fig.add_trace(go.Scatter( x=df.index.tolist(), y=df['Confirmed'],
    hoverinfo='x+y',
    mode='lines+markers',
    line_color='grey',
    stackgroup='one',
    name = "Confirmed (Total =" + str(df["Confirmed"].max()) + ')',
))

fig.update_layout(template = 'plotly_white',yaxis=dict(title='Number of cases' ),xaxis=(dict(title='Date (' + str(df.index.min().date()) + ' to ' + str(df.index.max().date())+ ')', showticklabels=False)), legend=dict(x=0, y=1))
fig.show()

In [None]:
# source : https://commons.wikipedia.org/wiki/File:COVID-19_India_Total_Cases_Animated_Map.gif
from IPython.display import HTML
HTML('<img src="https://upload.wikimedia.org/wikipedia/commons/9/95/COVID-19_India_Total_Cases_Animated_Map.gif" height ="600" width="400">')

###### For Country Level Drill Down

In [None]:
cases_country_df.copy().drop(['Lat','Long_','Last_Update','People_Tested','People_Hospitalized'],axis=1).sort_values('Confirmed',ascending = False).reset_index(drop=True).style.bar(
align='left',width=98,color='orange')

In [None]:
cases_country_df.copy().drop(['Lat','Long_','Last_Update','People_Tested','People_Hospitalized'],axis=1).sort_values('Confirmed',ascending = False).reset_index(drop=True).head(10)

In [None]:
cases_country_df.copy().drop(['Lat','Long_','Last_Update','People_Tested','People_Hospitalized'],axis=1).sort_values('Recovered',ascending = False).reset_index(drop=True).style.bar(
align='left',width=98,color='lightgreen')

In [None]:
cases_country_df.copy().drop(['Lat','Long_','Last_Update','People_Tested','People_Hospitalized'],axis=1).sort_values('Deaths',ascending = False).reset_index(drop=True).style.bar(
align='left',width=98,color='red')

In [None]:
cases_country_df.copy().drop(['Lat','Long_','Last_Update','People_Tested','People_Hospitalized'],axis=1).sort_values('Active',ascending = False).reset_index(drop=True).style.bar(
align='left',width=98,color='purple')

In [None]:
# for recovery rate
cases_country_df1 = cases_country_df.copy()
cases_country_df['%Recovered'] = cases_country_df['Recovered'] *100 / cases_country_df['Confirmed']

cases_country_df.copy().drop(['Lat','Long_','Last_Update','People_Tested','People_Hospitalized'],axis=1).sort_values('Confirmed',ascending = False).reset_index(drop=True).style.bar(
align='left',width=98,color='green')

In [None]:
# del cases_country_df['%Deaths']

In [None]:
((100*cases_country_df['Recovered'].sum())/cases_country_df['Confirmed'].sum()).round(2)

In [None]:
((100*cases_country_df['Deaths'].sum())/cases_country_df['Confirmed'].sum()).round(2)

In [None]:
cases_country_df

#### Focus : India

###### For Chart 3: 'Covid-19 Case' Trend in India

In [None]:
confirmed_India_ts = confirmed_df[confirmed_df['Country/Region']=='India']
confirmed_India_ts = confirmed_India_ts.drop(['Lat','Long','Country/Region','Province/State'],axis = 1).reset_index(drop=True).sum()
confirmed_India_ts.index = pd.to_datetime(confirmed_India_ts.index)

deaths_India_ts = deaths_df[deaths_df['Country/Region']=='India']
deaths_India_ts = deaths_India_ts.drop(['Lat','Long','Country/Region','Province/State'],axis = 1).reset_index(drop=True).sum()
deaths_India_ts.index = pd.to_datetime(deaths_India_ts.index)

recovered_India_ts = recovered_df[recovered_df['Country/Region']=='India']
recovered_India_ts = recovered_India_ts.drop(['Lat','Long','Country/Region','Province/State'],axis = 1).reset_index(drop=True).sum()
recovered_India_ts.index = pd.to_datetime(recovered_India_ts.index)

active_India_ts = pd.Series(data = np.array([x1-x2-x3 for (x1, x2, x3) in zip(confirmed_India_ts.values, deaths_India_ts.values, recovered_India_ts.values )]), index = confirmed_India_ts.index)

In [None]:
ts_array = [confirmed_India_ts, active_India_ts, recovered_India_ts, deaths_India_ts]
labels = ['Confirmed', 'Active','Recovered','Deaths']
colors = [color_arr[0],color_arr[1],color_arr[2],color_arr[3]]
mode_size = [8,8,12,8]
line_size = [2,2,4,2]

#calling the draw.plot function defined above
fig_3 = draw_plot(ts_array = ts_array,
                 ts_label = labels,
                  title = '',
                 colors = colors, mode_size = mode_size,
                 line_size = line_size,
                 x_axis_title = 'Date',
                 y_axis_title = 'Case Count',
                 tickangle = 315,
                 yaxis_type = '',additional_annotations =[])

fig_3.show()

In [None]:
a = pd.DataFrame(confirmed_India_ts).rename(columns = {0:'Confirmed'}).reset_index().rename(columns = {'index':'Date'})
b = pd.DataFrame(active_India_ts).rename(columns = {0:'Active'}).reset_index().rename(columns = {'index':'Date'})
c = pd.DataFrame(recovered_India_ts).rename(columns = {0:'Recovered'}).reset_index().rename(columns = {'index':'Date'})
d = pd.DataFrame(deaths_India_ts).rename(columns = {0:'Deaths'}).reset_index().rename(columns = {'index':'Date'})
df = a.merge(b, on='Date', how = 'left')
df = df.merge(c,on='Date', how = 'left')
df = df.merge(d,on='Date', how = 'left')
df.index = df.Date
df = df.drop(columns = 'Date')
df

In [None]:
a = pd.DataFrame(confirmed_India_ts).rename(columns = {0:'Confirmed'}).reset_index().rename(columns = {'index':'Date'})
b = pd.DataFrame(active_India_ts).rename(columns = {0:'Active'}).reset_index().rename(columns = {'index':'Date'})
c = pd.DataFrame(recovered_India_ts).rename(columns = {0:'Recovered'}).reset_index().rename(columns = {'index':'Date'})
d = pd.DataFrame(death_India_ts).rename(columns = {0:'Deaths'}).reset_index().rename(columns = {'index':'Date'})
df = a.merge(b, on='Date', how = 'left')
df = df.merge(c,on='Date', how = 'left')
df = df.merge(d,on='Date', how = 'left')
df.index = df.Date
df = df.drop(columns = 'Date')

#display
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=df.index.tolist(), y=df['Deaths'],
    hoverinfo='x+y',
    mode='lines+markers',
     line_color='red',
    stackgroup='one',
    name = "Deaths (Total =" + str(df["Deaths"].max()) + ')',
))

fig.add_trace(go.Scatter(x=df.index.tolist(), y=df['Recovered'],
    hoverinfo='x+y',
    mode='lines+markers',
    line_color='green',
    stackgroup='one',
    name = "Recovered (Total =" + str(df["Recovered"].max()) + ')',                
))
fig.add_trace(go.Scatter( x=df.index.tolist(), y=df['Confirmed'],
    hoverinfo='x+y',
    mode='lines+markers',
    line_color='grey',
    stackgroup='one',
    name = "Confirmed (Total =" + str(df["Confirmed"].max()) + ')',
))

fig.update_layout(template = 'plotly_white',yaxis=dict(title='Number of cases' ),xaxis=(dict(title='Date (' + str(df.index.min().date()) + ' to ' + str(df.index.max().date())+ ')', showticklabels=False)), legend=dict(x=0, y=1))
fig.show()

###### Chart 4 : Covid-19 Transmission Timeline in India - In different LockDown Phases

In [None]:
# Transmission Timeline in India in Different lockdown phases
# LockDown 1.O = '03/25/2020' - '04/14/2020'
# LockDown 2.O = '04/15/2020' - '05/03/2020'
# LockDown 3.O = '05/04/2020' - '05/17/2020'
# LockDown 4.O = '05/18/2020' - '05/31/2020'

def draw_plot1(ts_array, ts_label, title, colors, mode_size, line_size, x_axis_title, y_axis_title, tickangle = 0, yaxis_type='', additional_annotations=[]):
    #Initialize figure
    fig = go.Figure()
    #add all traces
    for index,ts in enumerate(ts_array):
        fig.add_traces(go.Scatter(x = ts.index,
                                  y = ts.values,
                                  name = ts_label[index],
                                  line = dict(color = colors[index],width = line_size[index]),connectgaps=True, mode = 'lines+markers'))
    #base x_axis prop.
    x_axis_dict = dict(showline = True,
                     showgrid = True,
                     showticklabels = True,
                     linecolor = 'rgb(204, 204, 204)',
                     linewidth = 2,
                     ticks = 'outside',
                     tickfont = dict(family = 'Arial',size = 12, color = 'rgb(82,82,82)'))
    #setting x_axis params
    if x_axis_title:
        x_axis_dict['title'] = x_axis_title
        
    if tickangle >0:
        x_axis_dict['tickangle'] = tickangle
        
    #base y_axis prop.
    y_axis_dict = dict(showline = True,
                     showgrid = True,
                     showticklabels = True,
                     linecolor = 'rgb(204, 204, 204)',
                     linewidth = 2)
    #setting my_axis params
    if yaxis_type != '':
        y_axis_dict['type'] = yaxis_type
        
    if y_axis_title:
        y_axis_dict['title'] = y_axis_title
    
    #updating the layout
    fig.update_layout(xaxis = x_axis_dict,
                     yaxis = y_axis_dict,
                     autosize = True,
                     margin = dict(autoexpand=True,l=100,r=20,t=110),
                     showlegend = True,
                     plot_bgcolor = 'rgba(0,0,0,0)',
                     legend=dict(x=0, y=1)
                     )
    
    #base annotations for any graph
    annotations = []
    #Title
    annotations.append(dict(xref='paper',yref='paper',x=0.0,y=1.05,xanchor='left',yanchor='bottom',
                           text = title,
                           font=dict(family = 'Arial',size = 16, color = 'rgb(37,37,37)'),showarrow=False))
    
    #adding annotations in params
    if len(additional_annotations) > 0:
        annotations.append(additional_annotations)
        
    #updating the layout
    fig.update_layout(annotations=annotations)
    
    return fig


Phase = 4

if Phase == 1:
    start_date = '2020-03-25'
    end_date = '2020-04-14'    
elif Phase == 2:
    start_date = '2020-04-15'
    end_date = '2020-05-03'
elif Phase == 3:
    start_date = '2020-05-04'
    end_date = '2020-05-17'
elif Phase == 4:
    start_date = '2020-05-18'
    end_date = '2020-05-31'


a = confirmed_India_ts[confirmed_India_ts.index >= start_date]
confirmed_India_ts1 = a[a.index <= end_date]
del a
a = active_India_ts[active_India_ts.index >= start_date]
active_India_ts1 = a[a.index <= end_date]
del a
a = recovered_India_ts[recovered_India_ts.index >= start_date]
recovered_India_ts1 = a[a.index <= end_date]
del a
a = deaths_India_ts[deaths_India_ts.index >= start_date]
deaths_India_ts1 = a[a.index <= end_date]
del a

# confirmed_India_ts1 = confirmed_India_ts.loc[(confirmed_India_ts[confirmed_India_ts.index >= start_date]) & (confirmed_India_ts[confirmed_India_ts.index <= end_date])]
# active_India_ts1 = active_India_ts.loc[(active_India_ts[active_India_ts.index >= start_date]) & (active_India_ts[active_India_ts.index <= end_date])]
# recovered_India_ts1 = recovered_India_ts.loc[(recovered_India_ts[recovered_India_ts.index >= start_date]) & (recovered_India_ts[recovered_India_ts.index <= end_date])]
# deaths_India_ts1 = deaths_India_ts.loc[(deaths_India_ts[deaths_India_ts.index >= start_date]) & (deaths_India_ts[deaths_India_ts.index <= end_date])]
line_size = [2,2,2,2]
ts_array1 = [confirmed_India_ts1, active_India_ts1, recovered_India_ts1, deaths_India_ts1]
fig_4 = draw_plot1(ts_array = ts_array1,
                 ts_label = labels,
                  title = '',
                 colors = colors, mode_size = mode_size,
                 line_size = line_size,
                 x_axis_title = 'Date',
                 y_axis_title = 'Case Count',
                 tickangle = 315,
                 yaxis_type = '',additional_annotations =[])

fig_4.show()

###### Chart 5 : Covid-19 Transmission in India (Semi-Log Plot)

In [None]:
mode_size = [8,8,12,8]
line_size = [4,2,2,2]

fig_5 = draw_plot(ts_array = ts_array,
                 ts_label = labels,
                  title = '',
                 colors = colors, mode_size = mode_size,
                 line_size = line_size,
                 x_axis_title = 'Date',
                 y_axis_title = 'Case Count',
                 tickangle = 315,
                 yaxis_type = 'log',additional_annotations =[])

fig_5.show()

#### Modelling & Prediction

###### SIR Model for Spread of Disease -- The Differential equation Model

In [None]:
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from datetime import timedelta, datetime

In [None]:
START_DATE = {
#     'Italy' : '1/31/20',
    'India' : '1/30/20' 
}
class Learner(object):
    def __init__(self,country,loss,start_date = '1/22/20', predict_range = 366, s_O = 1000000, i_O = 2, r_O = 10):
        self.country = country
        self.loss = loss
        self.start_date = start_date
        self.predict_range = predict_range
        self.s_O = s_O
        self.i_O = i_O
        self.r_O = r_O
        
    def load_confirmed(self, country):
        df = pd.read_csv(confirmed_cases_file_link)
        df= df.drop(['Province/State'],axis = 1)
        country_df = df[df['Country/Region'] == country]
        return country_df.iloc[0].loc[self.start_date:]
    
    def load_recovered(self, country):
        df = pd.read_csv(recovered_cases_file_link)
        df= df.drop(['Province/State'],axis = 1)
        country_df = df[df['Country/Region'] == country]
        return country_df.iloc[0].loc[self.start_date:]

    def load_deaths(self, country):
        df = pd.read_csv(deaths_cases_file_link)
        df= df.drop(['Province/State'],axis = 1)
        country_df = df[df['Country/Region'] == country]
        return country_df.iloc[0].loc[self.start_date:]
    
    def extend_index(self, index, new_size):
        values = index.values
        current = datetime.strptime(index[-1], '%m/%d/%y')
        while len(values) < new_size:
            current = current + timedelta(days=1)
            values = np.append(values, datetime.strftime(current, '%m/%d/%y'))
        return values
    
    def predict(self, beta, gamma, data, recovered, death, country, s_O, i_O, r_O):
        
        new_index = self.extend_index(data.index, self.predict_range)
        size = len(new_index)
        def SIR(t, y):
            S= y[0]
            I = y[1]
            R = y[2]
            return [-beta*S*I , beta*S*I-gamma*I , gamma*I]
        extended_actual = np.concatenate((data.values, [None] * (size - len(data.values))))
        extended_recovered = np.concatenate((recovered.values, [None] * (size - len(recovered.values))))
        extended_death = np.concatenate((death.values, [None] * (size - len(death.values))))
        return new_index, extended_actual, extended_recovered, extended_death, solve_ivp(SIR, [0,size], [s_O, i_O, r_O], t_eval = np.arange(0,size,1))
    
    def train(self):
        recovered = self.load_recovered(self.country)
        death = self.load_deaths(self.country)
        data = (self.load_confirmed(self.country) - recovered - death)
        
        optimal = minimize(loss, [0.001,0.001], args=(data, recovered, self.s_O, self.i_O, self.r_O), method = 'L-BFGS-B', bounds = [(0.00000001,0.4), (0.00000001,0.4)])
        print(optimal)
        beta, gamma = optimal.x
        new_index, extended_actual, extended_recovered, extended_death, prediction = self.predict(beta, gamma, data, recovered, death, self.country, self.s_O, self.i_O, self.r_O)
        df= pd.DataFrame({'Infected data' : extended_death, 'Recovered data' : extended_recovered, 'Death data' : extended_death, 'Susceptible' : prediction.y[0], 'Infected' : prediction.y[1], 'Recovered' : prediction.y[2]}, index = new_index)
        #save results to csv
        df.to_csv(f"{self.country}.csv")
        fig, ax = plt.subplots(figsize=(15,10))
        ax.set_title(self.country)
        df.plot(ax=ax)
        print(f"country={self.country}, beta = {beta:.8f}, gamma = {gamma:.8f}, r_O:{(beta/gamma):.8f}")
        fig.savefig(f"{self.country}.png")
        
        return df, fig

In [None]:
def loss(point, data, recovered, s_O, i_O, r_O):
    size = len(data)
    beta, gamma = point
    def SIR(t,y):
        S = y[0]
        I = y[1]
        R = y[2]
        return [-beta*S*I , beta*S*I-gamma*I , gamma*I]
    solution = solve_ivp(SIR, [0,size], [s_O,i_O,r_O], t_eval = np.arange(0,size,1), vectorized = True)
    l1 = np.sqrt(np.mean((solution.y[1] - data)**2))
    l2 = np.sqrt(np.mean((solution.y[2] - data)**2))
    alpha = 0.1
    return alpha*l1 + (1-alpha)*l2

###### For Italy

In [None]:
# italy_learner = Learner(country = 'Italy', loss = loss)

In [None]:
# italy_df, italy_fig = italy_learner.train()

###### For India

In [None]:
india_learner = Learner(country='India', loss = loss, i_O = 3)
#as there were only 3 in february

In [None]:
india_df, india_fig = india_learner.train()

###### Loading the saved results from the SIR model

In [None]:
india_sir = pd.read_csv('/Users/apple/India.csv')

In [None]:
india_sir = india_sir.rename(columns = {'Unnamed: 0':'Datetime'})
india_sir.index = india_sir.Datetime
india_sir = india_sir.drop(columns = ['Datetime'])
india_sir.head()

In [None]:
def plot_sir_prediction(title, df_sir, remove_series=[], yaxis_type='',yaxis_title=''):
    fig = go.Figure()
    title = title
    labels = ['Infected data','Recovered data','Death data','Susceptible','Infected','Recovered']
    colors = [color_arr[0], color_arr[9], color_arr[3], color_arr[8], color_arr[10], color_arr[13]]
    line_size = [2,2,2,2,2,2]
    
    for index, data_series in enumerate(labels):
        if data_series not in remove_series:
            fig.add_trace(go.Scatter(x = df_sir.index,
                                    y = df_sir[data_series],
                                    name = labels[index],
                                    line = dict(color = colors[index],
                                               width = line_size[index]),
                                    connectgaps = True))
    xaxis = dict( title = 'Date', 
                 showline = True,
                 showgrid = True,
                 showticklabels = True,
                 linecolor = 'rgb(204, 204, 204)',
                 linewidth = 2,
                 ticks = 'outside',
                 tickangle = 280,
                 tickfont = dict(family = 'Arial',size = 12, color = 'rgb(82,82,82)'))
    yaxis = dict(title = "Case Count",
                    showline = True,
                     showgrid = True,
                     showticklabels = True,
                     linecolor = 'rgb(204, 204, 204)',
                     linewidth = 2)
    if yaxis_type != '':
        yaxis['type'] = yaxis_type
        
    if yaxis_title:
        yaxis['title'] = yaxis_title

    fig.update_layout(xaxis = xaxis,
                     yaxis = yaxis,
                     autosize = True,
                     margin = dict(autoexpand=True,l=100,r=20,t=110),
                     showlegend = True
                     )
    annotations = []
    #Title
    annotations.append(dict(xref='paper',yref='paper',x=0.0,y=1.05,xanchor='left',yanchor='bottom',
                           text = title,
                           font=dict(family = 'Arial',size = 16, color = 'rgb(37,37,37)'),showarrow=False))
    fig.update_layout(annotations=annotations, plot_bgcolor='rgba(0,0,0,0)')
    
    return fig

###### For Chart 6 : SIR Model -- Covid-19 Transmission -- Prediction -- INDIA

In [None]:
fig_6 = plot_sir_prediction(title = 'SIR Model -- Covid-19 Transmission -- Prediction -- INDIA', df_sir = india_sir)
fig_6.show()