In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import minimize, LinearConstraint, Bounds
from colour.colorimetry import SpectralDistribution
from colour.utilities import as_float_array
# colour-0.4.4
from colour.quality import colour_fidelity_index_ANSIIESTM3018, ColourQuality_Specification_ANSIIESTM3018
from colour.colorimetry import sd_to_XYZ
from colour.colorimetry import LMS_ConeFundamentals
from colour.hints import ArrayLike
import matplotlib.pyplot as plt
%matplotlib inline

#问题一
df = pd.read_excel(r'C:\Users\23163\Desktop\数模大赛\\Problem 1\Problem 1.xlsx')
df['wavelength'] = df['波长'].str.extract(r'(\d+)').astype(float)
wls = df['wavelength'].values
spd = df['光强'].values

df_cmf = pd.read_csv(r'C:\Users\23163\Desktop\数模大赛\\Problem 1\ciexyzjv.csv', header=None,
                     names=['wavelength','xbar','ybar','zbar'])
df_cmf = df_cmf[df_cmf['wavelength'].isin(wls)].sort_values('wavelength')
wls_cmf = df_cmf['wavelength'].values
xbar = df_cmf['xbar'].values
ybar = df_cmf['ybar'].values
zbar = df_cmf['zbar'].values

xbar_i = np.interp(wls, wls_cmf, xbar)
ybar_i = np.interp(wls, wls_cmf, ybar)
zbar_i = np.interp(wls, wls_cmf, zbar)
X = np.trapz(spd * xbar_i, wls)
Y = np.trapz(spd * ybar_i, wls)
Z = np.trapz(spd * zbar_i, wls)

sum_xyz = X + Y + Z
x = X / sum_xyz
y = Y / sum_xyz
n = (x - 0.3320) / (y - 0.1858)

CCT_mccamy = -437 * n**3 + 3601 * n**2 - 6861 * n + 5514.31
print(x, y)
print(f"Computed CCT (McCamy) = {CCT_mccamy:.2f} K")

up = 4 * X / (X + 15 * Y + 3 * Z)
vp = 9 * Y / (X + 15 * Y + 3 * Z)

def u_bb(T):
    return (0.860117757 + 1.54118254e-4 * T + 1.28641212e-7 * T**2) / \
           (1 + 8.42420235e-4 * T + 7.08145163e-7 * T**2)

def v_bb(T):
    return (0.317398726 + 4.22806245e-5 * T + 4.20481691e-8 * T**2) / \
           (1 - 2.89741816e-5 * T + 1.61456053e-7 * T**2)

ubb_val = u_bb(CCT_mccamy)       # CIE1960 u
vbb_val = v_bb(CCT_mccamy)       # CIE1960 v
up_bb = ubb_val                  # CIE1976 u' 等于 CIE1960 u
vp_bb = 1.5 * vbb_val            # CIE1976 v' = (9/6)*v

Duv = np.sqrt((up - up_bb)**2 + (vp - vp_bb)**2)
print(f"Duv (unsigned) = {Duv:.4f}")


def tm30_rf_rg(wl, spd):
    try:
        import colour
        from colour import SpectralDistribution
        sd = SpectralDistribution(dict(zip(wl, spd)), name="Sample")
        try:
            from colour.quality.tm30 import tm30
            spec = tm30(sd)
        except Exception:
            try:
                from colour.quality.tm30 import TM30_Specification, tm30
                spec = tm30(sd)
            except Exception:
                from colour.quality.tm3018 import tm3018
                spec = tm3018(sd)
        Rf = getattr(spec, "R_f", None) or getattr(spec, "Rf", None)
        Rg = getattr(spec, "R_g", None) or getattr(spec, "Rg", None)
        return float(Rf), float(Rg)
    except Exception:
        return None, None

