# Data Exploration
## for covid-19 cases and deaths + hospital capacity data
first let's confirm the correct libraries are installed. For more information on how to set-up your Jupyter Notebook check out the `README.md`

In [None]:
%pip install --upgrade pip

In [None]:
import importlib

def install_if_not_exists(module_names):
    for module_name in module_names:
      try:
          importlib.import_module(module_name)
          print(f"{module_name} is already installed.")
      except ImportError:
          %pip install {module_name}
          print(f"{module_name} has been installed.")

In [None]:
required_modules = ["pandas", "matplotlib", "seaborn", "numpy", "scipy", "IPython"]
install_if_not_exists(required_modules)

In [None]:
import os
import inspect
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import zscore
from IPython.display import display

In [None]:
# Enable inline plotting in Jupyter notebook
%matplotlib inline

In [None]:
data_deaths = "data/Weekly_United_States_COVID-19_Cases_and_Deaths_by_County_-_ARCHIVED_20240113.csv"
data_hosp = "data/COVID-19_Reported_Patient_Impact_and_Hospital_Capacity_by_Facility_20240114.csv"

In [None]:
df_d = pd.read_csv(data_deaths)
df_h = pd.read_csv(data_hosp)

In [None]:
df_d.shape

In [None]:
df_h.shape

In [None]:
list(df_d.columns)

In [None]:
list(df_h.columns)

In [None]:
# might be good for later mapping data
df_h['geocoded_hospital_address'].sample(10)

## Dataset Clean-Up Functions

In [None]:
# Create a dictionary with each entry as columns_h[i]: 'sum'
def get_agg_dict(columns:list) -> dict:
    return {col: 'sum' for col in columns}

In [None]:
def get_merged_data(weekly_aggregated_cases_STATE:pd.core.frame.DataFrame, weekly_aggregated_hospitals_STATE:pd.core.frame.DataFrame) -> pd.core.frame.DataFrame:
    
    merged_df = pd.merge(weekly_aggregated_cases_STATE, weekly_aggregated_hospitals_STATE, on=['year', 'week', 'is_metro_micro'])
    merged_df['date'] = pd.to_datetime(merged_df['year'].astype(str) + merged_df['week'].astype(str) + '1', format='%Y%U%w')
    merged_df = merged_df.drop(columns=['year', 'week']).set_index('date')

    # Set the threshold count (52 in this case aka at least 1-years worth of data)
    threshold_count = 52

    # Filter columns based on count
    selected_columns = merged_df.columns[merged_df.count() >= threshold_count]

    # Create a new DataFrame with selected columns
    merged_df = merged_df[selected_columns]
    
    print("The shape of the merged_df is", merged_df.shape[0], "by", merged_df.shape[1])
    assert sum(merged_df['is_metro_micro'].unique() == 1), "The column 'is_metro_micro' should only have 0 or 1 as values"
    
    return merged_df

In [None]:
def get_df_sub_state_cleaned(df:pd.core.frame.DataFrame, 
                             columns_to_filter:list, 
                             date_col:str='collection_week',
                             state:str='MI',
                             merge_county_data_enabled:bool=False, 
                             county_df:pd.core.frame.DataFrame=None) -> pd.core.frame.DataFrame:
    
    if (date_col == 'collection_week'):
        columns_for_subset = ['collection_week', 'state', 'hospital_name', 'is_metro_micro', 'geocoded_hospital_address'] + columns_to_filter
    else:
        columns_for_subset = ['fips_code', 'state', 'date'] + columns_to_filter

    df_sub = df[columns_for_subset]
    df_sub[date_col] = pd.to_datetime(df_sub[date_col])
    df_sub['year'] = df_sub[date_col].dt.year
    df_sub['week'] = df_sub[date_col].dt.isocalendar().week
    
    if (df_sub[columns_to_filter].isnull().any().any()):
        # Replace NaN values with 0
        print('NaN values present and being replaced with zeros')
        df_sub[columns_to_filter] = df_sub[columns_to_filter].fillna(0)
        
    # Replace negative values with 0 (negative values are insertion errors)
    df_sub[columns_to_filter] = df_sub[columns_to_filter].applymap(lambda x: max(0, x))
    
    # exploring hospital data per STATE (default MI aka Michigan)
    df_STATE = df_sub[df_sub['state'] == state]
    
    if merge_county_data_enabled:
        # Check if county_df is None
        assert county_df is not None, "Input DataFrame 'county_df' cannot be None"
        df_STATE = pd.merge(df_STATE, county_df, on='fips_code')

    df_STATE['is_metro_micro'] = df_STATE['is_metro_micro'].astype(int)
    
    return df_STATE

