<a href="https://colab.research.google.com/github/trovatore/covid/blob/master/Covid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import files
import pandas as pd
import numpy as np
import sys
import io
from datetime import datetime, timedelta
import pytz
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import matplotlib.patches as mpatches
from sklearn.linear_model import LinearRegression
# sys.version


In [0]:
def date_to_ordinal(date):
  return datetime.strptime(date, '%Y-%m-%d').toordinal()

def dates_to_elapsed_days(dates, zero_date):
  zero_ordinal = date_to_ordinal(zero_date)
  return np.array([date_to_ordinal(date) - zero_ordinal for date in dates])

def find_fips_for_county(data, county, state):
  records = data.to_dict(orient='records')
  for record in records:
    if record['county'] == county and record['state'] == state:
      return int(record['fips'])

def find_fips_for_state(data, state):
  records = data.to_dict(orient='records')
  for record in records:
    if record['state'] == state:
      return int(record['fips'])


In [0]:
def _do_regression(elapsed_days, data_for_place, mask, ax, color, value_name):
  x = elapsed_days[mask].reshape(-1, 1)
  y = data_for_place[value_name].values[mask].reshape(-1, 1)
  lr = LinearRegression()
  lr.fit(x, np.log2(y))
  y_pred = lr.predict(x)
  ax.plot(x, np.exp(np.log(2) * y_pred), color + "-")
  return 1./lr.coef_[0][0]


In [0]:
regression_num_days = 5  #@param


In [0]:
def plot_place(data_for_place, title, reference_date, regression_interval=None):
  """
  Args:
    data_for_place: DataSeries with data for the particular place.
    title: Title for the plot.
    reference_date: YYYY-MM-DD string with a date of interest (for example,
      the start of shelter-in-place for the location).
    regression_interval: List of two dates giving the interval to run
      linear regression.
  """
  # print('data_for_place={}'.format(data_for_place))

  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(111)
  elapsed_days = dates_to_elapsed_days(data_for_place['date'], reference_date)
  line_cases = plt.plot(elapsed_days, data_for_place['cases'], 'bo')
  line_deaths = plt.plot(elapsed_days, data_for_place['deaths'], 'ro')

  if regression_interval is None:
    end_date = max(data_for_place['date'])
    start_date = (datetime.strptime(end_date, '%Y-%m-%d')
                 - timedelta(
                     days=regression_num_days - 1)).strftime('%Y-%m-%d')
    regression_interval = [start_date, end_date]

  title = ( 
      title + " (curve fit from " + regression_interval[0] + " to " +
      regression_interval[1] + ")"
  )
  regression_interval = dates_to_elapsed_days(regression_interval,
                                              reference_date)
  start_day = regression_interval[0]
  end_day = regression_interval[1]
  mask = (elapsed_days >= start_day) & (elapsed_days <= end_day)
  cases_doubling_time = _do_regression(elapsed_days, data_for_place, mask,
                                        ax, 'b', 'cases' )
  deaths_doubling_time = _do_regression(elapsed_days, data_for_place, mask,
                                        ax, 'r', 'deaths' )
    
  ax.set_ylabel('log scale')
  ax.set_xlabel('days since {}'.format(reference_date))
  ax.set_yscale('log')
  ax.yaxis.set_major_formatter(ScalarFormatter())
  ax.set_title(title)
  if regression_interval is not None:
    cases_label = 'cases (doubling time={:0.1f} days)'.format(cases_doubling_time)
    deaths_label = 'deaths (doubling time={:0.1f} days)'.format(deaths_doubling_time)
  else:
    cases_label = 'cases'
    deaths_label = 'deaths'
  cases_patch = mpatches.Patch(color='blue', label=cases_label)
  deaths_patch = mpatches.Patch(color='red', label=deaths_label)
  ax.legend(handles=[cases_patch, deaths_patch])
  plt.grid()  

