In [None]:
import json
file_dir = "test_compare_base_disagg"
file_path = f"./{file_dir}/mseg_res_com_state_2024.json"
with open(file_path, "r") as file:
    dataset = json.load(file)

print("RESIDENTIAL")
print(dataset ['CA']['single family home']['electricity'].keys())
print("COMMERCIAL")
print(dataset ['CA']['large office']['electricity'].keys())

### 1. Data preparation for Scout output using Stock data (electricity)

In [None]:
def is_valid_stock(stock):
    # If stock is a string and equals "NA", it's not valid.
    if isinstance(stock, str):
        if stock == "NA":
            return False
        # Otherwise, if it's a non-"NA" string, you could choose to accept or reject it.
        return True

    # If stock is a dictionary, check its year values.
    elif isinstance(stock, dict):
        # If the dictionary is empty, it isn't valid.
        if not stock:
            return False
        # If all values in the dictionary are 0.0, then it's not valid.
        if all(float(value) == 0.0 for value in stock.values()):
            return False
        # Otherwise, it has at least one non-zero value.
        return True

    # Any other type is considered invalid.
    return False


######################################

def dataprep_compare_elec(geo, strout, energy_or_stock, dir_2023, dir_new):
    def compute_relative_shares(data_dict):
        """Convert absolute values to state-level shares by end-use and year."""
        relative_data = {state: {end_use: {} for end_use in end_uses_all} for state in states}
        for end_use in end_uses_all:
            for year in years:
                total = sum(data_dict[state][end_use][year] for state in states)
                for state in states:
                    if total > 0:
                        share = data_dict[state][end_use][year] / total
                    else:
                        share = 0
                    relative_data[state][end_use][year] = share
        return relative_data
    
    


    
    if strout == "Elec_End-Use_EUSS2024_vs_End-Use_EUSS2023":
        file_paths = [f"./{dir_new}/mseg_res_com_{geo}_2024_eu_factors.json", f"./{dir_2023}/mseg_res_com_{geo}_2023.json"]
    elif strout == "Elec_Tech_EUSS2024_vs_End-Use_EUSS2023":   
        file_paths = [f"./{dir_new}/mseg_res_com_{geo}_2024_tech_factors.json", f"./{dir_2023}/mseg_res_com_{geo}_2023.json"]
    elif strout == "Elec_Tech_EUSS2024_vs_End-Use_EUSS2024":
        file_paths = [f"./{dir_new}/mseg_res_com_{geo}_2024_tech_factors.json", f"./{dir_new}/mseg_res_com_{geo}_2024_eu_factors.json"]
        
    
    datasets = []
    
    for path in file_paths:
        with open(path, "r") as file:
            datasets.append(json.load(file))
    
    states = set(datasets[0].keys()).intersection(set(datasets[1].keys()))
    
    end_uses_all = set()
    for state in states:
        for building_type in datasets[0][state]:
            for energy_source in datasets[0][state][building_type]: 
                if energy_source == 'electricity':
                    end_uses_all.update(datasets[0][state][building_type][energy_source].keys())
    
    res_bldgs = ['single family home', 'multi family home', 'mobile home']
    end_uses_all = list(end_uses_all) 
    end_uses_wsupply = ["heating", "cooling", "secondary heating"]
    years = [str(year) for year in range(2016, 2051)]
    
    allbldgs_data = [{state: {end_use: {year: 0 for year in years} for end_use in end_uses_all} for state in states} for _ in range(2)]
    resbldgs_data = [{state: {end_use: {year: 0 for year in years} for end_use in end_uses_all} for state in states} for _ in range(2)]
    combldgs_data = [{state: {end_use: {year: 0 for year in years} for end_use in end_uses_all} for state in states} for _ in range(2)]
    
    
    for dataset_index, dataset in enumerate(datasets):
        for state in states:
            for building_type in dataset [state]:
                for energy_source in dataset [state][building_type]:
                    if energy_source == "electricity":
                        for end_use in dataset [state][building_type][energy_source]:
                            if end_use in end_uses_wsupply:
                                for subcategory in dataset [state][building_type][energy_source][end_use]['supply']:
                                    if is_valid_stock(dataset [state][building_type][energy_source][end_use]['supply'][subcategory][energy_or_stock]):
                                        for year, value in dataset [state][building_type][energy_source][end_use]['supply'][subcategory][energy_or_stock].items():
                                            if year in allbldgs_data[dataset_index][state][end_use]:
                                                allbldgs_data[dataset_index][state][end_use][year] += value
                                            if building_type in res_bldgs:
                                                if year in resbldgs_data[dataset_index][state][end_use]:
                                                    resbldgs_data[dataset_index][state][end_use][year] += value
                                            if building_type not in res_bldgs:
                                                if year in combldgs_data[dataset_index][state][end_use]:
                                                    combldgs_data[dataset_index][state][end_use][year] += value
                            if end_use not in end_uses_wsupply:
                                for subcategory in dataset [state][building_type][energy_source][end_use]:
                                    if subcategory != 'energy' and subcategory != 'stock':
                                        if is_valid_stock(dataset [state][building_type][energy_source][end_use][subcategory][energy_or_stock]):
                                            for year, value in dataset [state][building_type][energy_source][end_use][subcategory][energy_or_stock].items():
                                                if year in allbldgs_data[dataset_index][state][end_use]:  # Ensure year is valid
                                                    allbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type in res_bldgs:
                                                    if year in resbldgs_data[dataset_index][state][end_use]:
                                                        resbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type not in res_bldgs:
                                                    if year in combldgs_data[dataset_index][state][end_use]:
                                                        combldgs_data[dataset_index][state][end_use][year] += value
                                    else:
                                        if is_valid_stock(dataset [state][building_type][energy_source][end_use][energy_or_stock]):
                                            for year, value in dataset [state][building_type][energy_source][end_use][energy_or_stock].items():
                                                if year in allbldgs_data[dataset_index][state][end_use]:  # Ensure year is valid
                                                    allbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type in res_bldgs:
                                                    if year in resbldgs_data[dataset_index][state][end_use]:
                                                        resbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type not in res_bldgs:
                                                    if year in combldgs_data[dataset_index][state][end_use]:
                                                        combldgs_data[dataset_index][state][end_use][year] += value
    
    # Convert absolute values to shares (% of total within each end use across states)
    allbldgs_data = [compute_relative_shares(d) for d in allbldgs_data]
    resbldgs_data = [compute_relative_shares(d) for d in resbldgs_data]
    combldgs_data = [compute_relative_shares(d) for d in combldgs_data]
    
    
    print("COMBINE ALL SECTORS")
    allbldgs_dfs = [
        {state: pd.DataFrame.from_dict(allbldgs_data[dataset_index][state], orient="index", columns=years).T for state in states}
        for dataset_index in range(2)
    ]
    
    print("RESIDENTIAL SECTOR")
    resbldgs_dfs = [
        {state: pd.DataFrame.from_dict(resbldgs_data[dataset_index][state], orient="index", columns=years).T for state in states}
        for dataset_index in range(2)
    ]
    print("COMMERCIAL SECTOR")
    combldgs_dfs = [
        {state: pd.DataFrame.from_dict(combldgs_data[dataset_index][state], orient="index", columns=years).T for state in states}
        for dataset_index in range(2)
    ]
    
    return allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states


