In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import geopandas as gpd
import geoplot as gplot
import geoplot as gplt
import matplotlib as mpl
import gc

In [None]:
def process_date(d):
    m,d,y = d.split('/')
    if len(m) < 2:
        m = '0' + m
    if len(d) < 2:
        d = '0' + d
    return '/'.join([m,d,y])

In [None]:
def rename_cols(df,ind):
    dct = dict(zip(list(df.columns[ind:]), [process_date(x) for x in list(df.columns[ind:])]))
    df.rename(columns=dct,inplace=True)

In [None]:
# load covid data
url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/'
conf = pd.read_csv(url + 'time_series_19-covid-Confirmed.csv')
dead = pd.read_csv(url + 'time_series_19-covid-Deaths.csv')
recs = pd.read_csv(url + 'time_series_19-covid-Recovered.csv')

In [None]:
conf.loc[conf['Province/State'].isna(), 'Province/State'] = 'none'
rename_cols(conf,4)
conf.loc[conf['Province/State'] == 'France','Province/State'] = 'none'
conf.loc[conf['Province/State'] == 'United Kingdom','Province/State'] = 'none'

dead.loc[dead['Province/State'].isna(), 'Province/State'] = 'none'
rename_cols(dead,4)
dead.loc[dead['Province/State'] == 'France','Province/State'] = 'none'
dead.loc[dead['Province/State'] == 'United Kingdom','Province/State'] = 'none'

recs.loc[recs['Province/State'].isna(), 'Province/State'] = 'none'
rename_cols(recs,4)
recs.loc[recs['Province/State'] == 'France','Province/State'] = 'none'
recs.loc[recs['Province/State'] == 'United Kingdom','Province/State'] = 'none'

In [None]:
# load geodata
url = 'https://raw.githubusercontent.com/deldersveld/topojson/master/'

# china
china = gpd.read_file(url + 'countries/china/china-provinces.json')
china.loc[china['NAME_1'] == 'Nei Mongol','NAME_1'] = 'Inner Mongolia'
china.loc[china['NAME_1'] == 'Xinjiang Uygur','NAME_1'] = 'Xinjiang'
china.loc[china['NAME_1'] == 'Xizang','NAME_1'] = 'Tibet'
china.loc[china['NAME_1'] == 'Ningxia Hui','NAME_1'] = 'Ningxia'
china.rename(columns={'NAME_1' : 'geounit'}, inplace=True)

# world
world = gpd.read_file(url + 'world-countries.json')
world = world.drop(world.loc[world['geometry'].is_empty].index)
world.rename(columns = {'name' : 'geounit'}, inplace=True)
world.loc[world['geounit'] == 'United States of America','geounit'] = 'US'

# europe
europe = gpd.read_file(url + 'continents/europe.json')
europe.drop(europe.loc[europe['geometry'].is_empty].index,inplace=True)
# remove Russie for now
europe.drop(europe.loc[europe['geounit']=='Russia'].index,inplace=True)
# make the UK
europe.loc[europe['geounit'] == 'England','geounit'] = 'United Kingdom'
europe.loc[europe['geounit'] == 'Scotland','geounit'] = 'United Kingdom'
europe.loc[europe['geounit'] == 'Wales','geounit'] = 'United Kingdom'
europe.loc[europe['geounit'] == 'Northern Ireland','geounit'] = 'United Kingdom'

# usa
usa = gpd.read_file(url + 'countries/united-states/us-albers.json')
usa = usa.drop(usa.loc[usa['geometry'].is_empty].index)
usa.rename(columns={'name' : 'geounit'},inplace=True)

# africa
africa = gpd.read_file(url + 'continents/africa.json')
africa.drop(africa.loc[africa['geometry'].is_empty].index,inplace=True)

# asia
asia = gpd.read_file(url + 'continents/asia.json')
asia.drop(asia.loc[asia['geometry'].is_empty].index,inplace=True)
# remove Russie for now
asia.drop(asia.loc[asia['geounit']=='Russia'].index,inplace=True)

# south america
samerica = gpd.read_file(url + 'continents/south-america.json')
samerica.drop(samerica.loc[samerica['geometry'].is_empty].index,inplace=True)

In [None]:
# some defs to use later
feature_spread='new_cases'
title_spread = 'cumulative cases'

feature_deaths='new_deaths'
title_deaths='cumulative deaths'

start_date=None

FIG_DIMS_H=(12,12)
FIG_DIMS_V=(8,12)
FIG_DPI=150

In [None]:
def print_full(x):
    pd.set_option('display.max_rows', x.shape[0])
    pd.set_option('display.max_columns', x.shape[1])
    print(x)
    pd.reset_option('display.max_rows')
    pd.reset_option('display.max_columns')

