In [1]:
import pandas as pd
import plotly.express as px
import json
from urllib.request import urlopen
import re

In [2]:
df = pd.read_csv(f'results/Influenza_county.csv', dtype={'fips': str})

## Get 6 weeks in each flu season

In [3]:
def foo(date, year):
    r = re.compile('\d{4}-\d{2}-\d{2}')
    if not r.match(date):
        return False
    month = int(date.split('-')[1])
    a = date.startswith(str(year)) and month >= 10
    b = date.startswith(str(year + 1)) and month <= 1
    return a or b

years = list(range(2003, 2012))

starts_of_season = []
seasons = []
for year in years[:-1]:
    cols = [col for col in df.columns if foo(col, year)][::3]
    starts_of_season.append(cols[0])
    seasons.append(cols)

## Plot

In [5]:
def plot(
    df, 
    plot_col, 
    color_min=None, 
    color_max=None, 
    save_fname=None):

    counties_webpage = 'https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json'
    with urlopen(counties_webpage) as response:
        counties = json.load(response)

    if color_min is None:
        color_min = df[plot_col].min()
    if color_max is None:
        color_max = df[plot_col].max()

    fig = px.choropleth(
        df, 
        geojson=counties, 
        locations='fips', 
        color=plot_col,
        color_continuous_scale=px.colors.sequential.RdBu_r[1:-1],
        range_color=(color_min, color_max),
        scope="usa",
    )

    # ============================ update figure properties ===========================START
    fig.update_layout(
        coloraxis_showscale=False,
        margin={
            'l': 0, 
            'b': 0, 
            'r': 0, 
            't': 0
        }
    )
    fig.update_traces(marker_line_width=0, marker_line_color='lightgray')
    # ============================ update figure properties ===========================END
    
    if save_fname is not None:
        fig.write_image(
            file=save_fname, 
            format='pdf', 
            # engine='kaleido', 
            width=900, 
            height=540)

    config = {'scrollZoom': False, 'displayModeBar': False}
    fig.show(config=config)

In [4]:
# for i, season in enumerate(seasons[:3]):
#     for date in season:
#         print(date)
#         save_fname = f'results/season_{i}_{date}.pdf'
#         plot(df, date, color_max=5, save_fname=save_fname)

In [5]:
# tmp = df[starts_of_season + ['fips']].set_index('fips').mean(axis=1).reset_index()
# save_fname = 'results/season_start_avg.pdf'
# plot(tmp, 0, color_max=5, save_fname=save_fname)