### 2. Comparison of Scout output using different Stock version (electricity)


In [None]:
### 2. Comparison of Scout output using different Stock version (electricity)

def plot_scatter_all(sector, allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states, energy_or_stock):
    year = "2024"
    if energy_or_stock == "energy":
        strx = "% total MWh across regions [new]" #"Energy Use (MWh)"
        stry = "% total MWh across regions [old]"
    else:
        strx = "% total tech-stock across regions [new]" #tech-stock
        stry = "% total tech-stock across regions [old]"
    if sector == "residential":
        dfs = resbldgs_dfs
        end_uses = res_end_uses
    else:
        dfs = combldgs_dfs
        end_uses = com_end_uses
    
    # Determine grid size
    num_end_uses = len(end_uses)
    grid_cols = 3  # Number of columns in the grid
    grid_rows = (num_end_uses + 1) // grid_cols  # Number of rows
    
    # Create a grid of scatter plots
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(18, 6 * grid_rows))
    axes = axes.flatten()  # Flatten in case of single row
    
    for idx, end_use in enumerate(end_uses):
        x_values = []
        y_values = []
        state_labels = []
    
        for state in states:
            if (state in dfs[0] and state in dfs[1] and end_use in dfs[0][state].columns and end_use in dfs[1][state].columns):
                # Extract values for the given year
                x = dfs[0][state].loc[year, end_use]
                y = dfs[1][state].loc[year, end_use]
                
                # Only add values if both x and y are finite
                if np.isfinite(x) and np.isfinite(y):
                    x_values.append(x)
                    y_values.append(y)
                    state_labels.append(state)
        ##############################################################            
        # # Add y=x reference line
        # min_val = min(min(x_values), min(y_values))
        # max_val = max(max(x_values), max(y_values))

        # ax = axes[idx]
        # ax.scatter(x_values, y_values, marker='o', color='blue', alpha=0.7, label="States")
        # ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="red", label="y = x")
    
        # # Add state labels to the points
        # for i, state in enumerate(state_labels):
        #     ax.text(x_values[i], y_values[i], state, fontsize=9, ha='right', va='bottom')
        ##############################################################
        x_values = np.array(x_values)
        y_values = np.array(y_values)
        
        # Compute the minimum and maximum values from the finite data
        min_val = min(x_values.min(), y_values.min())
        max_val = max(x_values.max(), y_values.max())
        
        ax = axes[idx]
        ax.scatter(x_values, y_values, marker='o', color='blue', alpha=0.7, label="States")
        ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="red", label="y = x")
        
        # Set both x and y limits to be the same
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)
        # Ensure equal aspect ratio for the axes
        ax.set_aspect('equal', adjustable='box')
        
        # Add state labels to the points
        for i, state in enumerate(state_labels):
            ax.text(x_values[i], y_values[i], state, fontsize=9, ha='right', va='bottom')
            
        ax.set_title(f"{sector} {end_use}\n{strout}")
        if energy_or_stock == "energy":
            # ax.set_xlabel("Using new factors: Energy Use (MWh)")
            # ax.set_ylabel("Using old factors:  Energy Use (MWh)")
            ax.set_xlabel(f"{strx}")
            ax.set_ylabel(f"{stry}")
        else:
            ax.set_xlabel("Using new factors")
            ax.set_ylabel("Using old factors")
        ax.grid(True, linestyle="--", alpha=0.6)
        ax.legend()
    
    # Hide any extra subplots if the number of end-uses is not a perfect multiple of grid_cols
    for i in range(num_end_uses, len(axes)):
        fig.delaxes(axes[i])
    
    plt.tight_layout()
    
    # Save the grid plot
    grid_plot_filename = f"{geo}_{strout}_{energy_or_stock}_{sector}.png"
    plt.savefig(f"./{dir_new}/{grid_plot_filename}")
    plt.show()



