In [None]:
import os 
import numpy as np
import pandas as pd
import matplotlib.dates as mdates
from datetime import datetime, timedelta
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import circstd
from sklearn.preprocessing import MinMaxScaler
from matplotlib.backends.backend_pdf import PdfPages
from PyPDF2 import PdfMerger
from PIL import Image, ImageDraw
import io
from PyPDF2 import PdfWriter, PdfReader


# path = os.getcwd()
path = '/Users/tamannaurmi/Documents/Research/behavior_covid'
path_processed = path + '/data/processed'

In [None]:

def monthly_jhu_cases(dt):
    dt1 = dt.drop(['UID', 'iso2', 'iso3', 'code3', 'FIPS', 'Admin2', 'Country_Region', 'Lat', 'Long_', 'Combined_Key'], axis = 1)

    # ## finding daily from cumulative
    dt2 = dt1.drop(dt1.columns[[0]], axis = 1)
    f = dt1.drop(dt1.columns[[0, 975]], axis = 1).values
    dt4 = np.pad(f, [(0, 0), (1, 0)], 'constant')
    dt5 = dt2.subtract(dt4)

    ## collapsing counties

    dt5['Province_State'] = dt1['Province_State']
    dt6 = dt5.groupby('Province_State').sum().reset_index()


    dt7 = dt6.T[1:]
    dt7['case_cnt_nat'] = dt7.apply(lambda row: sum(row), axis = 1)
    dt7['date'] = dt7.index
    dt7['month_yr'] = pd.to_datetime(dt7['date']).dt.to_period('M')
    dt8 = dt7[['month_yr', 'case_cnt_nat']]
    dt9 = dt8.groupby('month_yr').sum().reset_index()
    us_pop = 336997624
    dt9['case_cnt_norm'] = dt9['case_cnt_nat']*100/us_pop

    plt.figure()
    plt.plot(dt7[['date', 'case_cnt_nat']].case_cnt_nat)
    plt.show()
    
    return dt9


def lin_reg(x, y):
    '''
    Takes x and y and produces a regression, anomaly between observed and predicted values,
    the gradient and y-intercept and the regression line
    '''
    lr = LinearRegression()
    x = np.array(x[0:len(x)]).reshape(-1, 1)
    y = y.flatten()

    model = lr.fit(x, y)
    y_pred = model.predict(x)
    anomaly = y - y_pred
    m = model.coef_
    c = model.intercept_
    x_line = np.linspace(min(x)-((x[1]-x[0])), max(x), 20)
    y_line = m*x_line + c
    
    return x, y, anomaly, m, c, x_line, y_line  

def lin_reg_lag(x, y, lag):
    '''
    Takes x and y and produces a regression 
    anomaly between observed and predicted values
    at various lags between x an y,
    the gradient and y-intercept and the regression line
    '''
    lr = LinearRegression()
    x = np.array(x[0:len(x)]).reshape(-1, 1)
    y = y.flatten()
    if lag > 0:
        x_lag = x[lag:]
        y_lag = y[:lag]
    elif lag < 0:
        x_lag = x[:lag]
        y_lag = y[lag:]

    model = lr.fit(x_lag, y_lag)
    y_pred = model.predict(x_lag)
    anomaly = y_lag - y_pred
    m = model.coef_
    c = model.intercept_
    x_line = np.linspace(min(x_lag), max(x_lag), 20)
    y_line = m*x_line + c
    
    return x_lag, y_lag, anomaly, m, c, x_line, y_line 

def death_behavior_loc_wvend(death_agg_data, location, behavior, interval):
    
    if interval == 7:
        int_var = 'death_agg7'
    elif interval == 14:
        int_var = 'death_agg14'
    elif interval == 30:
        int_var = 'death_agg30'
    # print(location)
    if location == 'National':
        state_name = location
        cur_state_pop = 336997624
        state_code = 'National'
    elif len(location) == 2:
        state_name = code_to_state.get(location)
        cur_state_pop = state_pop[state_pop.NAME == state_name].POPESTIMATE2021.values[0]
        state_code = location
    else:
        state_name = location
        cur_state_pop = state_pop[state_pop.NAME == state_name].POPESTIMATE2021.values[0]
        state_code = state_to_code.get(location)

    dt = death_agg_data[(death_agg_data.State == state_name)][['Wave','new_ind', 'End_Date', int_var]]
    
    dt_b_d = behave_dt[behave_dt['state'] == state_code][['Wave','new_ind', 'End_Date', behavior]].reset_index().drop('index', axis = 1)
    dt_b_d = dt_b_d.merge(dt, how = 'left', on = ['Wave','new_ind', 'End_Date']).dropna()
    dt_b_d['death_pct'] = dt_b_d[int_var]*100/cur_state_pop
    
    b = dt_b_d[behavior].values
    dates = dt_b_d['End_Date'].values
    date_int = list(range(len(dates)))
    t = date_int
    t, b, ban, m, c, t_line, b_line = lin_reg(t, b)
    dt_b_d['behavior_anomaly'] = ban
    
    dt_b_d = dt_b_d.rename(columns={'End_Date': 'end_date'})

    return dt_b_d

def predictor(x, y, lag):
    
    lr = LinearRegression()
    x = np.array(x[0:len(x)-lag]).reshape(-1, 1)
    y = y[lag:].flatten()

    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.20)

    model = lr.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    m = model.coef_
    c = model.intercept_
    x_line = np.linspace(min(x), max(x), 20)
    y_line = m*x_line + c

    mae = (1/len(y_test))*(sum(abs(y_test - y_pred)))
    mse = (1/len(y_test))*(sum((y_test - y_pred)**2))
    mape = (1/len(y_test))*(sum(abs((y_test - y_pred)/(y_test+0.000001))))
    
    return m, c, mae, mse, mape


In [None]:
path_processed

In [None]:
## Get processed data
waves_interp = pd.read_csv(path_processed + '/processed_waves_with_gaps.csv')
behave_dt = pd.read_csv(path_processed + '/behavior_wave_interpolated_all_state.csv')
death_agg_dt = pd.read_csv(path_processed + '/death_monthly_state.csv')
jhu_cases = pd.read_csv(path + '/data/time_series_covid19_confirmed_US.csv')
hospit_df = pd.read_csv(path_processed + '/hospitalization_state_monthly.csv')
# vax_dt = pd.read_csv(path + '/us_state_vaccinations.csv')


nst_est2021 = pd.read_csv(path + '/data/NST-EST2021-alldata.csv')
state_pop = nst_est2021[5:][['NAME', 'POPESTIMATE2021']]
state_code_list = pd.read_csv(path + '/data/us_state_political_aff.csv')
state_list = list(state_code_list.State.values)
state_to_code = dict(zip(state_code_list.State, state_code_list.Code))
code_to_state = dict(zip(state_code_list.Code, state_code_list.State))

behaviors = ['Avoiding contact with other people',
       'Avoiding public or crowded places', 'Frequently washing hands',
       'Wearing a face mask when outside of your home','Go to work', 'Go to the gym', 'Go visit a friend',
       'Go to a cafe, bar, or restaurant', 'Go to a doctor or visit a hospital',
       'Go to church or another place of worship', 'Take mass transit (e.g. subway, bus, or train)',
       'Been in a room with someone outside of household in the past 24 hours',
       'Been in a room with 5-10 people outside of household in the past 24 hours',
       'Been in a room with 11-50 people outside of household in the past 24 hours',
       'Been in a room with over 50 people outside of household in the past 24 hours']