In [None]:
def log_scale(minval, maxval):
    def scalar(val):
        val = val + abs(minval) + 1
        return np.log10(val)
    return scalar

In [None]:
def identity_scale(minval, maxval):
    def scalar(val):
        return 2
    return scalar

In [None]:
def render_sub_frame(params=None): 
    # normalize colormap
    norm = mpl.colors.Normalize(vmin=params['color_min'], vmax=params['color_max'])
    cmap = mpl.cm.ScalarMappable(norm=norm, cmap=params['cmap']).cmap
    
    mode = params['mode']
    data = params['data']
    country = params['country']
    date = params['date']
    points = params['points']
        
    if mode == 'global':
        s_glob = data
        #s_glob = data.loc[data['Province/State'] == 'none']
        #s_locl = data.loc[data['Province/State'] != 'none']      
        
        sg = s_glob.loc[:,['Country/Region', date]].groupby('Country/Region').sum()
        
    elif mode == 'country':
        if country != None:
            data = data.loc[data['Country/Region'] == country]

            s_glob = data.loc[~data['Province/State'].str.contains(',')]
            s_locl = data.loc[data['Province/State'].str.contains(',')]
        
            sl = gpd.GeoDataFrame(s_locl,
                              geometry=gpd.points_from_xy(s_locl['Long'], s_locl['Lat']))
            sg = s_glob.loc[:,['Province/State', date]].groupby('Province/State').sum()
        else:
            print('Country not specified.')
    
    else:
        print('Invalid mode.')
        
    # merge w/ map data
    geodata = params['map_data'].copy()
    geodata['raw'] = np.zeros(geodata.shape[0])
    geodata['log10'] = np.zeros(geodata.shape[0])
    
    for geounit in sg.index:
        geodata.loc[geodata['geounit'] == geounit,'log10'] = np.log10(sg.loc[geounit,date])
        geodata.loc[geodata['geounit'] == geounit,'raw'] = sg.loc[geounit,date]
    
    gplot.choropleth(geodata,
                     hue='log10',
                     cmap=cmap,
                     ax=params['ax'],
                     norm=norm,
                     legend = True,
                    )
    
    if mode == 'country' and points: 
        gplot.pointplot(sl,
                        ax=params['ax'],
                        scale=date,
                        scale_func=identity_scale,
                        limits=(2,20)
                        )
    
    if params['annot']:
        for ix,row in geodata.iterrows():
            centroid = row['geometry'].centroid.coords
            x0 = centroid[0][0]
            y0 = centroid[0][1]
            params['ax'].text(x0, y0, int(row['raw']), fontsize=10)
    
    # derive date and time for plot title
    str_date = params['date']
    params['ax'].set_title('COVID19 '+ params['title_annot'] +' over time ' + params['date'])
    #del geodata

In [None]:
def plot_by_region(start_date=None,
                   end_date=None,
                   dims=(12,12),
                   dpi=150,
                   annot=True,
                   save_path=None,
                   mode='global',
                   country=None,
                   points=False,
                   params=None,
                   verbose=True):
    ''' wrapper around render_subframe_by_country and render_subframe_by_state'''
    
    n_plots = len(params)
    
    dates_l = [] 
    for fdict in params:
        dates_l.append(set(fdict['data'].columns[4:]))
        
    dates = list(set.intersection(*dates_l))
    dates.sort()
    
    start_ind = 0
    end_ind = len(dates)
    
    if start_date != None:
        start_ind = dates.index(start_date)
        
    if end_date != None:
        end_ind = dates.index(end_date) + 1
        
    dates = dates[start_ind:end_ind]
    
    #print(dates)
    for date in dates:        
        fig,axs = plt.subplots(n_plots, figsize=dims)
        
        ax_ind = 0
        for fdict in params:
            fdict['date'] = date
            fdict['annot'] = annot
            fdict['mode'] = mode
            fdict['country'] = country
            fdict['points'] = points
                
            if n_plots == 1:
                fdict['ax'] = axs
            else:
                fdict['ax'] = axs[ax_ind]
                ax_ind += 1
            
            render_sub_frame(params=fdict)
            
        fig.tight_layout()
            
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        if verbose:
            print('Created %s' %save_path)
            
        date_str = date.replace('/','')
        date_str = str(date_str[4:]) + str(date_str[0:2]) + str(date_str[2:4])
    
        fname = save_path + date_str + '.png'
        if verbose:
            print(fname)
            
        plt.savefig(fname, dpi=FIG_DPI)
        fig.clf()
        plt.close()
        
        # maybe running into RAM problems because creating too many figures
        gc.collect()