############
import json
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
res_end_uses = ['heating', 'secondary heating', 'cooling', 'fans and pumps', 'ceiling fan', 'lighting', 'water heating',
                'refrigeration', 'cooking', 'drying', 'TVs', 'computers', 'other', 'onsite generation']
com_end_uses = ['heating', 'cooling', 'water heating', 'ventilation', 'cooking', 'lighting', 'refrigeration', 'PCs',
                'non-PC office equipment', 'MELs', 'onsite generation']
geos = ["EMM","state"]
strouts = ["Elec_End-Use_EUSS2024_vs_End-Use_EUSS2023",
           "Elec_Tech_EUSS2024_vs_End-Use_EUSS2023",
           "Elec_Tech_EUSS2024_vs_End-Use_EUSS2024"]
energy_or_stocks = ["energy","stock"]
dir_2023 = "test_compare_base_disagg/2023_work"
dir_new = "test_compare_base_disagg/jaredstest"

for geo in geos:
    for strout in strouts:
        for energy_or_stock in energy_or_stocks:

            allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states = dataprep_compare_elec(geo, strout, energy_or_stock, dir_2023, dir_new)
            plot_scatter_all("commercial", allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states, energy_or_stock)
            plot_scatter_all("residential", allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states, energy_or_stock)


### 1. Data preparation for Scout output using Stock data (non-electricity)

In [None]:


