In [1]:
%matplotlib inline
import pandas as pd
from itertools import product
from pmdarima import auto_arima
import matplotlib.pyplot as plt
import ipywidgets as widgets
import plotly.express as px
import warnings
from sklearn.metrics import mean_squared_error

warnings.simplefilter('ignore')

# Loading data

In [2]:
df = pd.read_csv (r'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv')

In [3]:
df.head()

Unnamed: 0,Province/State,Country/Region,Lat,Long,1/22/20,1/23/20,1/24/20,1/25/20,1/26/20,1/27/20,...,5/17/21,5/18/21,5/19/21,5/20/21,5/21/21,5/22/21,5/23/21,5/24/21,5/25/21,5/26/21
0,,Afghanistan,33.93911,67.709953,0,0,0,0,0,0,...,63598,63819,64122,64575,65080,65486,65728,66275,66903,67743
1,,Albania,41.1533,20.1683,0,0,0,0,0,0,...,132032,132071,132095,132118,132153,132176,132209,132215,132229,132244
2,,Algeria,28.0339,1.6596,0,0,0,0,0,0,...,125485,125693,125896,126156,126434,126651,126860,127107,127361,127646
3,,Andorra,42.5063,1.5218,0,0,0,0,0,0,...,13555,13569,13569,13569,13569,13569,13569,13569,13664,13671
4,,Angola,-11.2027,17.8739,0,0,0,0,0,0,...,30787,31045,31438,31661,31909,32149,32441,32623,32933,33338


## Drop unnecessary columns

In [4]:
df.drop(columns=['Province/State', 'Lat', 'Long'], inplace=True)

#### Grouping the data by country

In [5]:
df = df.groupby(['Country/Region']).sum().copy()
df.head()

Unnamed: 0_level_0,1/22/20,1/23/20,1/24/20,1/25/20,1/26/20,1/27/20,1/28/20,1/29/20,1/30/20,1/31/20,...,5/17/21,5/18/21,5/19/21,5/20/21,5/21/21,5/22/21,5/23/21,5/24/21,5/25/21,5/26/21
Country/Region,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Afghanistan,0,0,0,0,0,0,0,0,0,0,...,63598,63819,64122,64575,65080,65486,65728,66275,66903,67743
Albania,0,0,0,0,0,0,0,0,0,0,...,132032,132071,132095,132118,132153,132176,132209,132215,132229,132244
Algeria,0,0,0,0,0,0,0,0,0,0,...,125485,125693,125896,126156,126434,126651,126860,127107,127361,127646
Andorra,0,0,0,0,0,0,0,0,0,0,...,13555,13569,13569,13569,13569,13569,13569,13569,13664,13671
Angola,0,0,0,0,0,0,0,0,0,0,...,30787,31045,31438,31661,31909,32149,32441,32623,32933,33338


In [6]:
t = df.reset_index().append(df.sum(axis=0), ignore_index=True)

In [7]:
t.tail()

Unnamed: 0,Country/Region,1/22/20,1/23/20,1/24/20,1/25/20,1/26/20,1/27/20,1/28/20,1/29/20,1/30/20,...,5/17/21,5/18/21,5/19/21,5/20/21,5/21/21,5/22/21,5/23/21,5/24/21,5/25/21,5/26/21
188,West Bank and Gaza,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,304074.0,304074.0,304532.0,304532.0,304968.0,305201.0,305201.0,305777.0,306334.0,306795.0
189,Yemen,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,6568.0,6586.0,6593.0,6613.0,6632.0,6649.0,6658.0,6662.0,6670.0,6688.0
190,Zambia,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,92460.0,92520.0,92630.0,92754.0,92920.0,93106.0,93201.0,93279.0,93428.0,93627.0
191,Zimbabwe,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,38572.0,38595.0,38612.0,38635.0,38664.0,38679.0,38682.0,38696.0,38706.0,38819.0
192,,557.0,655.0,941.0,1433.0,2118.0,2927.0,5578.0,6167.0,8235.0,...,163609626.0,164231810.0,164902902.0,165182315.0,165808190.0,166385985.0,166862060.0,167316360.0,167848205.0,168197107.0


In [8]:
t.at[192,'Country/Region'] = 'World'

In [9]:
dft = t.melt(id_vars=['Country/Region'], var_name='date', value_name='confirmed')

In [10]:
dft.head()

Unnamed: 0,Country/Region,date,confirmed
0,Afghanistan,1/22/20,0.0
1,Albania,1/22/20,0.0
2,Algeria,1/22/20,0.0
3,Andorra,1/22/20,0.0
4,Angola,1/22/20,0.0


In [11]:
def forecast(country):
    def data_by_country(country):
        data = dft[dft['Country/Region']==country].copy()
        data.index = pd.DatetimeIndex(data['date'], freq='D')
        data.drop(columns=['date', 'Country/Region'], inplace=True)
        return data

    def split_train_test(df_in, train_percentage=.95):
        size = int(len(df_in)* train_percentage)
        train = df_in[:size]
        test = df_in[size:]
        return (train, test)

    def rmse(y_actual, y_prediction):
        return mean_squared_error(y_actual, y_prediction)*0.5
        
    def get_prediction_df(model, data, n_periods):
        result = model.predict(n_periods=n_periods, return_conf_int=True, dynamic=False, typ='levels')
        index_ = pd.date_range(start=data.index[-1], periods=n_periods+1)[1:]
        pred = pd.DataFrame({'pred':result[0], 'lower': result[1][:,0], 'upper': result[1][:,1]},index=index_)
        return pred
    
    def plot_predictions(train, test, pred):
        pred = pred.join(test).rename(columns={'confirmed':'actual'})
        t = pd.concat([train, pred])
        fig = px.line(t, x=t.index, y=t.columns)
        fig.show()
        
    def plot_future_predictions(data, pred):
        t = pd.concat([data, pred])
        fig = px.line(t, x=t.index, y=t.columns)
        fig.show()
        
        
    data = data_by_country(country)
#     train, test = split_train_test(data)
#     model = auto_arima(train['confirmed'])
#     display(model.summary())
#     model.plot_diagnostics(figsize=(15,15));
#     pred = get_prediction_df(model, train, n_periods=len(test))
#     plot_predictions(train, test, pred)

    # future prediction
    future_model = auto_arima(data['confirmed'])
    display(future_model.summary())
    future_pred = get_prediction_df(future_model, data, n_periods=30)
    plot_future_predictions(data, future_pred)

In [12]:
widgets.interact(forecast, country=['World','US', 'India', 'United Kingdom', 'Brazil']);

interactive(children=(Dropdown(description='country', options=('World', 'US', 'India', 'United Kingdom', 'Braz…