In [19]:
import json
import os
import polars as pl
import numpy as np
import matplotlib.pyplot as plt

log_path = "cavity_design_log.json"
out_dir = "analysis_plots"
os.makedirs(out_dir, exist_ok=True)

def r_int(x):
    if x is None:
        return None
    return int(np.round(x))

# 1) 读数据并扁平化
with open(log_path, "r", encoding="utf-8") as f:
    data = json.load(f)

rows = []
for cfg in data.values():
    unit = cfg.get("unit_cell", {})
    for h in cfg["design_history"]:
        rows.append({
            "Q": r_int(h["result"]["Q"]),
            "V": r_int(h["result"]["V"]),
            "qv_ratio": r_int(h["result"]["qv_ratio"]),
            "resonance_nm": r_int(h["result"].get("resonance_nm", 0)),
            "period_nm": r_int(h["params"]["period_nm"]),
            "wg_width_nm": r_int(h["params"].get("wg_width_nm", unit.get("wg_width") * 1e9 if unit.get("wg_width") else None)),
            "wg_height_nm": r_int(unit.get("wg_height") * 1e9 if unit.get("wg_height") else None),
            "hole_rx_nm": r_int(h["params"]["hole_rx_nm"]),
            "hole_ry_nm": r_int(h["params"]["hole_ry_nm"]),
            "num_taper_holes": r_int(h["params"]["num_taper_holes"]),
            "num_mirror_holes": r_int(h["params"]["num_mirror_holes"]),
            "taper_type": h["params"].get("taper_type", "quadratic"),
            "min_a_percent": r_int(h["params"]["min_a_percent"]),
            "min_rx_percent": r_int(h["params"].get("min_rx_percent", 100)),
            "min_ry_percent": r_int(h["params"].get("min_ry_percent", 100)),
        })

df = pl.DataFrame(rows)

# 2) 导出 CSV（按 Q 降序）
df_sorted = df.sort("Q", descending=True)
df_sorted.write_csv("cavity_design_all_sorted_by_Q.csv")

# 3) 选择你要画的两个维度
x = "period_nm"
y = "hole_rx_nm"

param_cols = [
    "period_nm","wg_width_nm","wg_height_nm","hole_rx_nm","hole_ry_nm",
    "num_taper_holes","num_mirror_holes","taper_type",
    "min_a_percent","min_rx_percent","min_ry_percent"
]

fixed_cols = [c for c in param_cols if c not in (x, y)]

# 4) 找到“最常见”的固定参数组合
mode_row = (
    df.group_by(fixed_cols)
      .len()
      .sort("len", descending=True)
      .head(1)
)

fixed_values = {c: mode_row[c][0] for c in fixed_cols}
print("Fixed params (most common combo):", fixed_values)

# 5) 过滤到该固定组合
df_fixed = df
for c, v in fixed_values.items():
    df_fixed = df_fixed.filter(pl.col(c) == v)

print("Rows after fixed filter:", df_fixed.height)

# 6) 画 2D heatmap（Q 均值）
def heatmap_pair(df, x, y, value="Q"):
    agg = df.group_by([x, y]).agg(pl.mean(value).alias("mean_q"))
    xs = sorted(agg[x].unique().to_list())
    ys = sorted(agg[y].unique().to_list())

    grid = np.full((len(ys), len(xs)), np.nan)
    for row in agg.iter_rows(named=True):
        xi = xs.index(row[x])
        yi = ys.index(row[y])
        grid[yi, xi] = r_int(row["mean_q"])

    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(grid, origin="lower", aspect="auto")
    ax.set_xticks(range(len(xs)))
    ax.set_xticklabels([str(r_int(v)) for v in xs], rotation=45)
    ax.set_yticks(range(len(ys)))
    ax.set_yticklabels([str(r_int(v)) for v in ys])
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.set_title(f"Mean Q for {x} vs {y}")
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Mean Q")

    # 红色+旋转的整数 Q
    for i in range(len(ys)):
        for j in range(len(xs)):
            if not np.isnan(grid[i, j]):
                ax.text(j, i, f"{int(grid[i, j])}",
                        ha="center", va="center",
                        fontsize=7, color="red", rotation=45)

    fig.tight_layout()
    out_path = os.path.join(out_dir, f"q_mean_{x}_vs_{y}_fixed.png")
    fig.savefig(out_path, dpi=200)
    plt.close(fig)

heatmap_pair(df_fixed, x, y)

print("CSV saved:", "cavity_design_all_sorted_by_Q.csv")
print("Plot saved in:", out_dir)


Fixed params (most common combo): {'wg_width_nm': 450, 'wg_height_nm': 200, 'hole_ry_nm': 120, 'num_taper_holes': 10, 'num_mirror_holes': 7, 'taper_type': 'quadratic', 'min_a_percent': 89, 'min_rx_percent': 100, 'min_ry_percent': 100}
Rows after fixed filter: 57
CSV saved: cavity_design_all_sorted_by_Q.csv
Plot saved in: analysis_plots