def is_valid_stock(stock):
    # If stock is a string and equals "NA", it's not valid.
    if isinstance(stock, str):
        if stock == "NA":
            return False
        # Otherwise, if it's a non-"NA" string, you could choose to accept or reject it.
        return True

    # If stock is a dictionary, check its year values.
    elif isinstance(stock, dict):
        # If the dictionary is empty, it isn't valid.
        if not stock:
            return False
        # If all values in the dictionary are 0.0, then it's not valid.
        if all(float(value) == 0.0 for value in stock.values()):
            return False
        # Otherwise, it has at least one non-zero value.
        return True

    # Any other type is considered invalid.
    return False


def dataprep_compare_nonelec(geo, strout, energy_or_stock, dir_2023, dir_new):
    def compute_relative_shares(data_dict):
        """Convert absolute values to state-level shares by end-use and year."""
        relative_data = {state: {end_use: {} for end_use in end_uses_all} for state in states}
        for end_use in end_uses_all:
            for year in years:
                total = sum(data_dict[state][end_use][year] for state in states)
                for state in states:
                    if total > 0:
                        share = data_dict[state][end_use][year] / total
                    else:
                        share = 0
                    relative_data[state][end_use][year] = share
        return relative_data
    
    file_paths = [f"./{dir_new}/mseg_res_com_{geo}_2024_eu_factors.json", f"./{dir_2023}/mseg_res_com_{geo}_2023.json"]
    datasets = []
    
    for path in file_paths:
        with open(path, "r") as file:
            datasets.append(json.load(file))
    
    states = set(datasets[0].keys()).intersection(set(datasets[1].keys()))
    
    end_uses_all = set()
    for state in states:
        for building_type in datasets[0][state]:
            for energy_source in datasets[0][state][building_type]: 
                if energy_source == fueltype:
                    end_uses_all.update(datasets[0][state][building_type][energy_source].keys())
    
    res_bldgs = ['single family home', 'multi family home', 'mobile home']
    end_uses_all = list(end_uses_all) 
    end_uses_wsupply = ["heating", "cooling", "secondary heating"]
    years = [str(year) for year in range(2016, 2051)]
    
    allbldgs_data = [{state: {end_use: {year: 0 for year in years} for end_use in end_uses_all} for state in states} for _ in range(2)]
    resbldgs_data = [{state: {end_use: {year: 0 for year in years} for end_use in end_uses_all} for state in states} for _ in range(2)]
    combldgs_data = [{state: {end_use: {year: 0 for year in years} for end_use in end_uses_all} for state in states} for _ in range(2)]


    for dataset_index, dataset in enumerate(datasets):
        for state in states:
            for building_type in dataset [state]:
                for energy_source in dataset [state][building_type]:
                    if energy_source == fueltype:
                        for end_use in dataset [state][building_type][energy_source]:
                            if end_use in end_uses_wsupply:
                                for subcategory in dataset [state][building_type][energy_source][end_use]['supply']:
                                    if is_valid_stock(dataset [state][building_type][energy_source][end_use]['supply'][subcategory][energy_or_stock]):
                                        for year, value in dataset [state][building_type][energy_source][end_use]['supply'][subcategory][energy_or_stock].items():
                                            if year in allbldgs_data[dataset_index][state][end_use]:
                                                allbldgs_data[dataset_index][state][end_use][year] += value
                                            if building_type in res_bldgs:
                                                if year in resbldgs_data[dataset_index][state][end_use]:
                                                    resbldgs_data[dataset_index][state][end_use][year] += value
                                            if building_type not in res_bldgs:
                                                if year in combldgs_data[dataset_index][state][end_use]:
                                                    combldgs_data[dataset_index][state][end_use][year] += value
                            if end_use not in end_uses_wsupply:
                                for subcategory in dataset [state][building_type][energy_source][end_use]:
                                    if subcategory != 'energy' and subcategory != 'stock':
                                        if is_valid_stock(dataset [state][building_type][energy_source][end_use][subcategory][energy_or_stock]):
                                            for year, value in dataset [state][building_type][energy_source][end_use][subcategory][energy_or_stock].items():
                                                if year in allbldgs_data[dataset_index][state][end_use]:  # Ensure year is valid
                                                    allbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type in res_bldgs:
                                                    if year in resbldgs_data[dataset_index][state][end_use]:
                                                        resbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type not in res_bldgs:
                                                    if year in combldgs_data[dataset_index][state][end_use]:
                                                        combldgs_data[dataset_index][state][end_use][year] += value
                                    else:
                                        if is_valid_stock(dataset [state][building_type][energy_source][end_use][energy_or_stock]):
                                            for year, value in dataset [state][building_type][energy_source][end_use][energy_or_stock].items():
                                                if year in allbldgs_data[dataset_index][state][end_use]:  # Ensure year is valid
                                                    allbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type in res_bldgs:
                                                    if year in resbldgs_data[dataset_index][state][end_use]:
                                                        resbldgs_data[dataset_index][state][end_use][year] += value
                                                if building_type not in res_bldgs:
                                                    if year in combldgs_data[dataset_index][state][end_use]:
                                                        combldgs_data[dataset_index][state][end_use][year] += value
    
    
    # Convert absolute values to shares (% of total within each end use across states)
    allbldgs_data = [compute_relative_shares(d) for d in allbldgs_data]
    resbldgs_data = [compute_relative_shares(d) for d in resbldgs_data]
    combldgs_data = [compute_relative_shares(d) for d in combldgs_data]
    
    print("COMBINE ALL SECTORS")
    allbldgs_dfs = [
        {state: pd.DataFrame.from_dict(allbldgs_data[dataset_index][state], orient="index", columns=years).T for state in states}
        for dataset_index in range(2)
    ]
    
    print("RESIDENTIAL SECTOR")
    resbldgs_dfs = [
        {state: pd.DataFrame.from_dict(resbldgs_data[dataset_index][state], orient="index", columns=years).T for state in states}
        for dataset_index in range(2)
    ]
    print("COMMERCIAL SECTOR")
    combldgs_dfs = [
        {state: pd.DataFrame.from_dict(combldgs_data[dataset_index][state], orient="index", columns=years).T for state in states}
        for dataset_index in range(2)
    ]

    return allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states