In [None]:
start_date = '03/20/20'

In [None]:
# Europe
color_max_europe_spread=4  # 10000
color_max_europe_deaths=3   # 1000 

europe_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_europe_spread,
    'cmap' : 'Greens',
    'map_data' : europe,
    'feature': feature_spread,
    'title_annot' : title_spread
}

europe_death = {
    'data' : dead, 
    'color_min' : 0,
    'color_max' : color_max_europe_deaths,
    'cmap' : 'Reds',
    'map_data' : europe,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_H,
               annot=True,
               verbose=False,
               params=[europe_spread, europe_death],
               mode='global',
               save_path='./europe/',
               dpi=FIG_DPI)

In [None]:
# USA
color_max_usa_spread=3   # 1000
color_max_usa_deaths=1   # 10

usa_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_usa_spread,
    'cmap' : 'Greens',
    'map_data' : usa,
    'feature': feature_spread,
    'title_annot' : title_spread
}

usa_death = {
    'data' : dead,
    'color_min' : 0,
    'color_max' : color_max_usa_deaths,
    'cmap' : 'Reds',
    'map_data' : usa,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_H,
               annot=True,
               verbose=False,
               params=[usa_spread, usa_death],
               mode='country',
               save_path='./usa/',
               country='US',
               dpi=FIG_DPI)

In [None]:
# Africa
color_max_africa_spread=2   # 100
color_max_africa_deaths=1   # 10

africa_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_africa_spread,
    'cmap' : 'Greens',
    'map_data' : africa,
    'feature': feature_spread,
    'title_annot' : title_spread
}

africa_death = {
    'data' : dead,
    'color_min' : 0,
    'color_max' : color_max_africa_deaths,
    'cmap' : 'Reds',
    'map_data' : africa,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_V,
               annot=True,
               verbose=False,
               params=[africa_spread, africa_death],
               mode='global',
               save_path='./africa/',
               dpi=FIG_DPI)

In [None]:
# South America
color_max_samerica_spread=2.7  # 500
color_max_samerica_deaths=1   # 1 

samerica_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_samerica_spread,
    'cmap' : 'Greens',
    'map_data' : samerica,
    'feature': feature_spread,
    'title_annot' : title_spread
}

samerica_death = {
    'data' : dead,
    'color_min' : 0,
    'color_max' : color_max_samerica_deaths,
    'cmap' : 'Reds',
    'map_data' : samerica,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_V,
               annot=True,
               verbose=False,
               params=[samerica_spread, samerica_death],
               mode='global',
               save_path='./samerica/',
               dpi=FIG_DPI)

In [None]:
# Asia
color_max_asia_spread=4    #10,000
color_max_asia_deaths=3    # 1,000

asia_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_asia_spread,
    'cmap' : 'Greens',
    'map_data' : asia,
    'feature': feature_spread,
    'title_annot' : title_spread
}

asia_death = {
    'data' : dead,
    'color_min' : 0,
    'color_max' : color_max_asia_deaths,
    'cmap' : 'Reds',
    'map_data' : asia,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_H,
               annot=True,
               verbose=False,
               params=[asia_spread, asia_death],
               mode='global',
               save_path='./asia/',
               dpi=FIG_DPI)

In [None]:
# entire world
color_max_world_spread=4    # 10000
color_max_world_deaths=3 # 1000

world_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_world_spread,
    'cmap' : 'Greens',
    'map_data' : world,
    'feature': feature_spread,
    'title_annot' : title_spread
}

world_death = {
    'data' : dead,
    'color_min' : 0,
    'color_max' : color_max_world_deaths,
    'cmap' : 'Reds',
    'map_data' : world,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_H,
               annot=False,
               verbose=False,
               params=[world_spread, world_death],
               mode='global',
               save_path='./world/',
               dpi=FIG_DPI)

In [None]:
# China
color_max_china_spread=4    #10,000
color_max_china_deaths=3    # 1,000

china_spread = {
    'data' : conf,
    'color_min' : 0,
    'color_max' : color_max_china_spread,
    'cmap' : 'Greens',
    'map_data' : china,
    'feature': feature_spread,
    'title_annot' : title_spread
}

china_death = {
    'data' : dead,
    'color_min' : 0,
    'color_max' : color_max_china_deaths,
    'cmap' : 'Reds',
    'map_data' : china,
    'feature': feature_deaths,
    'title_annot' : title_deaths
}

plot_by_region(start_date=start_date,
               dims=FIG_DIMS_H,
               annot=True,
               verbose=False,
               params=[china_spread, china_death],
               mode='country',
               save_path='./china/',
               country='China',
               dpi=FIG_DPI)