In [0]:
from google.colab import drive
from IPython.display import clear_output
import os
drive.mount('/content/drive')
os.chdir('/content/drive/My Drive/kaggle/covid_forecast')
!pip install catboost
clear_output()
import warnings
warnings.filterwarnings('ignore')

In [0]:
import pandas as pd
import numpy as np
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.linear_model import Ridge
from scipy.optimize import curve_fit
import json
import lightgbm as lgb

In [0]:
# used to not to recalculate all date statistics when predicting recursively
max_shift = 8
max_shift_delta = pd.Timedelta(days=max_shift)

In [0]:
num_features = ['recovered_lag_2',
 'dead_lag_4',
 'lag3_minus_max',
 'recovered_poly_pred',
 'rolling_confirmed_min',
 'testpop',
 'recovered_poly_1',
 'Death rate, crude (per 1,000 people)',
 'dead_lag_1',
 'rolling_recovered_mean',
 'hospibed',
 'lag1_minus_max',
 'Population, total',
 'GDP ($ per capita)',
 'Cause of death, by communicable diseases and maternal, prenatal and nutrition conditions (% of total)',
 'Mortality rate attributed to household and ambient air pollution, age-standardized (per 100,000 population)',
 'rolling_confirmed_std',
 'tests',
 'Mortality from CVD, cancer, diabetes or CRD between exact ages 30 and 70 (%)',
 'rolling_recovered_min',
 'Poverty headcount ratio at $3.20 a day (2011 PPP) (% of population)',
 'Mortality rate, adult, female (per 1,000 female adults)',
 'Tuberculosis treatment success rate (% of new cases)',
 'confirmed_ridge_bias',
 'recovered_poly_0',
 'recovered_lag_7',
 'International migrant stock, total',
 'recovered_lag_1',
 'recovered_poly_3',
 'confirmed_poly_3',
 'recovered_ridge_coef',
 'rolling_dead_min',
 'Infant mortality (per 1000 births)',
 'confirmed_poly_2',
 'recovered_ridge_pred',
 'rolling_dead_mean',
 'Number of people spending more than 25% of household consumption or income on out-of-pocket health care expenditure',
 'Population in urban agglomerations of more than 1 million (% of total population)',
 'dead_lag_7',
 'Survival to age 65, male (% of cohort)',
 'dead_ridge_pred',
 'Number of people spending more than 10% of household consumption or income on out-of-pocket health care expenditure',
 'Labor force participation rate, total (% of total population ages 15+) (modeled ILO estimate)',
 'recovered_expanding_max',
 'dead_lag_2',
 'recovered_lag_4',
 'dead_poly_0',
 'Smoking prevalence, females (% of adults)',
 'Population ages 15-64 (% of total)',
 'Survival to age 65, female (% of cohort)',
 'Cause of death, by non-communicable diseases (% of total)',
 'schools',
 'confirmed_lag_2',
 'GDP per capita, PPP (current international $)',
 'confirmed_poly_0',
 'recovered_poly_2',
 'Life expectancy at birth, total (years)',
 'Hospital beds (per 1,000 people)',
 'Mortality rate, adult, male (per 1,000 male adults)',
 'temperature',
 'rolling_recovered_std',
 'recovered_lag_5',
 'rolling_recovered_max',
 'People using at least basic sanitation services (% of population)',
 'dead_lag_5',
 'Mortality rate attributed to unsafe water, unsafe sanitation and lack of hygiene (per 100,000 population)',
 'Trade (% of GDP)',
 'Diabetes prevalence (% of population ages 20 to 79)',
 'Population density (people per sq. km of land area)',
 'People using safely managed sanitation services (% of population)',
 'Tuberculosis case detection rate (%, all forms)',
 'confirmed_lag_6',
 'dead_expanding_max',
 'rolling_dead_std',
 'PM2.5 air pollution, population exposed to levels exceeding WHO guideline value (% of total)',
 'Incidence of tuberculosis (per 100,000 people)',
 'People with basic handwashing facilities including soap and water (% of population)',
 'rolling_dead_max',
 'Net migration',
 'medianage',
 'recovered_lag_6',
 'confirmed_expanding_max',
 'dead_ridge_coef',
 'Population in the largest city (% of urban population)',
 'lag2_minus_max',
 'Out-of-pocket expenditure (% of current health expenditure)',
 'dead_ridge_bias',
 'confirmed_lag_5',
 'confirmed_lag_1',
 'confirmed_lag_3',
 'dead_poly_2',
 'quarantine',
 'dead_lag_3',
 'confirmed_lag_4',
 'rolling_confirmed_max',
 'restrictions',
 'Current health expenditure per capita, PPP (current international $)',
 'Urban population (% of total)',
 'confirmed_ridge_pred',
 'confirmed_poly_pred',
 'days_since_quar_start',
 'recovered_ridge_bias',
 'confirmed_ridge_coef',
 'dead_lag_6',
 'Air transport, passengers carried',
 'confirmed_lag_7',
 'Population ages 65 and above (% of total)',
 'dead_poly_pred',
 'International tourism, number of departures',
 'Prevalence of HIV, total (% of population ages 15-49)',
 'dead_poly_1',
 'recovered_lag_3',
 'confirmed_poly_1',
 'Smoking prevalence, males (% of adults)',
 'International tourism, number of arrivals',
 'rolling_confirmed_mean',
 'dead_poly_3']