bimonthly = ['2020-04-30', '2020-05-15', '2020-05-31', '2020-06-01',
       '2020-06-15', '2020-06-30', '2020-07-15', '2020-07-31', '2020-08-15', 
       '2020-08-31', '2020-09-15', '2020-09-30', '2020-10-15', '2020-10-31', 
       '2020-11-14', '2020-11-15', '2020-11-30', '2020-12-15', '2020-12-31', 
       '2021-01-15', '2021-01-31', '2021-02-15', '2021-02-28', '2021-03-15', '2021-03-31', 
       '2021-04-15', '2021-04-30', '2021-05-15', '2021-05-31', '2021-06-15',
       '2021-06-30', '2021-07-15', '2021-07-31', '2021-08-15', '2021-08-31', 
       '2021-09-15', '2021-09-30', '2021-10-15', '2021-10-31', '2021-11-15', 
       '2021-11-30', '2021-12-15', '2021-12-31', '2022-01-15', '2022-01-31', 
       '2022-02-15', '2022-02-28', '2022-03-15', '2022-03-31', '2022-04-15',
       '2022-04-30', '2022-05-15', '2022-05-31', '2022-06-15', '2022-06-30', 
       '2022-07-15', '2022-07-31', '2022-08-15', '2022-08-31', '2022-09-15',
       '2022-09-30', '2022-10-15', '2022-10-31', '2022-11-15', '2022-11-30', 
       '2022-12-15', '2022-12-31', '2023-01-15', '2023-01-31', '2023-02-15',
       '2023-02-28', '2023-03-15', '2023-03-31']
bimonthly_dt = pd.to_datetime(bimonthly)
bimnth_yr = [datetime.strptime(date, '%Y-%m-%d').strftime('%Y-%b') for date in bimonthly]

monthly = ['2020-04-30', '2020-05-31', '2020-06-30', '2020-07-31', 
       '2020-08-31', '2020-09-30', '2020-10-31', '2020-11-30', '2020-12-31', '2021-01-31', 
       '2021-02-28',  '2021-03-31', '2021-04-30', '2021-05-31', '2021-06-30', '2021-07-31', 
       '2021-08-31', '2021-09-30', '2021-10-31','2021-11-30', '2021-12-31', '2022-01-31', 
       '2022-02-28', '2022-03-31', '2022-04-30', '2022-05-31', '2022-06-30', '2022-07-31', 
       '2022-08-31', '2022-09-30', '2022-10-31', '2022-11-30', 
       '2022-12-31', '2023-01-31', '2023-02-28', '2023-03-31']
monthly_dt = pd.to_datetime(monthly)
# yr_mnth = [datetime.strptime(date, '%Y-%m-%d').strftime('%Y-%b') for date in monthly]
yr_mnth = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in monthly]
# print(yr_mnth)

monthly_lin = ['2020-04-30', '2020-05-31', '2020-06-30', '2020-07-31', 
       '2020-08-31', '2020-09-30', '2020-10-31', '2020-11-30', '2020-12-31', '2021-01-31', 
       '2021-02-28',  '2021-03-31', '2021-04-30', '2021-05-31', '2021-06-30', '2021-07-31', 
       '2021-08-31', '2021-09-30', '2021-10-31','2021-11-30', '2021-12-31', '2022-01-31', 
       '2022-02-28']
monthly_lin_dt = pd.to_datetime(monthly_lin)
yr_mnth_lin = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in monthly_lin]

_monthly_lin = ['2020-04-30', '2020-05-31', '2020-06-30', '2020-07-31', 
       '2020-08-31', '2020-09-30', '2020-10-31', '2020-11-30', '2020-12-31', '2021-01-31', 
       '2021-02-28',  '2021-03-31', '2021-04-30', '2021-05-31', '2021-06-30', '2021-07-31', 
       '2021-08-31','2021-09-30', '2021-10-31','2021-11-30', '2021-12-31', '2022-01-31', 
       '2022-02-28', '2022-03-31', '2022-04-30', '2022-05-31']
# _monthly_lin_dt = pd.to_datetime(_monthly_lin)
_yr_mnth_lin = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in _monthly_lin]

behaviors_type = ['Risk-averting', 'Risk-averting', 'Risk-averting', 'Risk-averting', 'Risk-seeking', 'Risk-seeking', 'Risk-seeking', 'Risk-seeking', 
'Risk-seeking', 'Risk-seeking', 'Risk-seeking', 'Risk-seeking', 'Risk-seeking', 'Risk-seeking', 'Risk-seeking']
beh_category_map = dict(zip(behaviors, behaviors_type))
beh_category_map


### Plots

##### Main diagrams for the paper

In [None]:
newpath = path + '/behavior_series/'
# newpath

## assign variables
interval = 30
death_agg = death_agg_dt

# _monthly_lin_dt = pd.to_datetime(_monthly_lin)
yr_mnth = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in monthly]

# plot behavior time series, state-wise and distinguished by poticial affiliation

fig, axes = plt.subplots(nrows = 4, ncols = 4, figsize = (26, 27), sharex= False)
plt.subplots_adjust(top=1.0, hspace = 0.5)
axes = axes.flatten()
column_names = ['behavior', 'political_aff', 'min', 'max', 'start', 'end']
state_means = pd.DataFrame(columns = column_names)

for i in range(len(behaviors)):
    behavior = behaviors[i]
    df_b = pd.DataFrame()
    for st in state_code_list.State.values:

        location = st
        # print(location)
        dt0 = death_behavior_loc_wvend(death_agg, location, behavior, interval)
        dt = dt0[dt0.end_date.isin(monthly)]

        if state_code_list[state_code_list.State == st].Political_aff.values[0] == 'REP':
            col = 'r'
        elif state_code_list[state_code_list.State == st].Political_aff.values[0] == 'IND':
            col = 'black'
        elif state_code_list[state_code_list.State == st].Political_aff.values[0] == 'DEM':
            col = 'b'    
            
        death = dt['death_pct'].values
        behvr = dt[behavior].values
        dates = dt['end_date'].values
        
        b_list = []
        b_list.extend(behvr)
        b_list.append(col)
        l = np.array(b_list)

        df_b = pd.concat([df_b, pd.DataFrame(l).T])
        axes[i].plot(dates, behvr, color = col, alpha = 0.18, label = st)

    column_names = df_b.columns.tolist()
    column_names[-1] = 'color'
    df_b.columns = column_names

    for col in column_names[0:-1]:
        df_b[col] = pd.to_numeric(df_b[col])

    b_mean = df_b.groupby('color').mean()
    print(behavior)

    state_mean_analysis = b_mean.reset_index()
    for col in ['b', 'black', 'r']:
        if col == 'b':
            pol = 'DEM'
        elif col == 'black':
            pol = 'IND'
        elif col == 'r':
            pol = 'REP'

        beh_col = state_mean_analysis[state_mean_analysis['color'] == col].values[0][1:]
        # display(beh_col)
        # print(behavior, col, min(beh_col), max(beh_col), f"start: {beh_col[0]}", f"end: {beh_col[-1]}")
        new_row = [behavior, pol, min(beh_col), max(beh_col), beh_col[0], beh_col[-1]]
#         state_means.loc[len(state_means)] = new_row
    
# display(state_means)
# state_means.to_csv(newpath + 'state_level_political_aff_means.csv')
    axes[i].plot(dates, b_mean.loc['b'].values, color = 'b', marker = '.', label = 'Democratic')
    axes[i].plot(dates, b_mean.loc['black'].values, color = 'black', marker = '.', label = 'Independent')
    axes[i].plot(dates, b_mean.loc['r'].values, color = 'r', marker = '.', label = 'Republican')
    
    if len(behavior) > 27:
        title_half = int(np.round(len(behavior)/2))
        pos = 0
        while pos < title_half:
            if behavior[title_half + pos] == ' ':
                title_top = title_half + pos
                break
            pos += 1

        axes[i].set_title(behavior[0:title_top]+'\n'+behavior[title_top:], fontsize = 18, wrap=True)
    else:
        axes[i].set_title(behavior, fontsize = 18, wrap=True)
        
    axes[i].set_xlabel('Time [month]', fontsize = 16)
    axes[i].set_ylabel('Behavior adoption [%]', fontsize = 16)
    handles, labels = axes[i].get_legend_handles_labels()
    
     # Show every fourth label
    axes[i].set_xticks(monthly)
    axes[i].set_xticklabels(yr_mnth, fontsize = 14)
    for j, label in enumerate(axes[i].get_xticklabels()):
        if (j == 1) | (j == len(dates)):
            label.set_visible(True)
        if 1 < j < 3:
            label.set_visible(False)
        elif j % 5 != 0:
            label.set_visible(False)
    axes[i].set_yticklabels(axes[i].get_yticks(),fontsize = 14, fontweight = 'bold')
    axes[i].tick_params(axis='x', rotation=90)
    axes[i] = 'used'