def mel_der(wl, spd, s026_csv=S026_MEL_CSV):
    xbar, ybar, zbar = cmf_from_csv_or_gaussian(wl, cmf_csv=CMF_CSV)
    V = ybar

    # 优先 colour（若可用）
    try:
        import colour
        from colour import SDS_ILLUMINANTS, SpectralShape
        # D65 SPD
        D65 = SDS_ILLUMINANTS["D65"].copy().align(SpectralShape(int(wl.min()), int(wl.max()), 1)).values
        # melanopsin S026
        Smel = None
        try:
            from colour import SDS_PHOTO_SENSITIVITIES
            Smel = SDS_PHOTO_SENSITIVITIES["S 026-2018 Melanopsin 2 Degree"].copy().align(
                SpectralShape(int(wl.min()), int(wl.max()), 1)
            ).values
        except Exception:
            from colour import SDS_BIOLOGICAL_ACTION_SPECTRA
            Smel = SDS_BIOLOGICAL_ACTION_SPECTRA["CIE S 026/E:2018 Melanopsin 2 Degree"].copy().align(
                SpectralShape(int(wl.min()), int(wl.max()), 1)
            ).values

        dl = np.gradient(wl)
        m_s = float(np.sum(spd * Smel * dl)); p_s = float(np.sum(spd * V * dl))
        m_d = float(np.sum(D65 * Smel * dl)); p_d = float(np.sum(D65 * V * dl))
        return (m_s/p_s) / (m_d/p_d)
    except Exception:
        # 没有 colour：尝试 CSV 的 S026
        if s026_csv and os.path.exists(s026_csv):
            dfm = pd.read_csv(s026_csv)
            Smel = np.interp(wl, dfm.iloc[:,0].to_numpy(), dfm.iloc[:,1].to_numpy(), left=0, right=0)
            # 若没有 D65，只返回 sample 的 melanopic/photopic 比值（相对量）
            dl = np.gradient(wl)
            m_s = float(np.sum(spd * Smel * dl)); p_s = float(np.sum(spd * V * dl))
            return (m_s/p_s)
        else:
            return None

plt.figure(figsize=(8, 4))
plt.plot(wls, spd, linewidth=2)
plt.xlabel('Wavelength (nm)')
plt.ylabel('Relative Power')
plt.title('Original spectral distribution(SPD)')
plt.grid(True, linestyle='--', alpha=0.5)
plt.xlim(wls.min(), wls.max())
plt.tight_layout()
plt.show()

#问题二
def compute_tm30_rf_rg(sd_test: SpectralDistribution):
    """
    返回 Rf, Rg
    """
    spec: ColourQuality_Specification_ANSIIESTM3018 = \
        colour_fidelity_index_ANSIIESTM3018(sd_test, additional_data=True)
    return spec.R_f, spec.R_g, spec.CCT, spec.D_uv


# --- mel-DER 计算（基于 CIE S-026 ipRGC 灵敏度） ---
def compute_mel_der(sd_test: SpectralDistribution, sd_ref: SpectralDistribution):
    import colour

    # 获取波长范围
    wavelengths = sd_test.wavelengths
    shape = colour.SpectralShape(wavelengths[0], wavelengths[-1], 1)

    # 获取 melanopic 灵敏度
    try:
        mel_sd = colour.biochemistry.SDS_PHOTORECEPTOR_SENSITIVITIES['Melanopic']
        mel_sd = mel_sd.copy().align(shape)
        mel_sens = mel_sd.values
    except Exception:
        # 如果无法获取，使用高斯近似
        wl_grid = np.array(wavelengths)
        mel_sens = np.exp(-0.5 * ((wl_grid - 490) / 13) ** 2)

    # 获取 V(lambda) 光度权重
    cmfs = colour.MSDS_CMFS['CIE 1931 2 Degree Standard Observer'].copy().align(shape)
    v_lambda = cmfs.values[:, 1]

    # 对齐参考光谱到测试光谱的波长范围
    sd_r = sd_ref.copy().align(shape)

    # 计算 mel-DER
    mel_t = np.trapz(sd_test.values * mel_sens, wavelengths)
    mel_r = np.trapz(sd_r.values * mel_sens, wavelengths)
    v_t = np.trapz(sd_test.values * v_lambda, wavelengths)
    v_r = np.trapz(sd_r.values * v_lambda, wavelengths)

    # mel-DER = (mel_t / v_t) / (mel_r / v_r)
    return float((mel_t / v_t) / (mel_r / v_r))