In [None]:
def save_cleaned_data_per_state(states:list):
    current_directory = str(os.getcwd()).replace('\\','/')
    print(f"saving files to {current_directory}")

    for state in states:
        folder_path = f"/data/outputs/{state}"
        file_name_h = f"/df_h_sub_cleaned_{state}.csv"
        file_name_d = f"/df_d_sub_cleaned_{state}.csv"
        output_folder_path = current_directory + folder_path

        df_h_sub_cleaned = get_df_sub_state_cleaned(df_h, columns_h, date_col='collection_week', state=state)
        df_d_sub_cleaned = get_df_sub_state_cleaned(df_d, columns_d, date_col='date', state=state, merge_county_data_enabled=True, county_df=county_df)

        # Check if the folder exists; if not, create it
        if not os.path.exists(output_folder_path):
            os.makedirs(output_folder_path)
            print(output_folder_path)

        print(f"saving data for {state} ...")
        df_h_sub_cleaned.to_csv((output_folder_path + file_name_h), index=False)
        df_d_sub_cleaned.to_csv((output_folder_path + file_name_d), index=False)

In [None]:
def get_weekly_aggregated_data(df_sub_STATE:pd.core.frame.DataFrame, columns:list, replace_zeros:bool=False) -> pd.core.frame.DataFrame:  

    weekly_aggregated_STATE = df_sub_STATE.groupby(['year', 'week', 'is_metro_micro']).agg(get_agg_dict(columns)).reset_index()
    
    if replace_zeros:
        # Replace all 0 values back with NaN so that they do not affect the correlation data
        weekly_aggregated_STATE[columns] = weekly_aggregated_STATE[columns].replace(0, np.nan)
        
    assert sum(weekly_aggregated_STATE['is_metro_micro'].unique() == 1), "The column 'is_metro_micro' should only have 0 or 1 as values"
    
    return weekly_aggregated_STATE

In [None]:
def get_correlation_columns(merged_df:pd.core.frame.DataFrame, columns_to_correlate:list) -> tuple:
    """
    If you're observing NaN values in the correlation matrix heatmap, it might be due to a situation where the standard deviation of one of the columns involved in the correlation calculation is zero. This can happen when the data in one of the columns is constant across all rows.
    In a correlation calculation, when the standard deviation is zero, the denominator in the correlation formula becomes zero, leading to division by zero and resulting in NaN values.
    To handle this situation, you can preprocess your data to handle constant columns or columns with zero variance.
    """
    
    high_correlation_arr = []

    for category in range(2):

        metro_micro_df =  merged_df[merged_df['is_metro_micro'] == category]

        # Identify and drop constant columns
        constant_columns = metro_micro_df.columns[metro_micro_df.nunique() == 1]
        metro_micro_df = metro_micro_df.drop(columns=constant_columns)

        # Select the specified columns along with all other columns in the DataFrame
        selected_columns = columns_to_correlate + metro_micro_df.columns.difference(columns_to_correlate).tolist()
        
        # Calculate the correlation matrix for the selected columns
        correlation_matrix = metro_micro_df[selected_columns].corr() # method='spearman'
        correlation_matrix = correlation_matrix.dropna(how='all')
        correlation_matrix = correlation_matrix.dropna(axis=1, how='all')
        
        # display correlation heat map
        show_heat_map(correlation_matrix.loc[columns_to_correlate], category)
        
        # get high correlation columns
        high_correlation_columns = (correlation_matrix.loc[columns_to_correlate].columns[((correlation_matrix.loc[columns_to_correlate].abs() > 0.7) & 
                                   (correlation_matrix.loc[columns_to_correlate].abs() < 1)).any()].tolist())
        
        high_correlation_columns = [col for col in high_correlation_columns if col not in columns_to_correlate]
        
        # check that the column has a reasonable amount of data for that category
        high_correlation_columns = [col for col in high_correlation_columns if (metro_micro_df[col].dropna().shape[0] > 10)]
        
        high_correlation_arr.append(high_correlation_columns)
        
    return high_correlation_arr[0], high_correlation_arr[1], correlation_matrix.loc[columns_to_correlate]

