In [None]:
import os
os.chdir("..")

In [None]:
from src.data.data_process import DataReg
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import os

load_dotenv()

az.style.use("arviz-darkgrid")


In [None]:
# Separate lists for each group
local_results = []
foreign_results = []

# First pass: load and rename data
for file in os.listdir("data/processed/"):
    if file == ".gitkeep":
        continue
    results = az.from_netcdf(f"data/processed/{file}")
    parts = file[:-3].split("_")
    
    name = f"Kaits for Naics {parts[2]}"
    name_dict = {'log_k_index': name}
    results.rename(name_dict, inplace=True)

    if parts[1] == "local":
        local_results.append((results, name))
    else:
        foreign_results.append((results, name))

# Figure size based on number of plots
n_local = len(local_results)
n_foreign = len(foreign_results)
n_rows = max(n_local, n_foreign)

fig, axes = plt.subplots(n_rows, 2, figsize=(12, 3 * n_rows), squeeze=False)

# Plot local (column 0)
for i, (idata, varname) in enumerate(local_results):
    az.plot_posterior(idata, var_names=[varname], ax=axes[i, 0])
    axes[i, 0].set_title(f"Local: {varname}")

# Plot foreign (column 1)
for i, (idata, varname) in enumerate(foreign_results):
    az.plot_posterior(idata, var_names=[varname], ax=axes[i, 1])
    axes[i, 1].set_title(f"Foreign: {varname}")

# Hide unused axes
for i in range(n_rows):
    if i >= n_local:
        axes[i, 0].axis('off')
    if i >= n_foreign:
        axes[i, 1].axis('off')

# Overall formatting
fig.suptitle("Posterior Distributions of log_k_index by Group", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()

In [None]:
naics_groups = {}

# Load and group data
for file in os.listdir("data/processed"):
    if not file.endswith(".nc"):
        continue

    filepath = os.path.join("data/processed", file)
    idata = az.from_netcdf(filepath)
    parts = file[:-3].split("_")  # remove .nc, then split

    group = parts[1]   # local or foreign
    naics = parts[2]
    varname = f"Kaits for Naics {naics}"

    idata = idata.rename({'log_k_index': varname})

    if naics not in naics_groups:
        naics_groups[naics] = {}

    naics_groups[naics][group] = idata

# Plot each NAICS code: local vs foreign
for naics, group_data in naics_groups.items():
    data = []
    labels = []
    varname = f"Kaits for Naics {naics}"

    if "local" in group_data:
        data.append(group_data["local"])
        labels.append("Local")

    if "foreign" in group_data:
        data.append(group_data["foreign"])
        labels.append("Foreign")

    # Only plot if we have both groups
    if len(data) == 2:
        axes = az.plot_density(
            data,
            data_labels=labels,
            var_names=[varname],
            shade=0.2
        )
        fig = axes.flatten()[0].get_figure()
        fig.suptitle(f"NAICS {naics}: Local vs Foreign Posterior for {varname}", fontsize=14)

plt.show()


In [None]:
import os
import arviz as az
import matplotlib.pyplot as plt

az.style.use("arviz-doc")

# Group InferenceData objects by NAICS code
naics_groups = {}

for file in os.listdir("data/processed/"):
    if not file.endswith(".nc"):
        continue

    filepath = os.path.join("data/processed/", file)
    print(file)
    idata = az.from_netcdf(filepath)
    parts = file[:-3].split("_")

    group = parts[1]   # local or foreign
    naics = parts[2]
    varname = f"Kaits for Naics {naics}"

    idata = idata.rename({"log_k_index": varname})

    if naics not in naics_groups:
        naics_groups[naics] = {}
    naics_groups[naics][group] = idata

# Keep only NAICS codes with both local and foreign data
valid_naics = [n for n in naics_groups if "local" in naics_groups[n] and "foreign" in naics_groups[n]]

n = len(valid_naics)
n_cols = 2
n_rows = (n + 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 3 * n_rows))
axes = axes.flatten()

legend_handles = None

for idx, naics in enumerate(valid_naics):
    varname = f"Kaits for Naics {naics}"
    local_idata = naics_groups[naics]["local"]
    foreign_idata = naics_groups[naics]["foreign"]

    ax = axes[idx]

    # Plot density without drawing legend
    az.plot_density(
        [local_idata, foreign_idata],
        data_labels=["Local", "Foreign"],
        var_names=[varname],
        colors=["blue", "green"],
        shade=0.2,
        ax=ax,
        show=False
    )

    ax.set_title(f"NAICS {naics}")
    ax.legend_.remove()  # Ensure subplot legend is removed

    # Capture handles for global legend
    if legend_handles is None:
        handles, labels = ax.get_legend_handles_labels()
        legend_handles = (handles, labels)

# Hide unused subplots
for ax in axes[len(valid_naics):]:
    ax.axis("off")

# Add one global legend
if legend_handles:
    fig.legend(
        *legend_handles,
        loc="upper right",
        ncol=2,
        fontsize=12,
        frameon=True
    )

fig.suptitle("Posterior Distributions: Local vs Foreign by NAICS", fontsize=16)
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()


In [None]:
results = az.from_netcdf("data/processed/processed/results_local_31-33.nc")
az.plot_trace(results)

In [None]:
az.summary(results)