# Time Series Forecasting

In [None]:
from datetime import timedelta
from pathlib import Path
import re
import itertools
import asyncio
import warnings

import matplotlib.pyplot as plt

import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.express as px
import folium
from tqdm.notebook import tqdm

from statsmodels.tsa.statespace.sarimax import SARIMAX
import nest_asyncio
nest_asyncio.apply()

plt.rcParams["figure.figsize"] = (20, 10)

## Load all data

In [None]:
datasets = {f.name : pd.read_csv(f) for f in tqdm(Path('.').glob('*.csv'))}

In [None]:
datasets = {
    k: df.rename(
        mapper=lambda c: c if re.match(r'\d{1,2}/\d{1,2}/\d{1,2}', c) else c.split('/')[0].lower().rstrip('_'),
        axis='columns')
    for (k, df) in datasets.items()
}

## Clean the dataframes

### Drop unused columns

In [None]:
confirmed_df = datasets['time_series_covid19_confirmed_global.csv'].groupby('country').sum().reset_index()
deaths_df = datasets['time_series_covid19_deaths_global.csv'].groupby('country').sum().reset_index()

## Explore aggregated data

In [None]:
country_df = datasets['cases_country.csv']

@widgets.interact(num_rows=widgets.IntSlider(min=1, max=len(country_df), continuous_update=False, description='N:'))
def render_df(num_rows):
    display(country_df.sort_values(by='confirmed', 
                                   ascending=False)
                      .head(num_rows)
                      .loc[:,['country_region', 
                              'confirmed', 
                              'deaths', 
                              'recovered', 
                              'active']]
                      .style.background_gradient(cmap='Reds'))

## Worst affected countries

In [None]:
@widgets.interact(num_rows=widgets.IntSlider(min=1, max=len(country_df), continuous_update=False, description='N:'))
def render_df(num_rows):
    fig = px.scatter(country_df.sort_values(by='confirmed', 
                                            ascending=False)
                               .head(num_rows)
                               .fillna(0.0), 
                     x='deaths', 
                     y='recovered', 
                     size='confirmed',
                     color='country_region',
                     hover_name='country_region')
    fig.show()

In [None]:
@widgets.interact(country=confirmed_df['country'].unique())
def plot_daily_cases(country):
    ts_confirmed = confirmed_df.set_index('country').iloc[:,3:].loc[country,:]
    ts_deaths = deaths_df.set_index('country').iloc[:,3:].loc[country,:]
    df = ts_confirmed.to_frame().join(ts_deaths, lsuffix='_c', rsuffix='_d').reset_index()
    df.columns = ['date', 'confirmed', 'deaths']
    df['date'] = pd.to_datetime(df['date'])
    fig = px.line(df, x='date', y=df.columns, title=country)
    fig.show()

### Top 10 of worst-hit countries

In [None]:
def plot_top10(metric):
    fig = px.bar(country_df.sort_values(by=metric, 
                                    ascending=False)
                       .head(10),
             x='country_region',
             y=metric,
             title=f'Top 10 ({metric})')
    fig.show()

In [None]:
plot_top10('confirmed')

In [None]:
plot_top10('deaths')

In [None]:
plot_top10('active')

In [None]:
plot_top10('recovered')

In [None]:
plot_top10('mortality_rate')

### COVID-19 spread on global map

In [None]:
scale_f = lambda x: 10 / country_df.confirmed.max() * x + 5

m = folium.Map(tiles="Stamen Toner", zoom_start=13)
for i,r in country_df.dropna(subset=['lat', 'long']).iterrows():
    folium.CircleMarker(
        location=[r.lat, r.long],
        radius=scale_f(r.confirmed),
        popup=f'''
               <table>
                 <tr>
                   <th colspan="2">{r.country_region}</th>
                 </tr>
                 <tr>
                   <td>confirmed:</td>
                   <td>{int(r.confirmed)}</td>
                 </tr>
                 <tr>
                   <td>deaths:</td>
                   <td>{int(r.deaths)}</td>
                 </tr>
                 <tr>
                   <td>death rate:</td>
                   <td>{r.mortality_rate:.3f}</td>
                 </tr>
               </table>
        ''',
        color='crimson',
        fill=True
    ).add_to(m)
m

## Forecasting

In [None]:
confirmed_agg = confirmed_df.drop(columns=['lat', 'long']).set_index('country')

In [None]:
def get_data(country: str = None):
    data_df = confirmed_agg.loc[country, :] if country else confirmed_agg.sum()
    data_df.index = pd.to_datetime(data_df.index).to_period('D')
    return data_df

In [None]:
def train_test_split(data: pd.Series, train_ratio: float = 0.66):
    assert 0. < train_ratio < 1., 'train_ratio must be in (0.0, 1.0)'
    bound = int(len(data) * train_ratio)
    train, test = data[:bound], data[bound:]
    train.name = 'train'
    test.name = 'test'
    return train, test

In [None]:
def MAPE(y, y_hat, eps = 1e-10):
    return np.mean(np.abs((y - y_hat)/(y + eps)))

In [None]:
progress = widgets.IntProgress(min=0)
display(progress)

@widgets.interact(country=['World'] + confirmed_agg.index.values.tolist())
def autoarima(country):
    p = d = q = range(3)
    
    async def fit_arima(train, order):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            model = SARIMAX(train, order=order).fit()
            progress.value += 1
            return (model.aic, order)

    data = get_data(None if country == 'World' else country)
    train, test = train_test_split(data, 0.95)

    loop = asyncio.get_event_loop()
    tasks = [fit_arima(train, order) for order in itertools.product(p,d,q)]
    futures = asyncio.gather(*tasks, return_exceptions=True)

    progress.value = 0
    progress.max = len(tasks)
    loop.run_until_complete(futures)

    best_result = sorted([result for result in futures.result() if not isinstance(result, Exception)])[0]

    model = SARIMAX(train, order=best_result[1]).fit()
    
    forecast = model.predict(start=test.head(1).index[0], end=test.tail(1).index[0] + timedelta(days=30))
    print(f'test MAPE: {MAPE(test, forecast[:len(test)]):.3f}')
    
    result = data.to_frame().join(test).join(forecast, how='outer')
    result.columns = ['train', 'test', 'predicted']
    result.plot()