fig.legend(handles[-3:], labels[-3:], loc='upper left', fontsize = 14)
fig.suptitle('Behavior time-series segregated by state political-leaning', y = 0.95, fontsize = 30)    
fig.tight_layout(pad=1.2)
for c in axes:
    if c != 'used':
        plt.delaxes(c)
plt.savefig(newpath + 'timeseries_behavior_seg_state_political_30dayprior.png', facecolor='white')
plt.show()

print("-----------------------------------------------------------------------------------------------------------------------------------------")




In [None]:
newpath = path + '/main paper/'

_monthly_lin = ['2020-04-30', '2020-05-31', '2020-06-30', '2020-07-31', 
       '2020-08-31', '2020-09-30', '2020-10-31', '2020-11-30', '2020-12-31', '2021-01-31', 
       '2021-02-28',  '2021-03-31', '2021-04-30', '2021-05-31', '2021-06-30', '2021-07-31', 
       '2021-08-31','2021-09-30', '2021-10-31','2021-11-30', '2021-12-31', '2022-01-31', 
       '2022-02-28', '2022-03-31', '2022-04-30', '2022-05-31']

# _monthly_lin_dt = pd.to_datetime(_monthly_lin)
_yr_mnth_lin = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in _monthly_lin]

## Analysis parameters
behavior = 'Avoiding contact with other people'
suffix = 'avoid_contact'
# behavior = 'Take mass transit (e.g. subway, bus, or train)'
# suffix = 'take_mass_transit'
# behavior = 'Go visit a friend'
# suffix = 'visit_friend'

lag = 0
location = 'National'
interval = 30
death_agg = death_agg_dt

## Style parameters
titlefont = 18
labelfonts = 16
legendfonts = 12
risk_seeking_col = 'g'
risk_averting_col = 'darkorange'

## Data
dt0 = death_behavior_loc_wvend(death_agg, location, behavior, interval)
dt_lin = dt0[dt0.end_date.isin(_monthly_lin)]

# fig, ax = plt.subplots(nrows= )

#######

## Fig 1: Behavior and death time series juxtaposed
filename_prefix = 'behvr_death_y_axis_flipped_'

df_b = pd.DataFrame()
df_b_an = pd.DataFrame()

death = dt_lin['death_pct'].values[lag:]
behvr = dt_lin[behavior].values[0:len(dt_lin[behavior]) - lag]
dates = dt_lin['end_date'].values

b_lag = behvr[0:(len(behvr)-lag)].reshape(-1, 1)
d_lag = death[lag:].flatten()
dates_lag = dates[lag:]

if b_lag[0] < b_lag[-1]:
    col = risk_seeking_col
    legend_loc = 'upper left'
else:
    col = risk_averting_col
    legend_loc = 'upper right'

fig, axes = plt.subplots(figsize = (7, 5))

ax2 = axes.twinx()
# axes.set_ylim(0, 100)

axes.plot(dates_lag, d_lag, color = 'black')
ax2.plot(dates_lag, b_lag, color = col)
axes.plot(np.nan, color = 'black', label = 'Death') ## no data, this is added only to combine two legendfonts

if col == 'g':
    axes.plot(np.nan, color = col, label = 'Behavior') ## no data, this is added only to combine two legends
    # added to primary axis for simple concatenation on legend
elif col == 'darkorange':
    axes.plot(np.nan, color = col, label = 'Behavior') ## no data, this is added only to combine two legends

if len(behavior) < 30:
    axes.set_title(behavior, fontsize = titlefont, wrap = True)
else:
    title_half = int(np.round(len(behavior)/2))
    axes.set_title(behavior[0:title_half]+'\n'+behavior[title_half:], fontsize = titlefont, wrap=True)

# axes.set_title(behavior, fontsize = 14, wrap=True)

axes.set_xlabel('Time [month]', fontsize = labelfonts)
ax2.set_ylabel('Behavior adoption [%]', fontsize = labelfonts, color = col)
axes.set_ylabel('Proportion of deaths [%]', fontsize = labelfonts)

handles, labels = axes.get_legend_handles_labels()

# Show every fourth label
axes.set_xticks(_monthly_lin)
axes.set_xticklabels(_yr_mnth_lin, fontsize = labelfonts)
for j, label in enumerate(axes.get_xticklabels()):
    if (j == 1) | (j == len(dates)):
        label.set_visible(True)
    if 1 < j < 3:
        label.set_visible(False)
    elif j % 5 != 0:
        label.set_visible(False)

axes.tick_params(axis='x', rotation=90)
axes.set_yticklabels(axes.get_yticks(), fontsize = labelfonts)
ax2.set_yticklabels(ax2.get_yticks(), fontsize = labelfonts, fontweight = 'bold', color = col)
axes.legend(loc = legend_loc, fontsize = legendfonts)

plt.savefig(newpath + filename_prefix + suffix +'.png', facecolor='white', bbox_inches='tight')

#### Fig 2: Behavior trend line
filename_prefix = 'trend_lin_'

dt_lin = dt0[dt0.end_date.isin(_monthly_lin)]
behvr_lin = dt_lin[behavior].values[0:len(dt_lin)]
t_lin = range(1, len(_monthly_lin)+1)

t, b, an, m, c, t_line, b_line = lin_reg(t_lin, behvr_lin) 

fig, ax = plt.subplots(figsize = (7, 5))

if b[0] < b[-1]:
    col = risk_seeking_col
else:
    col = risk_averting_col 

ax.scatter(_monthly_lin, list(behvr_lin), color = col, marker = 'o', alpha = 0.8, label = 'Data' )
ax.plot(t_line, b_line, color = col, alpha = 0.8, linewidth= 1.2, label = 'Trend')
# ax.set_ylim(0, 100)
annotation_text = 'y = ' + str(round(m[0], 3)) +' x + ' + str(round(c, 3))
# axes[i].annotate(annotation_text, xy=(0.1, 0.90), xycoords='axes fraction', fontsize=22, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

## Plot info labeling
if len(behavior) < 30:
    ax.set_title(behavior, fontsize = titlefont, wrap = True)
else:
    title_half = int(np.round(len(behavior)/2))
    ax.set_title(behavior[0:title_half]+'\n'+behavior[title_half:], fontsize = titlefont, wrap=True)

ax.set_xlabel('Time [month]', fontsize = labelfonts)
# ax.set_ylabel('Behavior adoption [%]', fontsize = labelfonts, color = col)
handles, labels = ax.get_legend_handles_labels()
   
# Show every fourth label
ax.set_xticks(_monthly_lin)
ax.set_xticklabels(_yr_mnth_lin, fontsize = labelfonts)
for j, label in enumerate(ax.get_xticklabels()):
    if (j == 1) | (j == len(dates)):
        label.set_visible(True)
    if 1 < j < 3:
        label.set_visible(False)
    elif j % 5 != 0:
        label.set_visible(False)
ax.set_yticklabels(ax.get_yticks(),fontsize = labelfonts, fontweight = 'bold', color = col)
ax.tick_params(axis='x', rotation=90)
ax.legend(loc = legend_loc, fontsize = legendfonts)
plt.savefig(newpath + filename_prefix + suffix +'.png', facecolor='white', bbox_inches='tight')

#### Fig 3: Behavior anomaly death
filename_prefix = 'shift_'

dates_lin = dt_lin['end_date'].values
yr_mnth_lin = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in dates_lin]
behvr_lin = dt_lin[behavior].values
death_lin = dt_lin['death_pct'].values
t_lin = range(1, len(dt_lin)+1)

t, d, an, m, c, t_line, d_line = lin_reg(t_lin, death_lin)
t, b, ban, m, c, t_line, b_line = lin_reg(t_lin, behvr_lin) 

## Plotting

if b[0] < b[-1]:
    col = risk_seeking_col
else:
    col = risk_averting_col

fig, ax1 = plt.subplots(figsize = (7, 5))

ax1.plot(dates_lin, ban, color = col, linewidth= 2)

ax2 = ax1.twinx()  
ax2.plot(dates_lin, d, color = 'black')
ax1.plot(np.nan, color = 'black', label = 'Death') ## no data, this is added only to combine two legends

if col == 'g':
    # ax1.plot(dates_lin[0:-6], ban[6:], color = 'lightgreen', linewidth= 2, linestyle = '--')
    ax1.plot(np.nan, color = 'g', label = 'Behavior anomaly \n (risk-seeking)') ## no data, this is added only to combine two legends
    # ax1.plot(np.nan, color = 'lightgreen', label = 'Behavior anomaly \n 6-month shifted')