cat_features = ['quar_type', 'Region']

features  = num_features + cat_features

In [0]:
data = pd.read_pickle('data.pkl')

In [0]:
last_train_day = pd.to_datetime('2020/04/17')
last_test_day = pd.to_datetime('2020/12/31')
targets = ['confirmed', 'recovered', 'dead']

# Let's try to fit something

In [21]:
max_shift = 8
max_shift_delta = pd.Timedelta(days=max_shift)

def ridge_features(y):
    # returns Ridge params and predictions a week ahead
    pred_day = 8

    if (y.isna().any()) or (y == 0).all():
        return pd.Series(np.zeros((3,)), index=['ridge_bias', 'ridge_coef', 'ridge_pred'])        
    x = np.arange(1, len(y) + 1).reshape(len(y), -1)
    y = y[::-1]
    r = Ridge()
    r.fit(x, y)
    pred = r.predict([[pred_day]])[0]
    return pd.Series([r.coef_[0], r.intercept_, pred], index=['ridge_bias', 'ridge_coef', 'ridge_pred'])

def poly_features(y):
    # fit polynomial to data, to better see growth
    # returns poly coeffs and pred
    deg = 3  # polynomial up to deg order
    pred_day = 8 # day we make predictions for
    cnames = [f'poly_{i}' for i in range(deg + 1)] + ['poly_pred']
    if (y.isna().any()) or (y == 0).all():
        return pd.Series(np.zeros((len(cnames),)), index=cnames)
    x = np.arange(1, len(y) + 1)
    y = y[::-1]
    params = np.polyfit(x, y, deg)[::-1]
    pred = np.polyval(params[::-1], [pred_day])
    return pd.Series(np.append(params, pred), index=cnames)

def add_time_features(data):
    num_features = []
    cat_features = []
    
    # add lag features
    lags = [1, 2, 3, 4, 5, 6, 7]
    for lag in lags:
        lag_features = data.groupby('Country/Region')[['confirmed', 'dead', 'recovered']].shift(lag)
        lag_features.columns = [f'{col}_lag_{lag}' for col in lag_features.columns]
        num_features += lag_features.columns.to_list()
        data.drop(lag_features.columns, axis=1, inplace=True, errors='ignore')
        data = pd.concat([data, lag_features], axis=1)

    # rolling statistics
    window = 7
    rollings = data.groupby('Country/Region')[['confirmed', 'dead', 'recovered']].shift(1).rolling(window).agg(['mean', 'std', 'max', 'min'])
    rollings.columns = ['rolling_' + '_'.join(i) for i in rollings.columns]
    num_features += rollings.columns.to_list()
    data.drop(rollings.columns, axis=1, inplace=True, errors='ignore')
    data = pd.concat([data, rollings], axis=1)

    # max target delta expanding
    for t in targets:
        data[t + '_expanding_max'] = data.groupby('Country/Region')[t].transform(lambda x: x.shift(1).expanding().max())
        num_features += [t + '_expanding_max']

    #  max_confirmed - lag1_confirmed, max_confirmed - lag2_confirmed

    lag_minus_max_features = ['lag1_minus_max', 'lag2_minus_max', 'lag3_minus_max']
    lag_minus_max = [1, 2, 3]
    for l, col_name in zip(lag_minus_max, lag_minus_max_features):
        data[col_name] = data['confirmed_expanding_max'] - data[f'confirmed_lag_{l}']
    num_features += lag_minus_max_features

    if True:
        # ridge regression and polynomial on lags for confirmed, recovered and dead
        for t in ['confirmed_', 'recovered_', 'dead_']:
            t_lags = data.columns[data.columns.str.contains(t + 'lag_')]
            rfs = data[t_lags].apply(ridge_features, axis=1, result_type='expand')
            rfs.columns = [t + i for i in rfs.columns]
            data.drop(rfs.columns, axis=1, inplace=True, errors='ignore')
            data = pd.concat([data, rfs], axis=1)
            num_features += rfs.columns.to_list()

            pfs = data[t_lags].apply(poly_features, axis=1, result_type='expand')
            pfs.columns = [t + i for i in pfs.columns]
            data.drop(pfs.columns, axis=1, inplace=True, errors='ignore')
            data = pd.concat([data, pfs], axis=1)
            num_features += pfs.columns.to_list()

    return data, num_features, cat_features


