In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import datetime as dt
import sklearn as skl
import scipy.stats as stats
import sys
local_rel_path = '../data/'
sys.path.insert(0, local_rel_path)
import nytimes
import importlib
import models
import features

In [None]:
# get the data in

state_df, county_df = nytimes.get_nyt_data()

county_cases_ts = nytimes.convert_county_df_to_ts(county_df, quantity='cases')
county_deaths_ts = nytimes.convert_county_df_to_ts(county_df, quantity='deaths')
county_fips = nytimes.convert_county_df_to_ts(county_df, quantity='fips')

state_cases_ts = nytimes.convert_state_df_to_ts(state_df, quantity='cases')
state_deaths_ts = nytimes.convert_state_df_to_ts(state_df, quantity='deaths')

def preprocess_df(df):
    
    # drop county, state columns if exist
    df = df.drop(columns=['county', 'state'], errors='ignore')
    # fill nas with zeros
    df = df.fillna(0)
    # replace column indices with datetime objects
    df = df.rename(
        columns=lambda str_date: datetime.datetime.strptime(str_date, '%m/%d/%y'))
    
    return df

# do standard preprocessing below:
state_cases_ts = preprocess_df(state_cases_ts)
state_deaths_ts = preprocess_df(state_deaths_ts)

county_cases_ts = preprocess_df(county_cases_ts)
county_deaths_ts = preprocess_df(county_deaths_ts)


In [None]:
# demonstrate base forecast for Middlesex, Massachusetts

demo_county = ('Massachusetts', 'Middlesex')

demo_ts = county_cases_ts.loc[demo_county , :].copy()
demo_ts_daily = test_ts.diff()
# since we are demonstrating a base, not rolling forecast, shorten the 
# time series to the relevant portion, i.e. where case count is large
demo_ts_daily_short = test_ts_daily.loc[demo_ts > 10]
# now compute the forecasts
out = models.base_forecast_linear(demo_ts_daily_short) #point
out90 = models.base_forecast_linear(demo_ts_daily_short, quantile=0.9) # upper quantile
out10 = models.base_forecast_linear(demo_ts_daily_short, quantile=0.1) # lower quantile

In [None]:
# plot stuff nicely
def my_format_dates(ax):
    for label in ax.get_xticklabels():
            label.set_rotation(45)
            label.set_horizontalalignment('right')
            
fig, ax  = plt.subplots()
ax.plot(demo_ts_daily, color='xkcd:blue', marker='s', linestyle='')
ax.plot(out, color='xkcd:red')
ax.fill_between(out.index, 
               out10, out90, 
               where=out10<out90,
               color='xkcd:red',
               alpha=0.3)

#ax.plot(out90, ':', color='xkcd:red', alpha=0.7)
#ax.plot(out10, ':', color='xkcd:red', alpha=0.7)
ax.set_yscale('log')
my_format_dates(ax)