### 2. Comparison of Scout output using different Stock version (all fuels)


In [None]:
### 2. Comparison of Scout output using different Stock version (all fuels)


def plot_scatter_all(sector, allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states, energy_or_stock):
    year = "2024"
    if energy_or_stock == "energy":
        strx = "% total MWh across regions [new]" #"Energy Use (MWh)"
        stry = "% total MWh across regions [old]"
    else:
        strx = "% total tech-stock across regions [new]" #tech-stock
        stry = "% total tech-stock across regions [old]"
    
    if fueltype == "natural gas":
        com_end_uses = ['heating', 'cooling', 'water heating', 'cooking', 'other']
        res_end_uses = ['heating', 'cooling', 'water heating', 
                        'cooking', 'drying', 'other']
    elif fueltype == "distillate":
        res_end_uses = ['heating', 'secondary heating', 'water heating']
        com_end_uses = ['heating', 'water heating', 'other']
        
    elif fueltype == "other fuel":
        com_end_uses = []
        res_end_uses = ['heating', 'water heating', 'secondary heating', 'cooking', 'other']
        
    if sector == "residential":
        dfs = resbldgs_dfs
        end_uses = res_end_uses
    else:
        dfs = combldgs_dfs
        end_uses = com_end_uses
    
    # Determine grid size
    num_end_uses = len(end_uses)
    grid_cols = 3  # Number of columns in the grid
    grid_rows = (num_end_uses + 1) // grid_cols  # Number of rows
    
    # Create a grid of scatter plots
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(18, 6 * grid_rows))
    axes = axes.flatten()  # Flatten in case of single row

    for idx, end_use in enumerate(end_uses):
        x_values = []
        y_values = []
        state_labels = []
    
        for state in states:
            if (state in dfs[0] and state in dfs[1] and end_use in dfs[0][state].columns and end_use in dfs[1][state].columns):
                # Extract values for the given year
                x = dfs[0][state].loc[year, end_use]
                y = dfs[1][state].loc[year, end_use]
                
                # Only add values if both x and y are finite
                if np.isfinite(x) and np.isfinite(y):
                    x_values.append(x)
                    y_values.append(y)
                    state_labels.append(state)
        x_values = np.array(x_values)
        y_values = np.array(y_values)

        ax = axes[idx]
        ax.scatter(x_values, y_values, marker='o', color='blue', alpha=0.7, label="States")

        #################################        
        min_val = min(x_values.min(), y_values.min())
        max_val = max(x_values.max(), y_values.max())
        ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="red", label="y = x")
        #################################

        # Set both x and y limits to be the same
        # ax.set_xlim(min_val, max_val)
        # ax.set_ylim(min_val, max_val)
        # Ensure equal aspect ratio for the axes
        ax.set_aspect('equal', adjustable='box')
        
        # Add state labels to the points
        for i, state in enumerate(state_labels):
            ax.text(x_values[i], y_values[i], state, fontsize=9, ha='right', va='bottom')
            
        ax.set_title(f"{sector} {end_use}\n{strout}")
        if energy_or_stock == "energy":
            # ax.set_xlabel("Using new factors: Energy Use (MWh)")
            # ax.set_ylabel("Using old factors:  Energy Use (MWh)")
            ax.set_xlabel(f"{strx}")
            ax.set_ylabel(f"{stry}")
        else:
            ax.set_xlabel("Using new factors")
            ax.set_ylabel("Using old factors")
        ax.grid(True, linestyle="--", alpha=0.6)
        ax.legend()
    
    # Hide any extra subplots if the number of end-uses is not a perfect multiple of grid_cols
    for i in range(num_end_uses, len(axes)):
        fig.delaxes(axes[i])
    
    plt.tight_layout()
    
    # Save the grid plot
    grid_plot_filename = f"{geo}_{strout}_{energy_or_stock}_{sector}.png"
    plt.savefig(f"./{dir_new}/{grid_plot_filename}")
    # plt.show()