class CovidPredictor(object):

    def __init__(self, last_train_day, data=None):
        self.targets = targets
        self.models = {}
        for t in self.targets:
            self.models[t] = lgb.Booster(model_file=f'models/{t}.txt')
        self.last_train_day = last_train_day
        if data is None:
            self.data = pd.read_pickle('data.pkl')
        else:
            self.data = data
        self.simulation_data = self.data.copy()


    def dump_simulation(self):
        self.simulaton_data = self.data.copy()


    def _run_lgb(self, test):
        preds = pd.DataFrame()
        for key in self.models.keys():
            pred = np.maximum(0, self.models[key].predict(test[features]))
            pred = np.round(pred).astype('int')
            preds[key] = pred
        return preds


    def predict(self, h):
        # returns train_data + h predicted days
        # h - forecast horizont
        data = self.simulation_data.copy()

        for i in range(1, h + 1):
            pred_day = self.last_train_day + pd.Timedelta(days=i)
            print(pred_day)

            # recalculate time_features for pred_day
            recalc_window = (data['date'] >= pred_day - max_shift_delta) & (data['date'] <= pred_day)
            recalced = add_time_features(data.loc[recalc_window])[0]
            data.loc[data['date'] == pred_day, :] = recalced.loc[recalced['date'] == pred_day, :]
            xt = data[data['date'] == pred_day]
            data.loc[data['date'] == pred_day, self.targets] = self._run_lgb(xt).values

        return data[data['date'] <= pred_day]

predictor = CovidPredictor(last_train_day)
preds = predictor.predict(3)

2020-04-18 00:00:00
2020-04-19 00:00:00
2020-04-20 00:00:00


In [24]:
preds[preds['Country/Region'] == 'US']