# --- 读取五通道 SPD ---
def load_channels(path: str, sheets=5) -> (np.ndarray, list[SpectralDistribution]):
    df = pd.read_excel(path, sheet_name=None)
    # 取公共波长
    wls = None
    sds = []
    for name in list(df.keys())[:sheets]:
        d = df[name]
        # 提取波长数值（去掉单位）
        wavelength_str = d['波长'].astype(str)
        wls_numeric = wavelength_str.str.extract(r'(\d+)')[0].astype(float)

        if wls is None:
            wls = wls_numeric.values
        else:
            assert np.allclose(wls, wls_numeric.values), \
                "五路通道波长不一致!"
        # 获取 SPD 列名（除了波长列）
        spd_cols = [col for col in d.columns if col != '波长']
        for col in spd_cols:
            sd = SpectralDistribution(dict(zip(wls, d[col].values)), name=col)
            sds.append(sd)
    return wls, sds


# --- 合成 SPD ---
def mix_spd(weights: ArrayLike, sds: list[SpectralDistribution]) -> SpectralDistribution:
    w = np.array(weights)
    # 确保权重和光谱分布数量匹配
    assert len(w) == len(sds), f"权重数量 {len(w)} 与光谱分布数量 {len(sds)} 不匹配"

    # 获取波长范围
    domain = sds[0].domain

    # 手动计算混合光谱
    mixed_values = np.zeros_like(sds[0].values)
    for i in range(len(sds)):
        mixed_values += w[i] * sds[i].values

    # 创建新的 SpectralDistribution
    sd_mix = SpectralDistribution(dict(zip(domain, mixed_values)), name="Mix")
    return sd_mix


# --- 优化：日间模式 ---
def optimize_day(sds: list[SpectralDistribution], sd_ref: SpectralDistribution):
    # 目标：最大化 Rf → minimize -Rf
    def obj(w):
        sd = mix_spd(w, sds)
        Rf, Rg, CCT, Duv = compute_tm30_rf_rg(sd)
        # 惩罚：Rg 不在 [95,105] 强制离域惩罚
        pen = 0.0
        if not (95 <= Rg <= 105):
            pen += abs(Rg - np.clip(Rg, 95, 105)) * 5
        return -Rf + pen

    # 非线性约束：CCT ∈ [6000,7000]
    def constr_CCT_low(w):
        sd = mix_spd(w, sds)
        return compute_tm30_rf_rg(sd)[2] - 5500

    def constr_CCT_high(w):
        sd = mix_spd(w, sds)
        return 6500 - compute_tm30_rf_rg(sd)[2]

    n = len(sds)
    x0 = np.ones(n) / n
    bounds = Bounds(0, 1)
    lincon = LinearConstraint(np.ones((1, n)), [1], [1])  # 权重和=1
    cons = [
        {'type': 'ineq', 'fun': constr_CCT_low},
        {'type': 'ineq', 'fun': constr_CCT_high}
    ]

    res = minimize(obj, x0, method='SLSQP', bounds=bounds,
                   constraints=[lincon, *cons], options={'ftol': 1e-6})
    w_opt = res.x
    sd_opt = mix_spd(w_opt, sds)
    Rf, Rg, CCT, Duv = compute_tm30_rf_rg(sd_opt)
    mel = compute_mel_der(sd_opt, sd_ref)
    return w_opt, Rf, Rg, CCT, Duv, mel


