In [None]:
import os
import re
os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import pandas as pd
import bayesflow as bf
import keras
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import colorcet as cc

from FyeldGenerator import generate_field
from resnet import ResNetSummary

In [None]:
def generate_power_spectrum(alpha, scale):
    def power_spectrum(k):
        base = np.power(k, -alpha) * scale**2
        return base

    return power_spectrum


def distribution(shape=(128, 128)):
    a = np.random.normal(loc=0, scale=1., size=shape)
    b = np.random.normal(loc=0, scale=1., size=shape)
    return a + 1j * b

In [None]:
shape = (128, 128)
n_examples = 5
alphas = np.linspace(2, 5, n_examples)
spectra = [generate_power_spectrum(alpha, 1) for alpha in alphas]

In [None]:
def plot_distribution(shape=(128, 128)):
    rng = np.random.default_rng(seed=42)
    a = rng.normal(loc=0, scale=1., size=shape)
    b = rng.normal(loc=0, scale=1., size=shape)
    return a + 1j * b
fig, axs = plt.subplots(1, n_examples, figsize=(n_examples * 3, 4))

for power_spectrum, alpha, ax in zip(spectra, alphas, axs):
    field = generate_field(plot_distribution, power_spectrum, shape)
    max_magnitude = np.max(np.abs(field))
    # cc.cm.CET_D1A, cc.cm.coolwarm, "seismic", cc.cm.CET_R3
    ax.imshow(field, cmap=cc.cm.coolwarm, vmin=-max_magnitude, vmax=max_magnitude)
    ax.set_title(f"$\\alpha={alpha:.2f}$")
    ax.set_axis_off()

In [None]:
rng = np.random.default_rng(seed=42)


def prior():
    return {"log_std": rng.normal(scale=0.3), "alpha": rng.normal(loc=3, scale=0.5)}


def likelihood(log_std, alpha):
    field = generate_field(
        distribution, generate_power_spectrum(alpha, np.exp(log_std)), shape
    )

    return {"field": field[..., None]}


simulator = bf.make_simulator([prior, likelihood])

## Training

In [None]:
training_data = simulator.sample(5000)
validation_data = simulator.sample(500)

In [None]:
summary_network = ResNetSummary(
    summary_dim=8, 
    widths=[16, 32],
    use_batch_norm=False,
    dropout=0.0
)

inference_network = bf.networks.FlowMatching(subnet_kwargs={"widths": 3*(32,), "dropout": 0.0})

workflow = bf.workflows.BasicWorkflow(
    #simulator=simulator,
    summary_network=summary_network,
    inference_network=inference_network,
    inference_variables=["log_std", "alpha"],
    summary_variables=["field"],
    standardize="summary_variables",
)

In [None]:
workflow.approximator.summary()

In [None]:
history = workflow.fit_offline(
    data=training_data,
    epochs=100,
    validation_data=validation_data,
    batch_size=32,
)

In [None]:
f = bf.diagnostics.plots.loss(history)

In [None]:
small_training_data = {k: v[:100] for k,v in training_data.items()}
test_data = validation_data
workflow.plot_custom_diagnostics(
    test_data=test_data,
    plot_fns={
        "recovery": bf.diagnostics.recovery,
        "calibration": bf.diagnostics.calibration_ecdf,
    },
)

In [None]:
workflow.approximator.summary()

# Evaluations

In [None]:
checkpoint_path = "flow_matching/NPE/checkpoints/8_shape_config_8_16.keras"
model = keras.saving.load_model(checkpoint_path)
model.summary()

In [None]:
models = [
    "consistency_model",
    "diffusion_edm_vp",
    "flow_matching",
]
scales = [2**n for n in range(3, 9)]
target = "NPE"
model_configs = ["8_16", "32_64_128_256"]
checkpoint_paths = [
    f"{model}/{target}/checkpoints/{scale}_shape_config_{scale}.keras"
    for model in models
    for scale in scales
]
print(checkpoint_paths)
numbers_paths = [
    f"{model}/{target}/numbers_{mode}_{scale}_shape_config_{scale}.npz"
    for model in models
    for scale in scales
    for mode in ["train", "validation"]
]
print(numbers_paths)

In [None]:
z = np.load(numbers_paths[0], allow_pickle=True)
for k, v in z.items():
    print(k, v)

In [None]:
# numbers_paths = [...]  # your list from above
_rx = re.compile(r"numbers_(train|validation)_(\d+)_", re.IGNORECASE)

