In [None]:
import os
os.chdir("..")

In [None]:

import arviz as az
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import os

load_dotenv()

az.style.use("arviz-darkgrid")


In [None]:
import os
import arviz as az
import matplotlib.pyplot as plt

az.style.use("arviz-doc")

# Group InferenceData objects by NAICS code, then by FIPS
naics_groups = {}

for file in os.listdir("data/processed/"):
    if not file.endswith(".nc"):
        continue

    filepath = os.path.join("data/processed/", file)
    print(file)
    idata = az.from_netcdf(filepath)
    parts = file[:-3].split("_")

    fips = parts[1]    # FIPS code
    naics = parts[2]   # NAICS code
    varname = f"Kaits for Naics {naics}"

    idata = idata.rename({"log_k_index": varname})

    if naics not in naics_groups:
        naics_groups[naics] = {}
    naics_groups[naics][fips] = idata

# Keep only NAICS codes with more than one FIPS group
valid_naics = [n for n in naics_groups if len(naics_groups[n]) > 1]

n = len(valid_naics)
n_cols = 2
n_rows = (n + 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 3 * n_rows))
axes = axes.flatten()

legend_handles = None

for idx, naics in enumerate(valid_naics):
    varname = f"Kaits for Naics {naics}"
    fips_groups = naics_groups[naics]

    idatas = [fips_groups[fips] for fips in sorted(fips_groups)]
    labels = [f"FIPS {fips}" for fips in sorted(fips_groups)]

    ax = axes[idx]

    az.plot_density(
        idatas,
        data_labels=labels,
        var_names=[varname],
        shade=0.2,
        ax=ax,
        show=False
    )

    ax.set_title(f"NAICS {naics}")
    ax.legend_.remove()  # Remove subplot legend

    if legend_handles is None:
        handles, labels = ax.get_legend_handles_labels()
        legend_handles = (handles, labels)

# Hide unused subplots
for ax in axes[len(valid_naics):]:
    ax.axis("off")

# Global legend
if legend_handles:
    fig.legend(
        *legend_handles,
        loc="upper right",
        ncol=2,
        fontsize=12,
        frameon=True
    )

fig.suptitle("Posterior Distributions by FIPS Code per NAICS", fontsize=16)
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()