# --- 优化：夜间模式 ---
def optimize_night(sds: list[SpectralDistribution], sd_ref: SpectralDistribution):
    # 目标：最小化 mel-DER
    def obj(w):
        sd = mix_spd(w, sds)
        mel = compute_mel_der(sd, sd_ref)
        Rf, _, CCT, _ = compute_tm30_rf_rg(sd)
        # 惩罚：Rf < 80
        pen = 0.0
        if Rf < 80:
            pen += (80 - Rf) * 5
        return mel + pen

    def constr_CCT_low(w):
        sd = mix_spd(w, sds)
        return compute_tm30_rf_rg(sd)[2] - 2500

    def constr_CCT_high(w):
        sd = mix_spd(w, sds)
        return 3500 - compute_tm30_rf_rg(sd)[2]

    n = len(sds)
    x0 = np.ones(n) / n
    bounds = Bounds(0, 1)
    lincon = LinearConstraint(np.ones((1, n)), [1], [1])
    cons = [
        {'type': 'ineq', 'fun': constr_CCT_low},
        {'type': 'ineq', 'fun': constr_CCT_high}
    ]
    res = minimize(obj, x0, method='SLSQP', bounds=bounds,
                   constraints=[lincon, *cons], options={'ftol': 1e-6})
    w_opt = res.x
    sd_opt = mix_spd(w_opt, sds)
    Rf, Rg, CCT, Duv = compute_tm30_rf_rg(sd_opt)
    mel = compute_mel_der(sd_opt, sd_ref)
    return w_opt, Rf, Rg, CCT, Duv, mel


# --- 主流程 ---
if __name__ == "__main__":
    path = "Problem 2.xlsx"
    wls, sds = load_channels(path)
    # 参考光源使用 CIE D65
    from colour.colorimetry import SDS_ILLUMINANTS

    sd_ref = SDS_ILLUMINANTS["D65"]

    print("=== 日间模式 (CCT ∈ [5500,6500]) 最优化 ===")
    wd, Rf_d, Rg_d, CCT_d, Duv_d, mel_d = optimize_day(sds, sd_ref)
    print(f"权重分配:")
    print(f"  Blue: {wd[0]:.4f} ({wd[0]*100:.2f}%)")
    print(f"  Green: {wd[1]:.4f} ({wd[1]*100:.2f}%)")
    print(f"  Red: {wd[2]:.4f} ({wd[2]*100:.2f}%)")
    print(f"  Warm White: {wd[3]:.4f} ({wd[3]*100:.2f}%)")
    print(f"  Cold White: {wd[4]:.4f} ({wd[4]*100:.2f}%)")
    print(f"\n性能指标:")
    print(f"  CCT (相关色温): {CCT_d:.1f} K")
    print(f"  Duv (色度偏移): {Duv_d:.4f}")
    print(f"  Rf (色彩保真度指数): {Rf_d:.2f}")
    print(f"  Rg (色彩饱和度指数): {Rg_d:.2f}")
    print(f"  mel-DER (褪黑素抑制等效比): {mel_d:.3f}")
    print(f"\n权重和: {np.sum(wd):.6f}")

    print("\n=== 夜间模式 (CCT ∈ [2500,3500]) 最优化 ===")
    wn, Rf_n, Rg_n, CCT_n, Duv_n, mel_n = optimize_night(sds, sd_ref)
    print(f"权重分配:")
    print(f"  Blue: {wn[0]:.4f} ({wn[0]*100:.2f}%)")
    print(f"  Green: {wn[1]:.4f} ({wn[1]*100:.2f}%)")
    print(f"  Red: {wn[2]:.4f} ({wn[2]*100:.2f}%)")
    print(f"  Warm White: {wn[3]:.4f} ({wn[3]*100:.2f}%)")
    print(f"  Cold White: {wn[4]:.4f} ({wn[4]*100:.2f}%)")
    print(f"\n性能指标:")
    print(f"  CCT (相关色温): {CCT_n:.1f} K")
    print(f"  Duv (色度偏移): {Duv_n:.4f}")
    print(f"  Rf (色彩保真度指数): {Rf_n:.2f}")
    print(f"  Rg (色彩饱和度指数): {Rg_n:.2f}")
    print(f"  mel-DER (褪黑素抑制等效比): {mel_n:.3f}")
    print(f"\n权重和: {np.sum(wn):.6f}")