Unnamed: 0,Country/Region,Lat_x,Long_x,date,dead,Lat_y,Long_y,confirmed,Lat,Long,recovered,Country_x,Place,Start date,End date,quar_type,days_since_quar_start,iso_alpha2,iso_alpha3,iso_numeric,name,official_name,ccse_name,density,fertility_rate,land_area,median_age,migrants,population,urban_pop_rate,world_share,Country Code,"Air transport, passengers carried","Cause of death, by communicable diseases and maternal, prenatal and nutrition conditions (% of total)","Cause of death, by non-communicable diseases (% of total)","Current health expenditure per capita, PPP (current international $)","Death rate, crude (per 1,000 people)",Diabetes prevalence (% of population ages 20 to 79),"GDP per capita, PPP (current international $)","Hospital beds (per 1,000 people)",...,rolling_confirmed_max,rolling_confirmed_min,rolling_dead_mean,rolling_dead_std,rolling_dead_max,rolling_dead_min,rolling_recovered_mean,rolling_recovered_std,rolling_recovered_max,rolling_recovered_min,confirmed_expanding_max,dead_expanding_max,recovered_expanding_max,lag1_minus_max,lag2_minus_max,lag3_minus_max,confirmed_ridge_bias,confirmed_ridge_coef,confirmed_ridge_pred,confirmed_poly_0,confirmed_poly_1,confirmed_poly_2,confirmed_poly_3,confirmed_poly_pred,recovered_ridge_bias,recovered_ridge_coef,recovered_ridge_pred,recovered_poly_0,recovered_poly_1,recovered_poly_2,recovered_poly_3,recovered_poly_pred,dead_ridge_bias,dead_ridge_coef,dead_ridge_pred,dead_poly_0,dead_poly_1,dead_poly_2,dead_poly_3,dead_poly_pred
2021,US,37.0902,-95.7129,2020-02-01,0.0,37.0902,-95.7129,1.0,37.0902,-95.7129,0.0,US,California,2020-03-19,2020-03-29,12,0.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,5.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,3.0,0.0,0.0,1.0,3.0,3.0,0.000000,0.714286,0.714286,-2.857143,4.742063,-1.547619,0.138889,7.142857,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2206,US,37.0902,-95.7129,2020-02-02,0.0,37.0902,-95.7129,0.0,37.0902,-95.7129,0.0,US,California,2020-03-19,2020-03-29,12,0.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,1.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,3.0,0.0,0.0,2.0,1.0,3.0,-0.068966,1.133005,0.581281,8.000000,-6.531746,1.571429,-0.111111,-0.571429,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2391,US,37.0902,-95.7129,2020-02-03,0.0,37.0902,-95.7129,3.0,37.0902,-95.7129,0.0,US,California,2020-03-19,2020-03-29,12,0.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,3.0,0.0,0.0,3.0,2.0,1.0,0.137931,-0.123153,0.980296,2.000000,-2.702381,0.928571,-0.083333,-2.857143,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2576,US,37.0902,-95.7129,2020-02-04,0.0,37.0902,-95.7129,0.0,37.0902,-95.7129,0.0,US,California,2020-03-19,2020-03-29,12,0.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,3.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,3.0,0.0,0.0,0.0,3.0,2.0,0.344828,-0.522167,2.236453,-2.000000,2.253968,-0.619048,0.055556,4.857143,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2761,US,37.0902,-95.7129,2020-02-05,0.0,37.0902,-95.7129,0.0,37.0902,-95.7129,0.0,US,California,2020-03-19,2020-03-29,12,0.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,6.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,3.0,0.0,0.0,3.0,0.0,3.0,0.137931,0.305419,1.408867,-0.142857,-0.043651,0.214286,-0.027778,-1.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15896,US,37.0902,-95.7129,2020-04-16,4591.0,37.0902,-95.7129,31451.0,37.0902,-95.7129,2607.0,US,California,2020-03-19,2020-03-29,12,28.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,28680.0,1.0,373.142857,936.181683,2494.0,0.0,757.857143,1608.755153,4333.0,0.0,35098.0,2494.0,10494.0,6418.0,8047.0,9792.0,-1095.310345,33995.241379,25232.758621,28705.571429,6697.694444,-2678.095238,247.138889,37424.142857,595.241379,1695.748768,6457.679803,5918.857143,-5293.400794,1967.261905,-178.694444,-2015.142857,74.344828,1649.763547,2244.522167,1470.714286,606.257937,-234.785714,24.527778,3852.714286
16081,US,37.0902,-95.7129,2020-04-17,3857.0,37.0902,-95.7129,31905.0,37.0902,-95.7129,3842.0,US,California,2020-03-19,2020-03-29,12,29.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,31451.0,0.0,674.714286,1727.534246,4591.0,0.0,589.857143,1030.571917,2607.0,0.0,35098.0,4591.0,10494.0,3647.0,6418.0,8047.0,-523.068966,31572.847291,27388.295567,41335.714286,-7283.845238,957.345238,-16.666667,35801.714286,136.206897,3639.886700,4729.541872,4518.285714,-2907.861111,1372.071429,-144.138889,-4731.142857,325.206897,1047.458128,3649.113300,2150.428571,98.952381,-183.964286,31.083333,7083.000000
16266,US,37.0902,-95.7129,2020-04-18,1984.0,37.0902,-95.7129,31983.0,37.0902,-95.7129,3976.0,US,California,2020-03-19,2020-03-29,12,30.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,31905.0,0.0,569.142857,1450.566605,3857.0,0.0,783.285714,1463.625489,3842.0,0.0,35098.0,4591.0,10494.0,3193.0,3647.0,6418.0,502.551724,27014.221675,31034.635468,36330.428571,-7562.833333,1748.333333,-107.333333,32766.428571,-10.241379,4291.679803,4209.748768,-6625.857143,10401.448413,-2568.416667,184.277778,6557.285714,448.137931,805.591133,4390.694581,3843.000000,-2565.865079,767.952381,-56.611111,3480.142857
16451,US,37.0902,-95.7129,2020-04-19,1982.0,37.0902,-95.7129,27082.0,37.0902,-95.7129,3576.0,US,California,2020-03-19,2020-03-29,12,31.0,US,USA,840.0,United States,United States of America,US,36.0,1.8,9147420.0,38.0,954806.0,331002651.0,0.83,0.0425,USA,824039000.0,5.2,88.3,9869.742382,8.493,0.0,57904.204329,0.0,...,31983.0,1.0,302.857143,742.955007,1984.0,0.0,755.714286,1487.932874,3976.0,0.0,31983.0,4591.0,10494.0,0.0,78.0,532.0,924.000000,25631.571429,33023.571429,35471.714286,-9437.234127,2814.261905,-220.361111,27261.714286,-282.896552,5596.014778,3332.842365,-6308.428571,12668.238095,-3641.404762,294.000000,12515.571429,285.000000,1473.571429,3753.571429,3791.428571,-3264.785714,1248.797619,-116.916667,-2265.142857