## Data Visualization Functions

In [None]:
def show_heat_map(correlation_matrix:pd.core.frame.DataFrame, category:int=1):
    
    if (category == 1):
        title_text = "Urban"
    else:
        title_text = "Rural"
    
    # Plot the correlation heatmap
    plt.figure(figsize=(30, 3))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", linewidths=.5, annot_kws={'rotation': 90})
    plt.title(f'Correlation Heatmap In {title_text} Areas')
    plt.show()

In [None]:
# For high_correlation_columns_urban/rural let's check out their distributions
def get_norm_dist_of_high_corr(merged_df:pd.core.frame.DataFrame, 
                               corr_matrix_sub:pd.core.frame.DataFrame, 
                               high_correlation_arr:list, 
                               is_metro_micro:bool):
    
    category = 'is_metro_micro'
    
    for col in high_correlation_arr:
        sample_data = merged_df[merged_df[category] == is_metro_micro][col]
        corr_val = corr_matrix_sub[col][is_metro_micro]

        if sample_data.isnull().any():
            sample_data.fillna(sample_data.mean(), inplace=True)

        df_z_score = zscore(sample_data)

        sns.histplot(df_z_score, kde=True, stat='density', label=col + '- Standardized Data')
        plt.title('Standard Normal Distribution Curve for \n' + col + '\n is_urban = ' + str(is_metro_micro))
        plt.xlabel(f'For a given correlation of {round(corr_val,2)}')
        
        plt.show()

In [None]:
# Specify the columns for the scatter plot
def get_corr_plot(merged_df:pd.core.frame.DataFrame, 
                  corr_matrix_sub:pd.core.frame.DataFrame, 
                  high_correlation_arr:list, 
                  columns_d:list, 
                  is_metro_micro:bool):
    
    category = 'is_metro_micro'
    sample_data = merged_df[merged_df[category] == is_metro_micro]    
    
    for col in high_correlation_arr:
        y_column = col
        corr_val = corr_matrix_sub[col][is_metro_micro]

        # Plot the scatter plot with different colors for each category
        plt.figure(figsize=(14, 6))
        
        for i, x_col in enumerate(columns_d):
            plt.subplot(1, 2, i+1)
            x_column='new_deaths'
            sns.scatterplot(data=sample_data, x=x_col, y=y_column, hue=category, palette={0: 'blue', 1: 'orange'}, s=100)
            plt.title(f'{y_column} vs {x_col}')
            plt.xlabel(f'{x_col} given a {round(corr_val,2)} correlation')

        plt.tight_layout()
        plt.show()

## Main Function

In [None]:
def main(df_d:pd.core.frame.DataFrame, df_h:pd.core.frame.DataFrame, county_df:pd.core.frame.DataFrame, columns_d:list, columns_h:list, state:str='MI'):
    
    print(f"\ncreating sub hospital dataframe for {state}")
    df_h_sub = get_df_sub_state_cleaned(df_h, columns_h, date_col='collection_week', state=state)
    # display(df_h_sub.sample(5))
    print("\nsub dataframe complete")
    
    print(f"\ncreating sub covid cases/deaths dataframe for {state}")
    df_d_sub = get_df_sub_state_cleaned(df_d, columns_d, date_col='date', state=state, merge_county_data_enabled=True, county_df=county_df)
    # display(df_d_sub.sample(5))
    print("\nsub dataframe complete")
    
    print("\naggregating hospital data")
    weekly_aggregated_hospitals = get_weekly_aggregated_data(df_h_sub, columns_h, replace_zeros=True)
    display(weekly_aggregated_hospitals.sample(5))
    
    print("\naggregating covid cases/deaths data")
    weekly_aggregated_cases = get_weekly_aggregated_data(df_d_sub, columns_d)
    display(weekly_aggregated_cases.sample(5))
    
    print("\nmerging dataframes on week, year, and is_metro_micro")
    merged_df = get_merged_data(weekly_aggregated_cases_STATE=weekly_aggregated_cases, weekly_aggregated_hospitals_STATE=weekly_aggregated_hospitals)
    display(merged_df.head())
    
    print("\nfinding high correlating columns to covid cases and deaths")
    # Get columns with correlation greater than 0.7 and less than 1 for rural areas
    high_correlation_columns_rural, high_correlation_columns_urban, corr_matrix_sub = get_correlation_columns(merged_df, columns_d)
    
    print("\nAll highly correlated columns for RURAL AREAs:")
    print("---")
    for col in high_correlation_columns_rural:
        print(col, "has", merged_df[merged_df['is_metro_micro'] == 0][col].dropna().shape[0], "data points")
    
    print("\nAll highly correlated columns for URBAN AREAs:")
    print("---")
    for col in high_correlation_columns_urban:
        print(col, "has", merged_df[merged_df['is_metro_micro'] == 1][col].dropna().shape[0], "data points")
    
    print("\ndisplaying normal distribution of highly correlated columns")
    get_norm_dist_of_high_corr(merged_df, corr_matrix_sub, high_correlation_columns_rural, 0)
    get_norm_dist_of_high_corr(merged_df, corr_matrix_sub, high_correlation_columns_urban, 1)
    
    print("\ndisplaying correlation plots of highly correlated columns")
    get_corr_plot(merged_df, corr_matrix_sub, high_correlation_columns_rural, columns_d, is_metro_micro=0)
    get_corr_plot(merged_df, corr_matrix_sub, high_correlation_columns_urban, columns_d, is_metro_micro=1)
    
    return merged_df, high_correlation_columns_rural, high_correlation_columns_urban
    