#问题三
SUN_XLSX = Path("Problem 3.xlsx")   # 含 SUN_SPD
LED_XLSX = Path("Problem 2.xlsx")   # 含 5 通道 LED SPD
BETA = 0.15        # 时间平滑权重（越大越平滑）
WMAX = 1.0         # 单通道权重上限（按需要可改）
USE_GLOBAL_SCALE = True  # 是否在每个时刻做一个标量匹配(减少亮度偏移)
REP_TIMES = ["08:30", "12:00", "19:00"]  # 代表时刻绘图标签

def first_number(s: str):
    m = re.search(r'[-+]?\d*\.?\d+', str(s))
    return float(m.group()) if m else np.nan

def find_wl_col(df: pd.DataFrame):
    for c in df.columns:
        if re.search(r'wave|λ|lambda|波长', str(c), flags=re.IGNORECASE):
            return c
    return df.columns[0]

def prep_sun_matrix(xlsx: Path):
    book = pd.read_excel(xlsx, sheet_name=None)
    # 找到含 SUN 的sheet（否则取第一个）
    sheet = None
    for name in book:
        if re.search(r'sun', name, re.I): sheet = name; break
    if sheet is None: sheet = list(book.keys())[0]
    df = book[sheet].copy()
    df.columns = [str(c).strip() for c in df.columns]
    wl_col = find_wl_col(df)
    df[wl_col] = df[wl_col].map(first_number)
    df = df.dropna(subset=[wl_col]).sort_values(wl_col).reset_index(drop=True)
    sun_cols = [c for c in df.columns if c != wl_col]
    wl = df[wl_col].to_numpy()
    S = df[sun_cols].apply(pd.to_numeric, errors="coerce").fillna(0.0).to_numpy()
    return wl, sun_cols, S

def prep_led_basis(xlsx: Path, target_wl: np.ndarray):
    book = pd.read_excel(xlsx, sheet_name=None)
    sheet = list(book.keys())[0]  # 你的 Problem 2.xlsx 的 Sheet1 即为 LED SPD
    df = book[sheet].copy()
    df.columns = [str(c).strip() for c in df.columns]
    wl_col = find_wl_col(df)
    df[wl_col] = df[wl_col].map(first_number)
    df = df.dropna(subset=[wl_col]).sort_values(wl_col).reset_index(drop=True)
    led_cols = [c for c in df.columns if c != wl_col]
    L = np.zeros((len(target_wl), len(led_cols)))
    for j, c in enumerate(led_cols):
        spd = pd.to_numeric(df[c], errors="coerce").fillna(0.0).to_numpy()
        L[:, j] = np.interp(target_wl, df[wl_col].to_numpy(), spd, left=0.0, right=0.0)
    L = np.clip(L, 0, None)
    return led_cols, L

def solve_weights(S, L, beta=BETA, wmax=WMAX, use_global_scale=USE_GLOBAL_SCALE):
    """逐时刻：min ||Lw - s||^2 + beta||w - w_prev||^2, s.t. 0<=w<=wmax"""
    k = L.shape[1]
    T = S.shape[1]
    W = np.zeros((k, T))
    R = np.zeros_like(S)
    prev = np.zeros(k)
    I = np.eye(k)
    rt = sqrt(beta)
    for t in range(T):
        s = S[:, t]
        s_scaled = s / max(s.mean(), 1e-8) if use_global_scale else s
        A = np.vstack([L, rt*I])
        b = np.hstack([s_scaled, rt*prev])
        res = lsq_linear(A, b, bounds=(0, wmax), max_iter=200, lsmr_tol='auto')
        w = res.x
        W[:, t] = w
        R[:, t] = L @ w
        prev = w
    # 标量匹配回原亮度
    scales = np.ones(T)
    if use_global_scale:
        for t in range(T):
            r, s = R[:, t], S[:, t]
            a = (r @ s) / (r @ r + 1e-12)
            R[:, t] = r * a
            scales[t] = a
    return W, R, scales

