In [5]:
# ==============================================================
#  extract_chain_median.py  ✧  单链中位数导出（与 Simu.py 变量名对齐）
#  --------------------------------------------------------------
#  用法：
#    1) 确认 CHAIN_IDX=3，并将 CHAIN_DRAW_FILE 指向第三条链的 pkl
#    2) python extract_chain_median.py
#  产物：
#    saved_result/adult_params_chain3_median.json
#    saved_result/adult_params_chain3_median.npy
# ==============================================================

import os, pickle, json, datetime
import numpy as np

# === 与现有脚本保持相同的“配置变量名” ==============================
PARAM_SOURCE = "mcmc"   # {"init", "modfit", "mcmc", "file"} —— 仅占位，不参与计算
CHAIN_IDX = 3           # ← 这里固定选择“第三条链”
CHAIN_DRAW_FILE = f"saved_result/chain{CHAIN_IDX}_draws2025-07-19.pkl"
N_SAMPLES = 500         # ← 占位，不参与本脚本运算

In [6]:
# === 读取该链抽样 ====================================================
if not os.path.exists(CHAIN_DRAW_FILE):
    raise FileNotFoundError(f"找不到链文件：{CHAIN_DRAW_FILE}")

with open(CHAIN_DRAW_FILE, "rb") as f:
    chain_draws = pickle.load(f)   # 期望形状: (n_draw, n_param+... )

if not isinstance(chain_draws, np.ndarray):
    chain_draws = np.asarray(chain_draws)

# 仅取前10个模型参数（忽略 sigma 等）
if chain_draws.shape[1] < 10:
    raise ValueError(f"此链参数列数不足 10 列，实际为 {chain_draws.shape[1]} 列。")

param_draws = chain_draws[:, :10]

# 丢弃 burn-in（前10%）
start = int(0.1 * len(param_draws))
if start >= len(param_draws):
    raise ValueError("样本量过小，无法丢弃 burn-in。")
param_draws = param_draws[start:, :]

In [3]:
# 逐参数取中位数
param_median = np.median(param_draws, axis=0)  # shape: (10,)

# === 保存输出 ========================================================
SAVE_DIR = "saved_result"
os.makedirs(SAVE_DIR, exist_ok=True)
today = datetime.datetime.now().strftime("%Y-%m-%d")

# 保存为 json（带参数索引）
out_json = {
    "meta": {
        "chain_idx": CHAIN_IDX,
        "source_file": os.path.basename(CHAIN_DRAW_FILE),
        "burn_in_fraction": 0.10,
        "created_at": today,
        "note": "前10个模型参数的后验中位数，用作成人基线参数。"
    },
    "param_names": [f"theta{i+1}" for i in range(10)],
    "param_median": param_median.tolist()
}
json_path = os.path.join(SAVE_DIR, f"adult_params_chain{CHAIN_IDX}_median.json")
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(out_json, f, ensure_ascii=False, indent=2)

In [4]:
# 另存 .npy 便于后续快速加载
npy_path = os.path.join(SAVE_DIR, f"adult_params_chain{CHAIN_IDX}_median.npy")
np.save(npy_path, param_median)

# 终端输出
print("✅ 已生成成人基线参数（后验中位数）：")
for i, v in enumerate(param_median, 1):
    print(f"  theta{i:02d} = {v:.6g}")
print(f"\n已保存：\n  - {json_path}\n  - {npy_path}")

✅ 已生成成人基线参数（后验中位数）：
  theta01 = 0.00143049
  theta02 = 0.00154262
  theta03 = 9.58456
  theta04 = 0.21246
  theta05 = 1.69697
  theta06 = 0.854916
  theta07 = 1.56245
  theta08 = 16.7723
  theta09 = 0.0287982
  theta10 = 0.000606328

已保存：
  - saved_result\adult_params_chain3_median.json
  - saved_result\adult_params_chain3_median.npy