## Hospital Dataset Clean-Up

In [None]:
# let's clean up the hospital data columns for only the ones we want to aggregate
columns_h = list(df_h.columns)[11:]

# remove any column names with influenza or other non-relavant text
substrings_to_exclude = ['influenza', 'hhs_ids', 'is_corrected', 'total_personnel_covid_vaccinated', 'geocoded_hospital_address']
columns_h = [col for col in columns_h if not any(substring in col for substring in substrings_to_exclude)]

## COVID-19 Cases & Deaths Dataset Clean-Up

In [None]:
# first, rename column names so that they match all other column name formats
column_mapping_d = {'New cases': 'new_cases', 'New deaths': 'new_deaths'}
df_d.rename(columns=column_mapping_d, inplace=True)

In [None]:
# covid death data
df_d.sample(10)

### merging in urban vs rural data
You may notice that we have an issue here compared to our other dataset. The hospital dataset has information on 'is_metro_micro', however this dataset on covid deaths and cases does not have this differentiation. Therefore we would not be able to combine datasets without losing some of that information, and we cannot assume how the cases and deaths are split up by rural vs urban in this dataset. Therefore we need to extract some more information from the hospital dataset first to be able to combine it with our current dataset. 

The easiest way to do this, is both datasets have location information based on the column 'fips_code'. We will merge the 'is_metro_micro' data onto our deaths/cases dataset on the 'fips_code' column

In [None]:
county_df = df_h[['fips_code', 'is_metro_micro']].dropna()
# county_df = county_df.astype(int)
county_df.head()

In [None]:
# let's grab the covid data columns for only the ones we want to aggregate
columns_d = ['new_cases', 'new_deaths']

Let's save some of the data to our output for future manipulation. Uncomment to run, should take around ~20min to go through all 50 states

In [None]:
# save_cleaned_data_per_state(list(df_d['state'].unique()))

## Display Correlation Data Per State

### Michigan

In [None]:
merged_df_MI, high_correlation_columns_rural_MI, high_correlation_columns_urban_MI = main(df_d, df_h, county_df, columns_d, columns_h, state='MI')

In [None]:
merged_df_NY, high_correlation_columns_rural_NY, high_correlation_columns_urban_NY = main(df_d, df_h, county_df, columns_d, columns_h, state='NY')

In [None]:
merged_df_WA, high_correlation_columns_rural_WA, high_correlation_columns_urban_WA = main(df_d, df_h, county_df, columns_d, columns_h, state='WA')

In [None]:
merged_df_TX, high_correlation_columns_rural_TX, high_correlation_columns_urban_TX = main(df_d, df_h, county_df, columns_d, columns_h, state='TX')

In [None]:
merged_df_FL, high_correlation_columns_rural_FL, high_correlation_columns_urban_FL = main(df_d, df_h, county_df, columns_d, columns_h, state='FL')