In [None]:
import os
os.chdir("..")

In [None]:
from src.data.data_process import DataReg
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import os

load_dotenv()

az.style.use("arviz-darkgrid")


In [None]:
import os
import arviz as az
import matplotlib.pyplot as plt

az.style.use("arviz-doc")

# Group InferenceData objects by NAICS code
naics_groups = {}

for file in os.listdir("data/processed/"):
    if not file.endswith(".nc"):
        continue

    filepath = os.path.join("data/processed/", file)
    print(file)
    idata = az.from_netcdf(filepath)
    parts = file[:-3].split("_")

    group = parts[1]   # local or foreign
    naics = parts[2]
    varname = f"Kaits for Naics {naics}"

    idata = idata.rename({"log_k_index": varname})

    if naics not in naics_groups:
        naics_groups[naics] = {}
    naics_groups[naics][group] = idata

# Keep only NAICS codes with both local and foreign data
valid_naics = [n for n in naics_groups if "local" in naics_groups[n] and "foreign" in naics_groups[n]]

n = len(valid_naics)
n_cols = 2
n_rows = (n + 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 3 * n_rows))
axes = axes.flatten()

legend_handles = None

for idx, naics in enumerate(valid_naics):
    varname = f"Kaits for Naics {naics}"
    local_idata = naics_groups[naics]["local"]
    foreign_idata = naics_groups[naics]["foreign"]

    ax = axes[idx]

    # Plot density without drawing legend
    az.plot_density(
        [local_idata, foreign_idata],
        data_labels=["Local", "Foreign"],
        var_names=[varname],
        shade=0.2,
        ax=ax,
        show=False
    )

    ax.set_title(f"NAICS {naics}")
    ax.legend_.remove()  # Ensure subplot legend is removed

    # Capture handles for global legend
    if legend_handles is None:
        handles, labels = ax.get_legend_handles_labels()
        legend_handles = (handles, labels)

# Hide unused subplots
for ax in axes[len(valid_naics):]:
    ax.axis("off")

# Add one global legend
if legend_handles:
    fig.legend(
        *legend_handles,
        loc="upper right",
        ncol=2,
        fontsize=12,
        frameon=True
    )

fig.suptitle("Posterior Distributions: Local vs Foreign by NAICS", fontsize=16)
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()