elif col == 'darkorange':
    ax1.plot(np.nan, color = 'orange', label = 'Behavior anomaly \n (risk-averting)') ## no data, this is added only to combine two legends

if len(behavior) < 30:
    ax1.set_title(behavior, fontsize = titlefont, wrap = True)
else:
    title_half = int(np.round(len(behavior)/2))
    ax1.set_title(behavior[0:title_half]+'\n'+behavior[title_half:], fontsize = titlefont, wrap=True)

ax1.set_xlabel('Time [month]', fontsize = labelfonts)
handles, labels = ax1.get_legend_handles_labels()

# Show every fourth label
ax1.set_xticks(np.array(t_lin)-1)
ax1.set_xticklabels(_yr_mnth_lin, fontsize = labelfonts)

for j, label in enumerate(ax1.get_xticklabels()):
    if (j == 1) | (j == len(dates)):
        label.set_visible(True)
    if 1 < j < 3:
        label.set_visible(False)
    elif j % 5 != 0:
        label.set_visible(False)
ax1.set_yticklabels(ax1.get_yticks(),fontsize = labelfonts, fontweight = 'bold', color = col)
ax2.set_yticklabels(ax2.get_yticks(),fontsize = labelfonts)


ax1.tick_params(axis='x', rotation=90)
ax1.set_ylabel('Behavior anomaly', fontsize = labelfonts, color = col)
ax1.tick_params(axis='y', labelcolor=col)
ax2.set_ylabel('Proportion of deaths [%]', fontsize = labelfonts, color = 'black')
ax2.tick_params(axis='y', labelcolor = 'black')

ax1.legend(loc = legend_loc, fontsize = legendfonts)
plt.savefig(newpath + filename_prefix + suffix +'.png', facecolor='white', bbox_inches='tight')
# plt.savefig(newpath + filename_prefix + suffix +'shifted_line.png', facecolor='white', bbox_inches='tight')


In [None]:
newpath = path + '/behavior_anom_and_death_hospit_corr_lags'

import matplotlib.gridspec as gridspec

# lags = [-9, -8, -6, -4, -2, -1, 0, 1, 2, 4, 6, 8]
lags = [-2, -1, 0, 1, 2]

interval = 30
# location_list = []
location_list = list(state_code_list.State.values)
location_list.append('National')
death_agg = death_agg_dt

for location in location_list:
    dlag_corr = pd.DataFrame()
    dlag_corr.index = lags
    hlag_corr = pd.DataFrame()
    hlag_corr.index = lags
    fig, ax = plt.subplots(figsize = (5, 7))

    for i in range(len(behaviors)):
        # print("1", location)
        behavior = behaviors[i]
        dt0 = death_behavior_loc_wvend(death_agg, location, behavior, interval)
        dt = dt0[dt0.end_date.isin(_monthly_lin)]
        if location == 'National':
            loc = 'National'
        else:
            loc = state_to_code.get(location)

        hosp_yr_month = [datetime.strptime(date, '%Y-%m-%d').strftime('%Y-%m') for date in _monthly_lin]
        hospt_dt = hospit_df[hospit_df.mnth_yr.isin(hosp_yr_month) & (hospit_df['state_code']== loc)]
        # display(hospt_dt)
        dt['mnth_yr'] = dt['end_date'].apply(lambda row: datetime.strptime(row, '%Y-%m-%d').strftime('%Y-%m'))
        hospt_beh_anom = pd.merge(hospt_dt, dt, how = 'inner', on = 'mnth_yr')
        # display(hospt_beh_anom)

        dates = dt['end_date'].values
        dates_int = list(range(len(dates)))

        death = dt['death_pct'].values
        behvr = dt[behavior].values
        be_an = dt['behavior_anomaly'].values

        hospt = hospt_dt['monthly_hospitalization'].values
        h_be_an = hospt_beh_anom['behavior_anomaly'].values

        dcorr_ = []
        hcorr_ = []

        for i in range(len(lags)):
            lag = lags[i]
            if lag >= 0:
                ban_lag = be_an[0:(len(be_an)-lag)].flatten()
                d_lag = death[lag:].flatten()
                h_ban_lag = h_be_an[0:(len(h_be_an)-lag)].flatten()
                h_lag = hospt[lag:].flatten()
                dates_seg = dates_int[0:(len(be_an)-lag)]

            elif lag < 0:
                lag_ = abs(lag)
                ban_lag = be_an[lag_:].flatten()
                d_lag = death[0:(len(death)-lag_)].flatten()
                h_ban_lag = h_be_an[lag_:].flatten()
                h_lag = hospt[0:len(hospt)-lag_].flatten()
               
                dates_seg = dates_int[0:(len(death)-lag_)]
            # print(h_ban_lag)
            d, ba, ban, m, c, d_line, ba_line = lin_reg(d_lag, ban_lag)
            # h, ba, ban, m, c, h_line, ba_line = lin_reg(h_lag, ban_lag)
            dcorr_.append(np.corrcoef(ban_lag, d_lag)[0, 1])
            hcorr_.append(np.corrcoef(h_ban_lag, h_lag)[0, 1])
            # print('m: ', m, 'correlation: ', np.corrcoef(ban_lag, d_lag)[0,1])

        dlag_corr[behavior] = dcorr_
        hlag_corr[behavior] = hcorr_
        # plt.savefig(newpath + '/behavior_anom_vs_death_all_lags_' + location + '.png')
        # plt.show()
    
    # display(pd.concat([dlag_corr.T, hlag_corr.T], axis = 0))

    sns.heatmap(pd.concat([dlag_corr.T, hlag_corr.T], axis = 0), annot = True, cmap = 'coolwarm', annot_kws={'size': 8})

    # 

    # label_risk_averting = ['Avoiding contact with other people','Avoiding public or crowded places',
    #                         'Frequently washing hands','Wearing a face mask when outside of your home']

    # label_risk_seeking = [ 'Go to work', 'Go to the gym','Go visit a friend','Go to a cafe, bar, or restaurant',
    #                         'Go to a doctor or visit a hospital','Go to church or another place of worship',
    #                         'Take mass transit (e.g. subway, bus, or train)', 'Been in a room with someone outside of household in the past 24 hours',
    #                         'Been in a room with 5-10 people outside of household in the past 24 hours','Been in a room with 11-50 people outside of household in the past 24 hours',
    #                         'Been in a room with over 50 people outside of household in the past 24 hours']

    # def get_label_color(label):
    #     if label in label_risk_averting:
    #         return 'orangered'
    #     if label in label_risk_seeking:
    #         return 'g'
        
    # for label in ax.get_yticklabels():
    #     label.set_color(get_label_color(label.get_text()))

    ax.set_xlabel("Lag [month]")

    ax.set_title(f"{location} | Oscillations in behavior trend vs Mortality & Hospitalization", fontsize = 12)

    # fig.suptitle("National level behavior anomaly comparisons ")
    # plt.show()
    plt.tight_layout()
    # plt.savefig(newpath + '/behavior_anom_&_death_corr_lags_' + location + '.png', bbox_inches='tight')
    plt.savefig(newpath + '/behavior_anom_&_death_corr_lags_short_' + location + '.png', bbox_inches='tight')
    # break


##### Supplementary

In [None]:
newpath = path + '/params_table_3plots/'

_monthly_lin = ['2020-04-30', '2020-05-31', '2020-06-30', '2020-07-31', '2020-08-31', 
                '2020-09-30', '2020-10-31', '2020-11-30', '2020-12-31', '2021-01-31', 
                '2021-02-28',  '2021-03-31', '2021-04-30', '2021-05-31', '2021-06-30', 
                '2021-07-31', '2021-08-31','2021-09-30', '2021-10-31','2021-11-30', 
                '2021-12-31', '2022-01-31', '2022-02-28', '2022-03-31', '2022-04-30', '2022-05-31']

# _monthly_lin_dt = pd.to_datetime(_monthly_lin)
_yr_mnth_lin = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in _monthly_lin]


lag = 0
location = 'National'
interval = 30
death_agg = death_agg_dt