def parse_model_mode_scale(p: str):
    model = p.split("/", 1)[0]
    m = _rx.search(os.path.basename(p))
    if not m:
        raise ValueError(f"Could not parse mode/scale from {p}")
    mode = m.group(1).lower()
    scale = int(m.group(2))
    return model, mode, scale

def row_from_npz(p):
    z = np.load(p, allow_pickle=True)
    model, mode, scale = parse_model_mode_scale(p)
    row = {"model": model, "mode": mode, "scale": scale}
    # collect metrics for both variables by name (alpha, log_std)
    for metric in ("nrmse", "ce", "clg"):
        vals = z[f"{metric}_values"]
        names = z[f"{metric}_names"]
        for name, val in zip(names.tolist(), vals.tolist()):
            row[f"{metric}_{name}"] = float(val)
    return row

rows = [row_from_npz(p) for p in numbers_paths]
df = pd.DataFrame(rows)


# optional: rename clg_* -> log_gamma_* for readability
df = df.rename(columns=lambda c: c.replace("clg_", "log_gamma_"))

wanted_cols = [
    "model", "mode", "scale",
    "nrmse_alpha", "nrmse_log_std",
    "ce_alpha", "ce_log_std",
    "log_gamma_alpha", "log_gamma_log_std",
]
# Some files may miss a metric â†’ fill missing columns with NaN
for c in wanted_cols:
    if c not in df.columns:
        df[c] = np.nan

# 4) nice sort (Train first, then Validation, by scale)
mode_order = {"train": 0, "validation": 1}
df["_mode_order"] = df["mode"].map(mode_order).fillna(99)
df = (df[wanted_cols]
      .assign(_mode_order=df["_mode_order"])
      .sort_values(["model", "_mode_order", "scale"])
      .drop(columns="_mode_order")
      .reset_index(drop=True))

print(df)

In [None]:
# df columns: model, mode, scale,
#   nrmse_alpha, nrmse_log_std, ce_alpha, ce_log_std, log_gamma_alpha, log_gamma_log_std

# equally spaced x positions, tick labels are the actual scales
scales = sorted(df["scale"].unique())
pos = np.arange(len(scales))  # 0..n-1 equally spaced

# consistent color per model
models = list(sorted(df["model"].unique()))
cycle_colors = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["C0","C1","C2","C3"])
color_map = {m: cycle_colors[i % len(cycle_colors)] for i, m in enumerate(models)}

# linestyle per mode
style_map = {"train": "-.", "validation": "-"}

def plot_metric(ax, col, title):
    for m in models:
        for md in ("validation", "train"):  # order so solid lines (validation) draw first
            sub = (
                df[(df["model"] == m) & (df["mode"] == md)]
                .set_index("scale")
                .reindex(scales)      # align to all scales
                .sort_index()
            )
            ax.plot(
                pos, sub[col].values,
                marker="o",
                linestyle=style_map.get(md, "-"),
                color=color_map[m],
                label=f"{m} ({md})"
            )
    ax.set_title(title)
    ax.set_xlabel("scale")
    ax.set_ylabel("metric")
    ax.set_xticks(pos)
    ax.set_xticklabels([str(s) for s in scales])
    ax.grid(True, alpha=0.3)

fig, axes = plt.subplots(2, 3, figsize=(12, 6), sharex=True)

# top row: alpha
plot_metric(axes[0, 0], "nrmse_alpha",       "NRMSE (alpha)")
plot_metric(axes[0, 1], "ce_alpha",          "Calibration Error (alpha)")
plot_metric(axes[0, 2], "log_gamma_alpha",   "Log Gamma (alpha)")

# bottom row: log_std
plot_metric(axes[1, 0], "nrmse_log_std",     "NRMSE (log_std)")
plot_metric(axes[1, 1], "ce_log_std",        "Calibration Error (log_std)")
plot_metric(axes[1, 2], "log_gamma_log_std", "Log Gamma (log_std)")

# legends: one for models (colors), one for modes (linestyles)
model_handles = [Line2D([0], [0], color=color_map[m], lw=2, label=m) for m in models]
mode_handles  = [Line2D([0], [0], color="black", lw=2, linestyle=style_map[k], label=k)
                 for k in ("validation", "train")]

axes[0, 0].legend(handles=model_handles, title="model", loc="best")
axes[1, 0].legend(handles=mode_handles,  title="mode",  loc="best")

plt.tight_layout()
plt.show()