import numpy as np
import matplotlib.pyplot as plt
import json
import matplotlib.pyplot as plt
import pandas as pd


geos = ["EMM","state"]
fueltypes = ["natural gas", "distillate","other fuel"]
energy_or_stocks = ["energy","stock"]
dir_2023 = "test_compare_base_disagg/2023_work"
dir_new = "test_compare_base_disagg/jaredstest"


for geo in geos:
    for fueltype in fueltypes:
        for energy_or_stock in energy_or_stocks:
            strout = f"Non-Elec_{fueltype}_End-Use_EUSS2024_vs_EIA"
            allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states = \
                dataprep_compare_nonelec(geo, strout, energy_or_stock, dir_2023, dir_new)
            
            if fueltype != "other fuel": 
                plot_scatter_all("commercial", allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states, energy_or_stock)
            plot_scatter_all("residential", allbldgs_dfs, resbldgs_dfs, combldgs_dfs, states, energy_or_stock)

### Difference in barplots

In [None]:
def plot_comparison(dfs, sector):
    import numpy as np
    for state in states:
        num_end_uses = len(end_uses_all)
        fig, axes = plt.subplots(nrows=num_end_uses, ncols=1, figsize=(10, 5 * num_end_uses), sharex=True)
    
        # Ensure axes is always iterable
        axes = np.atleast_1d(axes)
    
        for ax, end_use in zip(axes, end_uses_all):
            # Plot both datasets for comparison
            for dataset_index, dataset_label in enumerate(["Using New Factors", "Using Old Factors"]):
                if state in dfs[dataset_index] and end_use in dfs[dataset_index][state].columns:
                    dfs[dataset_index][state][end_use].plot(ax=ax, linestyle='-', label=dataset_label)
    
            ax.set_title(f"{state} - {end_use}")
            ax.set_xlabel("Year")
            if energy_or_stock == "energy":
                ax.set_ylabel("Energy (MWh)")
            else:
                ax.set_ylabel("Stock")
            ax.grid(True)
            ax.legend()  # Add legend to differentiate datasets
    
        plt.tight_layout()
        plot_filename = f"./{file_dir}/state_bars/{energy_or_stock}_{sector}_{state}_energy_plot.png"
        plt.savefig(plot_filename)
        plt.close()
        # plt.show()



plot_comparison(allbldgs_dfs, 'all')
# plot_comparison(resbldgs_dfs, 'res')
# plot_comparison(combldgs_dfs, 'com')

print("COMPLETE!")



In [None]:
# Extract differences between datasets for the year 2024

res_end_uses = ['heating', 'secondary heating', 'cooling', 'fans and pumps', 'ceiling fan', 'lighting', 'water heating',
                'refrigeration', 'cooking', 'drying', 'TVs', 'computers', 'other', 'onsite generation']