def pick_idx_by_label(labels, want):
    for i,lbl in enumerate(labels):
        if want in str(lbl): return i
    if want=="08:30": return 0
    if want=="12:00": return len(labels)//2
    return len(labels)-1

# --------------------
# 主流程
# --------------------
def main():
    assert SUN_XLSX.exists(), f"缺少 {SUN_XLSX}"
    assert LED_XLSX.exists(), f"缺少 {LED_XLSX}"

    wl, sun_cols, S = prep_sun_matrix(SUN_XLSX)
    led_cols, L = prep_led_basis(LED_XLSX, wl)

    W, R, scales = solve_weights(S, L)

    # 导出
    pd.DataFrame(W.T, columns=led_cols, index=sun_cols).to_csv("weights.csv", encoding="utf-8-sig")
    recon_df = pd.DataFrame(R, columns=sun_cols); recon_df.insert(0, "wavelength", wl)
    recon_df.to_csv("reconstructed_SPD.csv", index=False, encoding="utf-8-sig")
    rmse = np.sqrt(((R - S) ** 2).mean(axis=0))
    pd.DataFrame({"time": sun_cols, "rmse": rmse, "scale": scales}).to_csv("metrics.csv", index=False, encoding="utf-8-sig")

    # 画权重轨迹
    plt.figure()
    for j,c in enumerate(led_cols):
        plt.plot(W[j,:], label=c)
    plt.xlabel("time index"); plt.ylabel("channel weight")
    plt.title("LED channel weights over time")
    plt.legend(); plt.tight_layout(); plt.savefig("weights_plot.png", dpi=160); plt.close()

    # 三个代表时刻光谱对比
    for tag in REP_TIMES:
        idx = pick_idx_by_label(sun_cols, tag)
        plt.figure()
        plt.plot(wl, S[:, idx], label=f"SUN {sun_cols[idx]}")
        plt.plot(wl, R[:, idx], label="LED recon")
        plt.xlabel("Wavelength (nm)"); plt.ylabel("SPD (a.u.)")
        plt.title(f"Spectrum comparison @ {tag}")
        plt.legend(); plt.tight_layout()
        plt.savefig(f"compare_{tag.replace(':','')}.png", dpi=160); plt.close()

    print("完成：weights.csv, reconstructed_SPD.csv, metrics.csv, weights_plot.png, compare_*.png")

if __name__ == "__main__":
    main()
    
#问题四w
import pandas as pd, numpy as np, matplotlib.pyplot as plt
from pathlib import Path

# === 1) 读取与整形 ===
book = pd.read_excel(Path("Problem 4.xlsx"), sheet_name=None)
dfw = book["Sheet1"].copy()
night_labels = dfw.iloc[0].tolist()

multi_cols, subj_id = [], 0
for i,col in enumerate(dfw.columns):
    if "被试" in str(col):
        subj_id += 1; night = "Night 1"
    else:
        night = str(night_labels[i])
    multi_cols.append((f"S{subj_id}", night))

dfw = dfw.iloc[1:].reset_index(drop=True)
dfw.columns = pd.MultiIndex.from_tuples(multi_cols, names=["subject","night"])
long = dfw.stack(level=[0,1]).reset_index()
long.columns = ["epoch","subject","night","code"]
long["epoch"] = long.groupby(["subject","night"]).cumcount()
long["code"] = pd.to_numeric(long["code"], errors="coerce").astype("Int64")
long = long.dropna(subset=["code"])
def map_state(c):
    if c==4: return "Wake"
    if c==5: return "REM"
    if c==3: return "N3"
    return "N1N2"                      # 1/2 -> N1N2
