In [1]:
# %%
import pandas as pd
import plotly.express as px
from ipywidgets import interact, widgets
import warnings
import os
import requests

warnings.simplefilter(action='ignore', category=FutureWarning)

def get_us_state_to_abbrev():
    return {
        "Alabama": "AL",
        "Alaska": "AK",
        "Arizona": "AZ",
        "Arkansas": "AR",
        "California": "CA",
        "Colorado": "CO",
        "Connecticut": "CT",
        "Delaware": "DE",
        "Florida": "FL",
        "Georgia": "GA",
        "Hawaii": "HI",
        "Idaho": "ID",
        "Illinois": "IL",
        "Indiana": "IN",
        "Iowa": "IA",
        "Kansas": "KS",
        "Kentucky": "KY",
        "Louisiana": "LA",
        "Maine": "ME",
        "Maryland": "MD",
        "Massachusetts": "MA",
        "Michigan": "MI",
        "Minnesota": "MN",
        "Mississippi": "MS",
        "Missouri": "MO",
        "Montana": "MT",
        "Nebraska": "NE",
        "Nevada": "NV",
        "New Hampshire": "NH",
        "New Jersey": "NJ",
        "New Mexico": "NM",
        "New York": "NY",
        "North Carolina": "NC",
        "North Dakota": "ND",
        "Ohio": "OH",
        "Oklahoma": "OK",
        "Oregon": "OR",
        "Pennsylvania": "PA",
        "Rhode Island": "RI",
        "South Carolina": "SC",
        "South Dakota": "SD",
        "Tennessee": "TN",
        "Texas": "TX",
        "Utah": "UT",
        "Vermont": "VT",
        "Virginia": "VA",
        "Washington": "WA",
        "West Virginia": "WV",
        "Wisconsin": "WI",
        "Wyoming": "WY",
        "District of Columbia": "DC",
        "American Samoa": "AS",
        "Guam": "GU",
        "Northern Mariana Islands": "MP",
        "Puerto Rico": "PR",
        "United States Minor Outlying Islands": "UM",
        "U.S. Virgin Islands": "VI", # Or "Virgin Islands"
        "Virgin Islands": "VI" # Common variation
    }

def get_consumption_data():
    file_loc = 'data/state_consumption.csv'
    directory = os.path.dirname(file_loc)

    if not os.path.exists(directory):
        os.makedirs(directory)
    
    if not os.path.exists(file_loc):
        try:
            response = requests.get('https://raw.githubusercontent.com/riverar9/cuny-msds/refs/heads/main/data608-knowledge-and-visual-analytics/stories/story-7/data/state_consumption.csv')
            response.raise_for_status()
            with open(file_loc, 'wb') as file:
                file.write(response.content)
        except Exception as e:
            raise Exception(e)
    
    df = pd.read_csv('data/state_consumption.csv')

    # Drop the 'Rank' column as it's not needed for the transformation
    df = df.drop(columns=['Rank'])

    # Initialize an empty list to store individual DataFrames for each fuel source
    dfs_to_concat = []

    # Define the fuel sources based on the column naming pattern
    fuel_sources_map = {
        'Coal': ('State_Coal', 'Trillion_Btu_Coal'),
        'Natural Gas': ('State_Natural_Gas', 'Trillion_Btu_Natural_Gas'),
        'Petroleum': ('State_Petroleum', 'Trillion_Btu_Petroleum'),
        'Nuclear': ('State_Nuclear', 'Trillion_Btu_Nuclear'),
        'Total Renewable Energy': ('State_Total_Renewable_Energy', 'Trillion_Btu_Total_Renewable_Energy')
    }

    for fuel_name, (state_col, btu_col) in fuel_sources_map.items():
        # Select the relevant state and BTU columns
        temp_df = df[[state_col, btu_col]].copy()
        # Rename columns to a standard format for concatenation
        temp_df.columns = ['State', 'Trillions BTU Consumed']
        # Add the 'Fuel Source' column
        temp_df['Fuel Source'] = fuel_name
        # Append to the list
        dfs_to_concat.append(temp_df)

    # Concatenate all the temporary DataFrames
    result_df = pd.concat(dfs_to_concat, ignore_index=True)

    # Data Cleaning:
    # 1. Remove commas from 'Trillions BTU Consumed'
    # 2. Convert 'Trillions BTU Consumed' to numeric
    result_df['Trillions BTU Consumed'] = result_df['Trillions BTU Consumed'].astype(str).str.replace(',', '', regex=False)
    result_df['Trillions BTU Consumed'] = pd.to_numeric(result_df['Trillions BTU Consumed'])

    # Reorder columns to the desired format
    result_df = result_df[['State', 'Fuel Source', 'Trillions BTU Consumed']]

    total_df = result_df[['State','Trillions BTU Consumed']].groupby('State').sum().reset_index()
    total_df['Fuel Source'] = "Total"

    result_df = pd.concat([result_df,total_df])

    return result_df

def get_generation_data():
    file_loc = 'data/state_generation.csv'
    directory = os.path.dirname(file_loc)

    if not os.path.exists(directory):
        os.makedirs(directory)
    
    if not os.path.exists(file_loc):
        try:
            response = requests.get('https://raw.githubusercontent.com/riverar9/cuny-msds/refs/heads/main/data608-knowledge-and-visual-analytics/stories/story-7/data/state_generation.csv')
            response.raise_for_status()
            with open(file_loc, 'wb') as file:
                file.write(response.content)
        except Exception as e:
            raise Exception(e)

    df =  pd.read_csv(file_loc)

    df = df.drop(
        columns = 'Total'
    )

    df.columns = ['State', 'Coal', 'Natural Gas', 'Petroleum', 'Nuclear', 'Total Renewable Energy', 'Total Renewable Energy', 'Total Renewable Energy']

    df = df.melt(id_vars= ['State'], value_name='Trillions BTU Generated', var_name='Fuel Source')

    df = df.groupby(['State','Fuel Source']).sum().reset_index()

    total_df = df[['State','Trillions BTU Generated']].groupby('State').sum().reset_index()
    total_df['Fuel Source'] = 'Total'

    df = pd.concat([df, total_df])

    return df