## Style parameters
titlefont = 10
labelfonts = 8
legendfonts = 6
risk_seeking_col = 'g'
risk_averting_col = 'darkorange'

behaviors_set1 = behaviors[0:5]
behaviors_set2 = behaviors[5:10]
behaviors_set3 = behaviors[10:15]

location_list = list(state_code_list.State.values)
location_list.append('National')

for location in location_list:

    ms = []
    cs = []

    #######
    i = 0
    for b_set in [behaviors_set1, behaviors_set2, behaviors_set3]:
        fig, axes = plt.subplots(nrows= 5, ncols = 3, sharex=True, figsize = (12, 15))
        for b in range(len(b_set)):
            
            ## Fig 1: Behavior and death time series juxtaposed
            
            behavior = b_set[b]
            df_b = pd.DataFrame()
            df_b_an = pd.DataFrame()

            ## Data
            dt0 = death_behavior_loc_wvend(death_agg, location, behavior, interval)
            dt_lin = dt0[dt0.end_date.isin(_monthly_lin)]
            t_lin = range(1, len(_monthly_lin)+1)

            death = dt_lin['death_pct'].values[lag:]
            behvr = dt_lin[behavior].values[0:len(dt_lin[behavior]) - lag]
            dates = dt_lin['end_date'].values

            b_lag = behvr[0:(len(behvr)-lag)].reshape(-1, 1)
            d_lag = death[lag:].flatten()
            dates_lag = dates[lag:]

            if b_lag[0] < b_lag[-1]:
                col = risk_seeking_col
                legend_loc = 'upper left'
            else:
                col = risk_averting_col
                legend_loc = 'upper right'

            # fig, axes = plt.subplots(figsize = (7, 5))

            ax2 = axes[b, 0].twinx()
            # axes.set_ylim(0, 100)

            axes[b, 0].plot(dates_lag, d_lag, color = 'black')
            ax2.plot(dates_lag, b_lag, color = col)
            axes[b, 0].plot(np.nan, color = 'black', label = 'Death') ## no data, this is added only to combine two legendfonts

            if col == 'g':
                axes[b, 0].plot(np.nan, color = col, label = 'Behavior') ## no data, this is added only to combine two legends
                # added to primary axis for simple concatenation on legend
            elif col == 'darkorange':
                axes[b, 0].plot(np.nan, color = col, label = 'Behavior') ## no data, this is added only to combine two legends

            if len(behavior) < 30:
                axes[b, 0].set_title(behavior, fontsize = titlefont, wrap = True)
            else:
                text_list = behavior.split(" ")
                # Calculate the midpoint of the list
                midpoint = len(text_list) // 2

                # Split the list into two halves
                first_half = text_list[:midpoint]
                second_half = text_list[midpoint:]

                # Concatenate the first half and the second half, with spaces between elements
                title_first_half = ' '.join(first_half)
                title_second_half = ' '.join(second_half)

                axes[b, 0].set_title(title_first_half + '\n' + title_second_half, fontsize = titlefont, wrap=True)

            # axes.set_title(behavior, fontsize = 14, wrap=True)

            axes[b, 0].set_xlabel('Time [month]', fontsize = labelfonts)
            ax2.set_ylabel('Behavior adoption [%]', fontsize = labelfonts, color = col)
            axes[b, 0].set_ylabel('Mortality due to COVID-19 [%]', fontsize = labelfonts)

            handles, labels = axes[b, 0].get_legend_handles_labels()

            # Show every fourth label
            axes[b, 0].set_xticks(np.array(t_lin)-1)
            axes[b, 0].set_xticklabels(_yr_mnth_lin, fontsize = labelfonts)

            for j, label in enumerate(axes[b, 0].get_xticklabels()):
                if (j == 1) | (j == len(_yr_mnth_lin)):
                    label.set_visible(True)
                if 1 < j < 3:
                    label.set_visible(False)
                elif j % 5 != 0:
                    label.set_visible(False)

            axes[b, 0].tick_params(axis='x', rotation=90)
            axes[b, 0].set_yticklabels([np.round(elem, 3) for elem in axes[b, 0].get_yticks()], fontsize = labelfonts)
            ax2.set_yticklabels([np.round(elem, 3) for elem in ax2.get_yticks()], fontsize = labelfonts, fontweight = 'bold', color = col)
            axes[b, 0].legend(loc = legend_loc, fontsize = legendfonts)

            # plt.savefig(newpath + filename_prefix + suffix +'.png', facecolor='white', bbox_inches='tight')

            #### Fig 2: Behavior trend line
            filename_prefix = 'trend_lin_'

            dt_lin = dt0[dt0.end_date.isin(_monthly_lin)]
            behvr_lin = dt_lin[behavior].values[0:len(dt_lin)]
            t_lin = range(1, len(_monthly_lin)+1)

            t, bh, an, m, c, t_line, b_line = lin_reg(t_lin, behvr_lin) 
            ms.append(np.round(m[0], 3))
            cs.append(np.round(c, 3))

            if bh[0] < bh[-1]:
                col = risk_seeking_col
            else:
                col = risk_averting_col 

            axes[b, 1].scatter(_monthly_lin, list(behvr_lin), color = col, marker = 'o', alpha = 0.8, label = 'Data' )
            axes[b, 1].plot(t_line, b_line, color = col, alpha = 0.8, linewidth= 1.2, label = 'Trend')

            if len(behavior) < 30:
                axes[b, 1].set_title(behavior, fontsize = titlefont, wrap = True)
            else:
                text_list = behavior.split(" ")
                # Calculate the midpoint of the list
                midpoint = len(text_list) // 2

                # Split the list into two halves
                first_half = text_list[:midpoint]
                second_half = text_list[midpoint:]

                # Concatenate the first half and the second half, with spaces between elements
                title_first_half = ' '.join(first_half)
                title_second_half = ' '.join(second_half)

                axes[b, 1].set_title(title_first_half + '\n' + title_second_half, fontsize = titlefont, wrap=True)

            handles, labels = axes[b, 1].get_legend_handles_labels()
            
            axes[b, 1].set_xlabel('Time [month]', fontsize = labelfonts)
            
            annotation_text = 'y = ' + str(round(m[0], 3)) +' x + ' + str(round(c, 3))
            # axes[i].annotate(annotation_text, xy=(0.1, 0.90), xycoords='axes fraction', fontsize=22, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            # Show every fourth label
            axes[b, 1].set_xticks(np.array(t_lin)-1)
            axes[b, 1].set_xticklabels(_yr_mnth_lin, fontsize = labelfonts)
            
            for j, label in enumerate(axes[b, 1].get_xticklabels()):
                if (j == 1) | (j == len(_yr_mnth_lin)):
                    label.set_visible(True)
                if 1 < j < 3:
                    label.set_visible(False)
                elif j % 5 != 0:
                    label.set_visible(False)

            axes[b, 1].tick_params(axis='x', rotation=90)
            axes[b, 1].set_yticklabels(axes[b, 1].get_yticks(),fontsize = labelfonts, fontweight = 'bold', color = col)
            axes[b, 1].legend(loc = legend_loc, fontsize = legendfonts)

            #### Fig 3: Behavior anomaly death
            # filename_prefix = 'shift_'

            dates_lin = dt_lin['end_date'].values
            yr_mnth_lin = [datetime.strptime(date, '%Y-%m-%d').strftime('%b-%y') for date in dates_lin]
            behvr_lin = dt_lin[behavior].values
            death_lin = dt_lin['death_pct'].values
            t_lin = range(1, len(dt_lin)+1)

            t, d, an, m, c, t_line, d_line = lin_reg(t_lin, death_lin)
            t, bh, ban, m, c, t_line, b_line = lin_reg(t_lin, behvr_lin) 

            ## Plotting

            if bh[0] < bh[-1]:
                col = risk_seeking_col
            else:
                col = risk_averting_col

            axes[b, 2].plot(dates_lin, ban, color = col, linewidth= 2)

            ax2 = axes[b, 2].twinx()  
            ax2.plot(dates_lin, d, color = 'black')
            axes[b, 2].plot(np.nan, color = 'black', label = 'Death') ## no data, this is added only to combine two legends

            if col == 'g':
                # axes[b, 2].plot(dates_lin[0:-6], ban[6:], color = 'lightgreen', linewidth= 2, linestyle = '--')
                axes[b, 2].plot(np.nan, color = 'g', label = 'Oscillations in behavior trend \n (risk-seeking)') ## no data, this is added only to combine two legends
                # axes[b, 2].plot(np.nan, color = 'lightgreen', label = 'Behavior anomaly \n 6-month shifted')
            elif col == 'darkorange':
                axes[b, 2].plot(np.nan, color = 'orange', label = 'Oscillations in behavior trend \n (risk-averting)') ## no data, this is added only to combine two legends

            if len(behavior) < 30:
                axes[b, 2].set_title(behavior, fontsize = titlefont, wrap = True)
            else:
                text_list = behavior.split(" ")
                # Calculate the midpoint of the list
                midpoint = len(text_list) // 2

                # Split the list into two halves
                first_half = text_list[:midpoint]
                second_half = text_list[midpoint:]

                # Concatenate the first half and the second half, with spaces between elements
                title_first_half = ' '.join(first_half)
                title_second_half = ' '.join(second_half)

                axes[b, 2].set_title(title_first_half + '\n' + title_second_half, fontsize = titlefont, wrap=True)

            handles, labels = axes[b, 2].get_legend_handles_labels()

            axes[b, 2].set_xlabel('Time [month]', fontsize = labelfonts)
            # Show every fourth label
            axes[b, 2].set_xticks(np.array(t_lin)-1)
            axes[b, 2].set_xticklabels(_yr_mnth_lin, fontsize = labelfonts)

            for j, label in enumerate(axes[b, 2].get_xticklabels()):
                if (j == 1) | (j == len(_yr_mnth_lin)):
                    label.set_visible(True)
                if 1 < j < 3:
                    label.set_visible(False)
                elif j % 5 != 0:
                    label.set_visible(False)

            axes[b, 2].set_yticklabels([np.round(elem, 3) for elem in axes[b, 2].get_yticks()],fontsize = labelfonts, fontweight = 'bold', color = col)
            axes[b, 2].tick_params(axis='x', rotation=90)
            axes[b, 2].set_ylabel('Oscillations in behavior trend', fontsize = labelfonts, color = col)
            # axes[b, 2].tick_params(axis='y', labelcolor=col)

            ax2.set_yticklabels([np.round(elem, 3) for elem in ax2.get_yticks()],fontsize = labelfonts)
            ax2.set_ylabel('Mortality due to COVID-19 [%]', fontsize = labelfonts, color = 'black')
            ax2.tick_params(axis='y', labelcolor = 'black')
            axes[b, 2].legend(loc = legend_loc, fontsize = legendfonts)
        
        fig.tight_layout(pad=0.7)
        if i == 1:
            fig.suptitle(f"State name: {location}",fontsize = 40, fontweight = 'bold')
        plt.subplots_adjust(top=0.9) 
        plt.savefig(newpath + location+'_3plots_'+str(i)+'.svg', format = 'svg')
        i = i + 1

    params_df = pd.DataFrame()
    params_df['Behaviors'] = behaviors
    params_df['Slopes'] = ms
    params_df['Y-intercepts'] = cs

    # ax = render_mpl_table(params, header_columns=0, col_width = 12)
    # set_row_edge_color(ax, len(behaviors), 'k')
    # Plot the table
    fig, ax = plt.subplots(figsize=(30, 2.5*len(params_df)))  # Adjust height based on rows

    # Hide the axes
    ax.axis('tight')
    ax.axis('off')

    # Create the table
    table = ax.table(cellText=params_df.values, 
                    colLabels=params_df.columns, 
                    cellLoc='center', 
                    loc='center')

    # Adjust font size
    table.auto_set_font_size(False)
    table.set_fontsize(25)

    # Set the width for the first column (index 0)
    table.scale(1, 8)  # Increase the second argument (height scaling)

    # Iterate over each cell in the header row (row index 0) and set the background color to light blue
    for j in range(len(params_df.columns)):
        table[(0, j)].set_facecolor('#D3D3D3')  # Light blue color

    # Bold the header text
    for j in range(len(params_df.columns)):
        table[(0, j)].set_text_props(fontsize = 32, weight='bold')

    # Set the width of the first column
    for i in range(len(params_df) + 1):  # +1 for the header row
        table[(i, 0)].set_width(0.8)  # Adjust as needed

    # Reduce white space by using tight layout
    plt.subplots_adjust(left=0.2, right=0.8, top=0.75, bottom=0.05)

    # # Display the plot
    # plt.show()
    plt.savefig(newpath + location+'_param_table.svg', bbox_inches='tight', format='svg')
    
   