long["state"] = long["code"].map(map_state)

# 你可在这里改 Night→条件 的映射
cond_map = {"Night 1":"A", "Night 2":"B", "Night 3":"C"}
long["condition"] = long["night"].map(cond_map)

# === 2) 经验转移矩阵（按条件） ===
states = ["Wake","N1N2","N3","REM"]
long = long.sort_values(["subject","night","epoch"])
long["state_next"] = long.groupby(["subject","night"])["state"].shift(-1)
trans = long.dropna(subset=["state_next"]).copy()

def empirical_P(cond):
    sub = trans[trans["condition"]==cond]
    P = np.zeros((4,4))
    for i,s in enumerate(states):
        counts = sub[sub["state"]==s]["state_next"].value_counts().reindex(states).fillna(0).values + 1e-6
        P[i,:] = counts / counts.sum()
    return P

# === 3) 模拟 hypnogram 并计算指标 ===
def simulate_from_P(P, n_epochs):
    x = 0  # start Wake
    traj = [states[x]]
    rng = np.random.default_rng(0)
    for _ in range(n_epochs-1):
        x = rng.choice(len(states), p=P[x])
        traj.append(states[x])
    return traj

def metrics(arr):
    arr = np.asarray(arr)
    sleep = arr!="Wake"
    T_total = len(arr)*0.5/60
    TST = sleep.sum()*0.5/60
    SE = 100*TST/T_total
    SOL = (np.argmax(sleep)*0.5) if sleep.any() else np.nan
    N3p = 100*(arr=="N3").sum()/max(1, sleep.sum())
    REMp = 100*(arr=="REM").sum()/max(1, sleep.sum())
    Aw = ((arr[:-1]!="Wake") & (arr[1:]=="Wake")).sum()
    return dict(TST=TST, SE=SE, SOL=SOL, N3p=N3p, REMp=REMp, Awakenings=Aw)

n_epochs = long.groupby(["subject","night"])["epoch"].max().median().astype(int)+1
rows=[]
for cond in ["A","B","C"]:
    P = empirical_P(cond)
    traj = simulate_from_P(P, n_epochs)
    rows.append(dict(condition=cond, **metrics(traj)))
pd.DataFrame(rows).to_csv("simulated_metrics.csv", index=False, encoding="utf-8-sig")

# === 4) 观测指标（每被试×每夜） ===
obs=[]
for (subj, night), g in long.groupby(["subject","night"]):
    arr = g.sort_values("epoch")["state"].to_numpy()
    obs.append(dict(subject=subj, condition=cond_map[night], **metrics(arr)))
pd.DataFrame(obs).to_csv("observed_metrics.csv", index=False, encoding="utf-8-sig")

# === 5) 可视化 ===
for cond in ["A","B","C"]:
    P = empirical_P(cond)
    plt.figure(figsize=(4,3))
    plt.imshow(P, vmin=0, vmax=1, aspect="auto")
    plt.xticks(range(4), states); plt.yticks(range(4), [f"{s}→" for s in states])
    plt.title(f"Empirical transition matrix - {cond}")
    plt.colorbar(label="p(next|current)"); plt.tight_layout()
    plt.savefig(f"P_{cond}.png", dpi=160)

    traj = simulate_from_P(P, n_epochs)
    y = [states.index(s) for s in traj]
    plt.figure(figsize=(7,1.8))
    plt.plot(y, lw=0.7)
    plt.yticks(range(4), states); plt.xlabel("Epoch (30s)")
    plt.title(f"Hypnogram (sim) - {cond}"); plt.tight_layout()
    plt.savefig(f"hypnogram_{cond}.png", dpi=160)