def get_dataset():
    con_df = get_consumption_data()

    gen_df = get_generation_data()

    energy_df = con_df.merge(
        gen_df,
        on = ['State','Fuel Source'],
        how = 'outer'
    )

    energy_df['State'] = energy_df['State'].replace(get_us_state_to_abbrev())

    energy_df['net_export'] = energy_df['Trillions BTU Generated'] - energy_df['Trillions BTU Consumed']
    energy_df.head()
    
    state_totals = energy_df.groupby('State').sum(numeric_only=True).reset_index()
    state_totals = state_totals.drop(columns = 'net_export')
    state_totals.columns = ['State','tot_consumed','tot_generated']

    energy_df = energy_df.merge(
        state_totals,
        on = 'State'
    )

    energy_df['state_pct_consumed'] = energy_df['Trillions BTU Consumed'] / energy_df['tot_consumed']
    energy_df['state_pct_generated'] = energy_df['Trillions BTU Generated'] / energy_df['tot_generated']

    energy_df.columns = ['State','Fuel Source','Energy Consumption (TBTU)', 'Energy Generation (TBTU)', 'Net Exports (TBTU)','tot_consumed','tot_generated','Consumption by Source Percentage','Generation by Source Percentage']

    return energy_df

def plot_interactive_energy_map(selected_fuel_source, selected_kpi):
    """
    Filters the energy_df by fuel source and plots the selected KPI on a US choropleth map.
    """
    energy_df = get_dataset()

    # Filter DataFrame based on selected fuel source
    filtered_df = energy_df[energy_df['Fuel Source'] == selected_fuel_source].copy() # Use .copy() to avoid SettingWithCopyWarning

    if filtered_df.empty:
        print(f"No data available for Fuel Source: {selected_fuel_source} and KPI: {selected_kpi}")
        return

    # Determine color scale and midpoint
    if selected_kpi == 'Net Exports (TBTU)':
        color_scale = px.colors.diverging.RdYlGn  # Red-Yellow-Green for diverging
        color_midpoint = 0                        # Center at zero
    elif selected_kpi in ['Energy Generation (TBTU)', 'Generation by Source Percentage']:
        # For generation, use shades of green
        color_scale = px.colors.sequential.Greens # A common sequential green scale
        # Alternatives: px.colors.sequential.YlGn, px.colors.sequential.Emrld
        color_midpoint = None                     # Sequential scales don't typically need a specific midpoint
    elif selected_kpi in ['Energy Consumption (TBTU)', 'Consumption by Source Percentage']:
        # For consumption, use shades of orange/red (not too harsh)
        color_scale = px.colors.sequential.YlOrRd # Yellow-Orange-Red is a good, softer option
        # Alternatives: px.colors.sequential.Oranges, px.colors.sequential.OrRd, px.colors.sequential.Reds (might be too strong)
        color_midpoint = None
    else:
        # Default for any other KPIs
        color_scale = 'Viridis' # A common, perceptually uniform sequential scale
        color_midpoint = None

    # Create the choropleth map
    fig = px.choropleth(
        filtered_df,
        locations=filtered_df['State'],  
        locationmode='USA-states',      
        color=selected_kpi,             
        scope='usa',                     
        color_continuous_scale=color_scale,
        color_continuous_midpoint=color_midpoint,
        hover_name='State',              
        hover_data={                     
            'Fuel Source': True,         
            selected_kpi: ':.2f',        
            'Energy Consumption (TBTU)': ':.2f',
            'Energy Generation (TBTU)': ':.2f',
            'Net Exports (TBTU)': ':.2f'
        },
        title=f'US Energy Data: {selected_kpi} for {selected_fuel_source}'
    )

    # Further layout customization
    fig.update_layout(
        margin={"r":0, "t":40, "l":0, "b":0}, # Adjust margins
        geo=dict(
            lakecolor='rgb(255, 255, 255)',   # Set lake color to white
            # To explicitly show only continental US, you might need to adjust center and projection
            # For example, to try and center on continental US (might need tweaking):
            # projection_scale=0.9, # Adjust scale if needed
            # center=dict(lon=-98, lat=39)
        )
    )

    fig.show()

# %%
def main():
    energy_df = get_dataset()
    
    kpi_columns = ['Energy Consumption (TBTU)', 'Energy Generation (TBTU)', 'Net Exports (TBTU)','Consumption by Source Percentage','Generation by Source Percentage']

    fuel_sources = sorted(energy_df['Fuel Source'].unique())

    fuel_source_dropdown = widgets.Dropdown(
        options=fuel_sources,
        value="Total", # Default to the first fuel source
        description='Fuel Source:',
        style={'description_width': 'initial'} # Adjust width to show full description
    )

    kpi_dropdown = widgets.Dropdown(
        options=kpi_columns,
        value='Energy Generation (TBTU)', # Default to 'net_export' as per specific color request
        description='Select KPI:',
        style={'description_width': 'initial'}
    )

    interact(plot_interactive_energy_map, selected_fuel_source=fuel_source_dropdown, selected_kpi=kpi_dropdown)


In [2]:
main()

interactive(children=(Dropdown(description='Fuel Source:', index=4, options=('Coal', 'Natural Gas', 'Nuclear',…