In [None]:
from PIL import Image, ImageDraw, ImageFont
from PyPDF2 import PdfReader, PdfWriter
import cairosvg
from io import BytesIO

# Function to load and convert SVG to a format PIL can handle
def load_svg_as_image(svg_path):
    """
    Load an SVG file and convert it to a PIL image.
    """
    png_image_data = cairosvg.svg2png(url=svg_path)
    image = Image.open(BytesIO(png_image_data))
    return image

# Function to add left-aligned text to an image
def add_left_aligned_text(image, text, y_position, font_size=70, bold=False):
    """
    Add left-aligned text to the given image at the specified y_position with a custom font size.
    If bold=True, load a bold font.
    """
    draw = ImageDraw.Draw(image)

    # Try to load a larger font
    try:
        if bold:
            # Load bold font
            font = ImageFont.truetype("/Library/Fonts/Lato-Bold.ttf", font_size)
        else:
            # Load regular font
            font = ImageFont.truetype("/Library/Fonts/Lato-Regular.ttf", font_size)
    except IOError:
        # Fallback to default PIL font but this won't allow font size adjustment
        font = ImageFont.load_default()
        print(f"Unable to load custom font. Default font is being used.")

    # Draw the text left-aligned at x = 20 (arbitrary left margin)
    draw.text((20, y_position), text, fill=(0, 0, 0), font=font)