In [0]:
def plot_daily_new(data_for_place, title, reference_date):
  """
  Args:
    data_for_place: DataSeries with data for the particular place.
    title: Title for the plot.
    reference_date: YYYY-MM-DD string with a date of interest.
  """
  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(111)
  elapsed_days = dates_to_elapsed_days(data_for_place['date'], reference_date)
  assert (elapsed_days == sorted(elapsed_days)).all()
  assert elapsed_days[-1] - elapsed_days[0] + 1 == len(elapsed_days)
  line_cases = plt.plot(
      elapsed_days, np.diff(data_for_place['cases'], prepend=0), 'bo')
  line_deaths = plt.plot(
      elapsed_days, np.diff(data_for_place['deaths'], prepend=0), 'ro')
    
  ax.set_ylabel('log scale')
  ax.set_xlabel('days since {}'.format(reference_date))
  ax.set_yscale('log')
  ax.yaxis.set_major_formatter(ScalarFormatter())
  ax.set_title(title)
  cases_label = 'New cases'
  deaths_label = 'New deaths'
  cases_patch = mpatches.Patch(color='blue', label=cases_label)
  deaths_patch = mpatches.Patch(color='red', label=deaths_label)
  ax.legend(handles=[cases_patch, deaths_patch])
  plt.grid()  

In [0]:
def _get_new_and_intercept(events):
  events_new = np.diff(events, prepend=0)
  mask = (events_new > 0)
  events_new_censored = events_new[mask]
  events_censored = events[mask]
  log_10_intercept = np.mean(np.log10(events_new_censored) -
                             np.log10(events_censored))
  return events_new, 10. ** log_10_intercept


In [0]:
def plot_new_versus_cumulative(data_for_place, title, reference_date):
  """
  Args:
    data_for_place: DataSeries with data for the particular place.
    title: Title for the plot.
    reference_date: YYYY-MM-DD string with a date of interest.
    regression_interval: List of two dates giving the interval to run
      linear regression.
  """
  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(111)
  cases = data_for_place['cases']
  deaths = data_for_place['deaths']
  cases_new, cases_intercept = _get_new_and_intercept(data_for_place['cases'])
  deaths_new, deaths_intercept = _get_new_and_intercept(
      data_for_place['deaths'])
  line_cases = plt.plot(cases, cases_new, 'bo')
  line_deaths = plt.plot(deaths, deaths_new, 'ro')

  cases_above_10 = cases[cases >= 10]
  plt.plot(cases_above_10, cases_above_10 * cases_intercept, 'b--')
  deaths_above_10 = deaths[deaths >= 10]
  plt.plot(deaths_above_10, deaths_above_10 * deaths_intercept, 'r--')
    
  ax.set_ylabel('new')
  ax.set_xlabel('cumulative')
  ax.set_yscale('log')
  ax.set_xscale('log')
  ax.xaxis.set_major_formatter(ScalarFormatter())
  ax.yaxis.set_major_formatter(ScalarFormatter())
  ax.set_title(title)
  cases_label = 'cases'
  deaths_label = 'deaths'
  cases_patch = mpatches.Patch(color='blue', label=cases_label)
  deaths_patch = mpatches.Patch(color='red', label=deaths_label)
  ax.legend(handles=[cases_patch, deaths_patch])
  ax.set_aspect('equal')
  plt.grid()  

In [0]:
def _get_jhu_data_for_county(data, county, state):
  state_names = data['Province_State']
  county_names = data['Admin2']
  mask_state = (state_names == state)
  mask_county = (county_names == county)
  county_dict = data[mask_state][mask_county].to_dict(orient='records')
  assert(len(county_dict) == 1)
  county_dict = county_dict[0]
  reformatted_dict = {}
  for key, value in county_dict.items():
    if str(key[0]).isnumeric():
      date = datetime.strptime(key, '%m/%d/%y').strftime('%Y-%m-%d')
      reformatted_dict[date] = value
  return reformatted_dict

def _get_data_for_county(county, state, use_nyt=True):
  if use_nyt:
    fips = find_fips_for_county(data_counties_nyt, county, state)
    return data_counties_nyt[data_counties_nyt['fips']==fips]
  else:
    cases = _get_jhu_data_for_county(cases_jhu, county, state)
    deaths = _get_jhu_data_for_county(deaths_jhu, county, state)
    cases_dates = set(cases.keys())
    deaths_dates = set(deaths.keys())
    assert cases_dates == deaths_dates
    dates = dict([(date, date) for date in cases.keys()])
    # return pd.DataFrame([dates, cases, deaths], index=['date', 'cases', 'deaths']).T
    cases_series = pd.Series(cases, dtype=float, name='cases')
    deaths_series = pd.Series(deaths, dtype=float, name='deaths')
    dates_series = pd.Series(dates, name='date')
    return pd.concat([dates_series, cases_series, deaths_series], axis=1)


    