com_end_uses = ['heating', 'cooling', 'water heating', 'ventilation', 'cooking', 'lighting', 'refrigeration', 'PCs',
                'non-PC office equipment', 'MELs', 'onsite generation']

year = "2024"
sector = "residential"

if sector == "residential":
    dfs = resbldgs_dfs
    end_uses = res_end_uses
else:
    dfs = combldgs_dfs
    end_uses = com_end_uses

# Initialize a dictionary to store differences
differences = {end_use: {} for end_use in end_uses}

# Compute the difference (Dataset 2 - Dataset 1) for each state and end-use
for end_use in end_uses:
    for state in states:
        if state in dfs[0] and state in dfs[1] and end_use in dfs[0][state].columns and end_use in dfs[1][state].columns:
            # Compute the energy difference for the year 2024
            differences[end_use][state] = dfs[1][state].loc[year, end_use] - dfs[0][state].loc[year, end_use]

# Convert differences to DataFrame for plotting
diff_df = pd.DataFrame(differences)

# Plot the differences
fig, ax = plt.subplots(figsize=(12, 6))
diff_df.plot(kind="bar", ax=ax, width=0.7)

ax.set_title(f"Difference in Energy Use using New and Old Factors ({sector}Year 2024)")
ax.set_xlabel("State")
if energy_or_stock == "energy":
    ax.set_ylabel("Energy Difference (MWh)")
else:
    ax.set_ylabel("Stock")
ax.legend(title="End Use", bbox_to_anchor=(1.05, 1), loc="upper left")
ax.grid(axis="y", linestyle="--")

# Save the plot
diff_plot_filename = f"{energy_or_stock}_{sector}_energy_difference_2024.png"
plt.tight_layout()
plt.savefig(f"./{file_dir}/{diff_plot_filename}")
plt.show()

print("COMPLETE!")

### Scatter end use

In [None]:
import matplotlib.pyplot as plt
res_end_uses = ['heating', 'secondary heating', 'cooling', 'fans and pumps', 'ceiling fan', 'lighting', 'water heating',
                'refrigeration', 'cooking', 'drying', 'TVs', 'computers', 'other', 'onsite generation']
com_end_uses = ['heating', 'cooling', 'water heating', 'ventilation', 'cooking', 'lighting', 'refrigeration', 'PCs',
                'non-PC office equipment', 'MELs', 'onsite generation']

year = "2024"
sector = "commercial"

if sector == "residential":
    dfs = resbldgs_dfs
    end_uses = res_end_uses
else:
    dfs = combldgs_dfs
    end_uses = com_end_uses
# Create scatter plots for each end-use category
for end_use in end_uses:
    x_values = []
    y_values = []
    state_labels = []

    for state in states:
        if state in dfs[0] and state in dfs[1] and end_use in dfs[0][state].columns and end_use in dfs[1][state].columns:
            # Extract values for the year 2024
            x = dfs[0][state].loc[year, end_use]
            y = dfs[1][state].loc[year, end_use]

            x_values.append(x)
            y_values.append(y)
            state_labels.append(state)

    # Create scatter plot
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.scatter(x_values, y_values, marker='o', color='blue', alpha=0.7)
    
    min_val = min(min(x_values), min(y_values))
    max_val = max(max(x_values), max(y_values))
    ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="red", label="y = x")

    # Add state labels to the points
    for i, state in enumerate(state_labels):
        ax.text(x_values[i], y_values[i], state, fontsize=9, ha='right', va='bottom')

    ax.set_title(f"Scatter Plot for {sector} {end_use} (Year 2024)")
    if energy_or_stock == "energy":
        ax.set_xlabel("Using new factors: Energy Use (MWh)")
        ax.set_ylabel("Using old factors:  Energy Use (MWh)")
    else:
        ax.set_xlabel("Using new factors")
        ax.set_ylabel("Using old factors")
    ax.grid(True, linestyle="--", alpha=0.6)

    # Save the scatter plot
    scatter_plot_filename = f"{energy_or_stock}_{sector}_scatter_{end_use}_2024.png"
    plt.savefig(f"./{file_dir}/scatter_enduse/{scatter_plot_filename}")
    plt.show()
print("COMPLETE!")