# Function to add centered text to an image
def add_centered_text(image, text, y_position, font_size=70, bold=False):
    """
    Add centered text to the given image at the specified y_position with a custom font size.
    If bold=True, load a bold font.
    """
    draw = ImageDraw.Draw(image)

    # Try to load a larger font
    try:
        if bold:
            # Load bold font
            font = ImageFont.truetype("/Library/Fonts/Lato-Bold.ttf", font_size)
        else:
            # Load regular font
            font = ImageFont.truetype("/Library/Fonts/Lato-Regular.ttf", font_size)
    except IOError:
        # Fallback to default PIL font but this won't allow font size adjustment
        font = ImageFont.load_default()
        print(f"Unable to load custom font. Default font is being used.")

    # Calculate text width and height to center it using textbbox()
    text_bbox = draw.textbbox((0, 0), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    image_width = image.width

    # Calculate x-position to center the text
    x_position = (image_width - text_width) // 2

    # Draw the text at the calculated position
    draw.text((x_position, y_position), text, fill=(0, 0, 0), font=font)

# Function to add a page number to the image
def add_page_number(image, page_num, font_size=60):
    """
    Add a page number at the bottom-right of the page.
    """
    draw = ImageDraw.Draw(image)
    
    # Load the font
    try:
        font = ImageFont.truetype("/Library/Fonts/Lato-Regular.ttf", font_size)
    except IOError:
        # Fallback to default PIL font if custom font is not found
        font = ImageFont.load_default()
    
    # Define the text and position (bottom-right corner)
    text = f"{page_num}"
    text_bbox = draw.textbbox((0, 0), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    x_position = image.width - text_width - 100  # 20px padding from the right edge
    y_position = image.height - font_size - 100  # 20px padding from the bottom edge
    
    # Draw the page number
    draw.text((x_position, y_position), text, fill=(0, 0, 0), font=font)

# Combine the figure and table into one PDF with added centered text and page number
def combine_images_into_pdf(image1, image2, image3, table_img, output_pdf_path, location, figure_number, page_num, add_section_title=False):
    # Compress the images (adjust quality as needed)
    image1 = compress_image(image1)
    image2 = compress_image(image2)
    image3 = compress_image(image3)
    table_img = compress_image(table_img)
    
    # Get the maximum height of the images to align them vertically
    max_height = max(image1.height, image2.height, image3.height)

    line_width = 2  # Adjust the line width as needed
    line_color = (0, 0, 0)  # Black color for the dividing line

    # Calculate the combined width, accounting for the lines between images
    combined_width = image1.width + image2.width + image3.width + (2 * line_width)

    # Create a new image with enough space for the images and text
    combined_image_height = max_height + 500  # Adding space for larger text and titles
    combined_image = Image.new('RGB', (combined_width, combined_image_height), (255, 255, 255))

    # If this is the first location, add a bold section title (left-aligned)
    if add_section_title:
        add_left_aligned_text(combined_image, f"7 State-level behavioral trends and their relationships with epidemic severity", y_position=20, font_size=80, bold=True)
        y_offset = 200  # Move images lower due to the section title
    else:
        y_offset = 50  # No section title, so less offset

    # Paste the first image (below the title)
    combined_image.paste(image1, (0, y_offset + 80))

    # Create a draw object to add lines
    draw = ImageDraw.Draw(combined_image)

    # Calculate the offset for the line height
    line_height = int(max_height * 0.8)  # 80% of the image height
    vertical_offset = (max_height - line_height) // 2 + y_offset + 80  # Center the line vertically

    # Draw the first dividing line between image1 and image2, with the offset height
    draw.line([(image1.width, vertical_offset), (image1.width, vertical_offset + line_height)], fill=line_color, width=line_width)

    # Paste the second image after the first image and the line
    combined_image.paste(image2, (image1.width + line_width, y_offset + 80))

    # Draw the second dividing line between image2 and image3, with the offset height
    draw.line([(image1.width + image2.width + line_width, vertical_offset), (image1.width + image2.width + line_width, vertical_offset + line_height)], fill=line_color, width=line_width)

    # Paste the third image after the second image and the line
    combined_image.paste(image3, (image1.width + image2.width + 2 * line_width, y_offset + 80))

    # Add title for the table block (below the 3 images)
    add_centered_text(combined_image, f"Figure S7.{figure_number} shows behavior and mortality on 1st, 4th and 7th column, behavior trends in 2nd, 5th and 8th column, \n and behavior oscillatory components juxtaposed with mortality rates on 3rd, 6th and 9th columns for {location}.", max_height + y_offset + 180, font_size=50, bold=True)

    # Save the combined image temporarily to add more
    fig_img = combined_image

    # Combine plot and table
    # Determine the width and total height needed for the combined image
    max_width = max(fig_img.width, table_img.width)
    total_height = fig_img.height + table_img.height + 400  # Adding space for larger text

    # Create a blank image with the calculated dimensions
    combined_image = Image.new('RGB', (max_width, total_height), (255, 255, 255))

    # Paste the figure and table images into the combined image
    combined_image.paste(fig_img, (0, 10))
    combined_image.paste(table_img, ((max_width - table_img.width)//2, fig_img.height))

    # Add centered text at the bottom of the page (below the 4-image block)
    bottom_text = f"Table S7.{figure_number} shows the slopes and y-intercept of each behavior's trend in {location}."
    add_centered_text(combined_image, bottom_text, total_height - 450, font_size=50, bold=True)  # Add centered text at the bottom of the page

    # Add the page number
    add_page_number(combined_image, page_num)

    # Save as PDF with specified resolution
    combined_image.save(output_pdf_path, "PDF", resolution=70.0)

def compress_image(image):
    """
    Compress the given image by adjusting quality and optimizing it.
    """
    compressed_image = Image.new('RGB', image.size)
    compressed_image.paste(image)
    return compressed_image

output_pdf_paths = []
figure_number = 1  # Start with figure 1
page_num = 17  # Start with page 1
first_location = True  # Flag to check if it's the first location

pdf_writer = PdfWriter()
for location in location_list:
    # Load SVG files as images
    figure1 = load_svg_as_image(newpath + location + '_3plots_0.svg')
    figure2 = load_svg_as_image(newpath + location + '_3plots_1.svg')
    figure3 = load_svg_as_image(newpath + location + '_3plots_2.svg')
    table = load_svg_as_image(newpath + location + '_param_table.svg')

    # Define the output PDF path
    state_param_pdf_path = newpath + location + '_param_fig_tbl.pdf'
    output_pdf_paths.append(state_param_pdf_path)

    # Combine and save images as a PDF with added centered text and titles
    if first_location:
        # Add section title only for the first location (left-aligned)
        combine_images_into_pdf(figure1, figure2, figure3, table, state_param_pdf_path, location, figure_number, page_num, add_section_title=True)
        first_location = False  # Ensure that this is only done for the first location
    else:
        # No section title for subsequent locations
        combine_images_into_pdf(figure1, figure2, figure3, table, state_param_pdf_path, location, figure_number, page_num, add_section_title=False)

    # Increment figure number and page number for the next set of images
    figure_number += 1  # Since we are referencing two figures in one iteration
    page_num += 1  # Increment page number
    # break
    
    # for pg_num in range(len(state_pdf_path.pages)):
    state_param_pdf = PdfReader(state_param_pdf_path)
    page = state_param_pdf.pages[0]
    pdf_writer.add_page(page)

# Define the output combined PDF path
combined_output_pdf_path = newpath + '/supp_state_figs_params_tbl_combined.pdf'   

with open(combined_output_pdf_path, 'wb') as output_pdf:
    pdf_writer.write(output_pdf)


# def combine_pdfs(pdf_list, output_path):
#     pdf_writer = PdfWriter()
#     for pdf in pdf_list:
#         pdf_reader = PdfReader(pdf)
#         for pg_num in range(len(pdf_reader.pages)):
#             page = pdf_reader.pages[0]
#             pdf_writer.add_page(page)
#             # break
#     with open(output_path, 'wb') as output_pdf:
#         pdf_writer.write(output_pdf)



# Combine all the generated PDFs into a single PDF
# combine_pdfs(output_pdf_paths, combined_output_pdf_path)

Supplementary: Correlations of behavior oscillatory component with mortality, hospitalizations, and cases at various lags

In [None]:
offc_st_df = pd.read_csv(path + '/data/all_state_official_case_cnt.csv')
offc_st_df.rename(columns={'month_yr':'mnth_yr'}, inplace=True)
offc_st_df

In [None]:
import matplotlib.gridspec as gridspec
newpath = path + '/behavior_anom_and_mortality_hospit_cases_corr_lags'

lags = [-2, -1, 0, 1, 2]

interval = 30
# location_list = []
location_list = list(state_code_list.State.values)
location_list.append('National')
death_agg = death_agg_dt
case_nat = offc_st_df[['mnth_yr', 'case_cnt_norm']].groupby('mnth_yr').sum().reset_index()
all_corr_table = pd.DataFrame()


for location in location_list:
    dlag_corr = pd.DataFrame()
    dlag_corr.index = lags
    clag_corr = pd.DataFrame()
    clag_corr.index = lags
    hlag_corr = pd.DataFrame()
    hlag_corr.index = lags
    fig, ax = plt.subplots(figsize = (5, 10))

    for i in range(len(behaviors)):
        # print("1", location)
        behavior = behaviors[i]
        dt0 = death_behavior_loc_wvend(death_agg, location, behavior, interval)
        dt = dt0[dt0.end_date.isin(_monthly_lin)]
        if location == 'National':
            loc = 'National'
        else:
            loc = state_to_code.get(location)

        dates = dt['end_date'].values
        dates_int = list(range(len(dates)))

        death = dt['death_pct'].values
        behvr = dt[behavior].values
        be_an = dt['behavior_anomaly'].values

        hosp_yr_month = [datetime.strptime(date, '%Y-%m-%d').strftime('%Y-%m') for date in _monthly_lin]
        hospt_dt = hospit_df[(hospit_df.mnth_yr.isin(hosp_yr_month)) & (hospit_df['state_code']== loc)]
        # print(len(hospt_dt))
        
        # display(hospt_dt)
        dt['mnth_yr'] = dt['end_date'].apply(lambda row: datetime.strptime(row, '%Y-%m-%d').strftime('%Y-%m'))
        hospt_beh_anom = pd.merge(hospt_dt, dt, how = 'inner', on = 'mnth_yr')
        # display(hospt_beh_anom)
        

        case_yr_month = [datetime.strptime(date, '%Y-%m-%d').strftime('%Y-%m') for date in _monthly_lin]
        case_dt = offc_st_df[(offc_st_df.mnth_yr.isin(case_yr_month)) & (offc_st_df['state_code'] == loc)]

        if loc == 'National':
            case_dt = case_nat[case_nat.mnth_yr.isin(case_yr_month)]
        else:
            case_dt = case_dt
        
        dt['mnth_yr'] = dt['end_date'].apply(lambda row: datetime.strptime(row, '%Y-%m-%d').strftime('%Y-%m'))
        case_beh_anom = pd.merge(case_dt, dt, how = 'inner', on = 'mnth_yr')
        # display(case_beh_anom)

        dates = dt['end_date'].values
        dates_int = list(range(len(dates)))

        cs = case_dt['case_cnt_norm'].values
        c_be_an = case_beh_anom['behavior_anomaly'].values

        # display(cs)
        behvr = dt[behavior].values
        be_an = dt['behavior_anomaly'].values

        hospt = hospt_dt['monthly_hospitalization'].values
        h_be_an = hospt_beh_anom['behavior_anomaly'].values

        dcorr_ = []
        ccorr_ = []
        hcorr_ = []
            

        for i in range(len(lags)):
            lag = lags[i]
            if lag >= 0:
                ban_lag = be_an[0:(len(be_an)-lag)].flatten()
                d_lag = death[lag:].flatten()

                c_ban_lag = c_be_an[lag:].flatten()
                c_lag = cs[lag:].flatten()

                h_ban_lag = h_be_an[lag:].flatten()
                h_lag = hospt[lag:].flatten()
                dates_seg = dates_int[0:(len(be_an)-lag)]

            elif lag < 0:
                lag_ = abs(lag)
                ban_lag = be_an[lag_:].flatten()
                d_lag = death[0:(len(death)-lag_)].flatten()

                c_ban_lag = c_be_an[lag_:].flatten()
                c_lag = cs[0:(len(cs)-lag_)].flatten()

                h_ban_lag = h_be_an[lag_:].flatten()
                h_lag = hospt[0:len(hospt)-lag_].flatten()
                dates_seg = dates_int[0:(len(death)-lag_)]
            # print(ban_lag, c_ban_lag, h_ban_lag)
            dcorr_.append(np.corrcoef(ban_lag, d_lag)[0, 1])
            ccorr_.append(np.corrcoef(c_ban_lag, c_lag)[0, 1])
            hcorr_.append(np.corrcoef(h_ban_lag, h_lag)[0, 1])
            
        dlag_corr[behavior] = dcorr_
        clag_corr[behavior] = ccorr_
        hlag_corr[behavior] = hcorr_

    # Concatenate the DataFrames with a MultiIndex to indicate their source
    concatenated_df = pd.concat([
        dlag_corr.T.assign(Source=': mortality'),
        hlag_corr.T.assign(Source=': hospitalizations'),
        clag_corr.T.assign(Source=': cases')], axis=0)

    # Set the 'Source' column as the first level of the index
    concatenated_df.set_index('Source', append=True, inplace=True)

    # Create a heatmap
    # plt.figure(figsize=(10, 8))
    sns.heatmap(concatenated_df, annot=True, cmap='coolwarm', annot_kws={'size': 8})

    # Customize the tick labels with different colors for different sources
    ax = plt.gca()  # Get the current axis

    # Customize the colors for the index
    for tick_label in ax.get_yticklabels():
        label_text = tick_label.get_text()
        if 'mort' in label_text:
            tick_label.set_color('black')
        elif 'lizatio' in label_text:
            tick_label.set_color('dimgray')
        elif 'case' in label_text:
            tick_label.set_color('darkgray')
    # sns.heatmap(clag_corr.T, annot = True, cmap = 'coolwarm', annot_kws={'size': 8})
    # df = pd.concat([dlag_corr.T, hlag_corr.T, clag_corr.T], axis = 0)
    # df['severity_type'] = ['mortality']*15 + ['hospitalization']*15 + ['cases']*15
    # df['state'] = location
    # all_corr_table = pd.concat([all_corr_table, df], axis = 0)

    ax.set_xlabel("Lag [month]")
    ax.set_ylabel("Behaviors (oscillatory component)")
    ax.set_title(location, fontsize = 15)

    # fig.suptitle("National level behavior anomaly comparisons ")
    # plt.show()
    # plt.tight_layout()
    # plt.savefig(newpath + '/behavior_anom_&_death_corr_lags_' + location + '.png', bbox_inches='tight')
    plt.savefig(newpath + '/behavior_anom_vs_mort_hospt_case_corr_lags_short_' + location + '.svg', bbox_inches='tight', format = 'svg')
    # break
# all_corr_table.to_csv(newpath + '/all_state_ba_lag_correlations_with_mort_hospt_case.csv', index = True)



In [None]:
image_path

In [None]:
import os
from PIL import Image
import cairosvg
from fpdf import FPDF

# Directory containing the images
image_dir = '/Users/tamannaurmi/Documents/Research/behavior_covid/behavior_anom_and_mortality_hospit_cases_corr_lags'

# List of locations (update this list with your actual locations)
locations = location_list  # Replace with your locations

# Initialize a PDF
pdf = FPDF(orientation='P', unit='mm', format='A4')  # A4 size page (210 x 297 mm)

# Parameters for image layout
image_width = 130  # Fixed width for each image (adjust based on your need)
image_height = 104  # Fixed height for each image (adjust based on your need)
margin = 18  # Horizontal margin between images

# A4 page size
page_width = 210  # Standard A4 width
page_height = 297  # Standard A4 height

# Calculate center position for images
x_center = (page_width - image_width) / 2  # Center the images horizontally

# Max images per page
images_per_page = 2

# Add a section title on the first page
first_page = True

# Page number tracking
page_num = 69
figure_num = 1

# Loop through the locations and process each image
image_count = 0  # Counter to keep track of images on a page

for location in locations:
    image_name = f"behavior_anom_vs_mort_hospt_case_corr_lags_short_{location}.svg"
    image_path = os.path.join(image_dir, image_name)

    if os.path.exists(image_path):
        # Add a new page if two images have already been placed
        if image_count == 0 or image_count % images_per_page == 0:
            pdf.add_page()  # Add a new page

            # Add section title only on the first page
            if first_page:
                pdf.set_font("Arial", size=12, style='B')
                pdf.cell(0, 10, f"8 State-level correlations between oscillations in behavior trends and disease severity metrics", ln=True, align='L')
                first_page = False  # Ensure this happens only on the first page

        # Convert SVG to PNG using cairosvg and load it
        png_image_data = cairosvg.svg2png(url=image_path)
        png_image_path = os.path.join(image_dir, f"{location}.png")
        with open(png_image_path, "wb") as png_file:
            png_file.write(png_image_data)

        # Calculate the Y position based on image count
        y = margin + (image_count % images_per_page) * (image_height + 20)  # 25 is the extra space for text

        # Add the image to the PDF (centered horizontally)
        pdf.image(png_image_path, x=x_center, y=y, w=image_width, h=image_height)

        # Add text below each image (use multi_cell for text wrapping)
        pdf.set_xy(x_center, y + image_height + 5)  # Set position for text below image
        pdf.set_font("Arial", size=8)
        text = (f"Figure S8.{figure_num} shows the 5 lagged correlations of oscillations in behavior trend with "
                f"mortality (black), hospitalizations (gray), and cases (light gray) for {location}.")
        pdf.multi_cell(image_width, 5, text, align='L')  # multi_cell wraps text

        image_count += 1  # Increment image counter

        # Add page number at the bottom-right of the page when two images are placed
        if image_count % images_per_page == 0:
            pdf.set_font("Arial", size=10)
            pdf.set_xy(-30, -38)  # Set the position for page number correctly
            pdf.cell(0, 10, f"{page_num}", 0, 0, 'R')
            page_num += 1  # Increment the page number

    figure_num += 1

# Save the final PDF
pdf_output_path = '/Users/tamannaurmi/Documents/Research/behavior_covid/behavior_anom_and_mortality_hospit_cases_corr_lags/behavior_anom_and_mortality_hospit_cases_corr_lags_all_state.pdf'
pdf.output(pdf_output_path)

print(f"PDF created successfully at {pdf_output_path}")