In [0]:
def _get_data_for_state(state, use_nyt=True):
  if use_nyt:
    fips = find_fips_for_state(data_states_nyt, state)
    return data_states_nyt[data_states_nyt['fips']==fips]

In [0]:
def plot_county(county, state, reference_date,
                use_nyt=True,
                regression_interval=None):
  data_county = _get_data_for_county(county, state, use_nyt)
  data_source = 'NYT' if use_nyt else 'JHU'
  title = (
      '{} County, {} ({} data)'.format(county, state, data_source)
  )
  plot_place(data_county, title, reference_date, regression_interval)
  plot_daily_new(data_county, title, reference_date)
  plot_new_versus_cumulative(data_county, title, reference_date)


In [0]:
def plot_state(state, reference_date, use_nyt=True, regression_interval=None):
  data_state = _get_data_for_state(state, use_nyt)
  data_source = 'NYT' if use_nyt else 'JHU'
  title = 'State of {} ({} data)'.format(state, data_source)
  plot_place(data_state, title, reference_date, regression_interval)
  plot_daily_new(data_state, title, reference_date)
  plot_new_versus_cumulative(data_state, title, reference_date)

In [0]:
#@title Flags for what to load

load_nyt = True  #@param
load_jhu = True  #@param

**Upload  NYT files**

To use NYT data, download `us-states.csv` and `us-counties.csv` from https://github.com/nytimes/covid-19-data and upload them here.


In [0]:
us_counties_csv_nyt = None
us_states_csv_nyt = None
data_counties_nyt = None
data_states_nyt = None

if load_nyt:
  nytfiles = files.upload()
  counties_name = 'us-counties.csv'
  if counties_name in nytfiles.keys():
    us_counties_csv_nyt = nytfiles[counties_name].decode('utf-8')
    data_counties_nyt = pd.read_csv(io.StringIO(us_counties_csv_nyt))
  states_name = 'us-states.csv'
  if states_name in nytfiles.keys():
    us_states_csv_nyt = nytfiles[states_name].decode('utf-8')
    data_states_nyt = pd.read_csv(io.StringIO(us_states_csv_nyt))

**Upload JHU files**

To use JHU data, extract `time_series_covid19_confirmed_global.csv` and `time_series_covid19_deaths_global.csv` from https://github.com/CSSEGISandData/COVID-19 and upload them here.  If you are planning to analyze only US places, you can save time by uploading `time_series_covid19_confirmed_US.csv` and `time_series_covid19_deaths_US.csv` instead.

In [0]:
cases_jhu = None
deaths_jhu = None

if load_jhu:
  jhufiles = files.upload()
  cases_name_us = 'time_series_covid19_confirmed_US.csv'
  cases_name_global = 'time_series_covid19_confirmed_global.csv
  if cases_name_global in jhufiles.keys():
    cases_jhu = pd.read_csv(io.StringIO(
        jhufiles[cases_name_global].decode('utf-8')))
  elif cases_name_us in jhufiles.keys():
    cases_jhu = pd.read_csv(io.StringIO(
        jhufiles[cases_name_us].decode('utf-8')))
  deaths_name_us = 'time_series_covid19_deaths_US.csv'
  deaths_name_global = 'time_series_covid19_deaths_global.csv'
  if deaths_name_global in jhufiles.keys():
    deaths_jhu = pd.read_csv(io.StringIO(
        jhufiles[deaths_name_global].decode('utf-8')))
  elif deaths_name_us in jhufiles.keys():
    deaths_jhu = pd.read_csv(io.StringIO(
        jhufiles[deaths_name_us].decode('utf-8')))


In [0]:
plot_county('Santa Clara', 'California', '2020-03-17', use_nyt=False)

In [0]:
plot_county('Los Angeles', 'California', '2020-03-19')

In [0]:
plot_state('California', '2020-03-19')

In [0]:
plot_state('Washington', '2020-03-20')

In [0]:
plot_state('South Dakota', '2020-04-02')