In [None]:
# ==============================================
# Twitch LLM QA Experiment (Resume-safe, 1-cell)
#  - NUM_LIVE==1 の時は、EXCLUDE_LIVE_IDS を避けてランダムで1件選択
#  - MCQ生成: 古いプロンプト（そのまま）
#  - 回答: 新しいプロンプト（具体的な数値例を出さない）
#  - 選択肢シャッフル: 安定乱数で permute（GTも更新）
#  - 進捗: tqdm
#  - 保存: Local + GCS（レジューム可：既存があればスキップ）
#  - 追加: 各ステージの pred_index 分布を表示
# ==============================================

from __future__ import annotations
import os, io, re, json, textwrap, datetime, hashlib, time, random
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import fsspec
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig

# ====================== CONFIG（ここを調整） ======================
PROJECT_ID           = "dena-ai-intern-ds-dev-gcp"
LOCATION             = "us-central1"
MODEL_NAME           = "gemini-2.5-pro"
TEMPERATURE          = 0

# ---- 入力（Twitch ペアCSV、RUN_IDごとに固定）----
RUN_ID               = "20250910_190657"  # ★あなたの Twitch RUN_ID を入れる
PAIRS_CSV_PATH       = f"gs://dena-ai-intern-yoshihara-data/yoshi_LLMQA_twitch_pairs/llmqa_pairs_{RUN_ID}.csv"

# ---- 対象ライブ（VOD）選択 ----
NUM_LIVE             = 15
LIVE_IDS_EXPLICIT: List[str] = []         # 明示指定するなら ["2562056128"] のように。空なら自動選択
RANDOM_ONE_IF_SINGLE = True               # NUM_LIVE==1 かつ LIVE_IDS_EXPLICIT が空ならランダム抽出
EXCLUDE_LIVE_IDS     = ["2561384383"]     # ★重かった ID をここに列挙（複数可）

# ---- フィルタ・その他 ----
MIN_COMBINED_CHARS   = 500                # combined が短すぎるチャンクはスキップ
RNG_SEED_BASE        = 20240917           # 安定シャッフルのベースシード
CHECKPOINT_EVERY     = 100                # 進捗CSVの保存間隔

# ---- 出力（新規パス。Local & GCS に同構造で保存）----
GCS_BASE_PREFIX      = "gs://dena-ai-intern-yoshihara-data/yoshi_LLMQA_twitch_runs"
RESUME_RUN_LABEL     = ""                 # 既存ラベル指定で再開。空なら自動生成
RUN_LABEL            = ""                 # 明示指定したい場合のみ
# ===============================================================


# ====================== Helpers (FS / Model / JSON / Paths) ======================
def _dedent(s: str) -> str:
    return textwrap.dedent(s).strip()

def fs_gcs():
    return fsspec.filesystem("gcs")

def read_gcs_text(path: str) -> str:
    fs = fs_gcs()
    with fs.open(path, "r") as f:
        return f.read()

def write_gcs_text(path: str, text: str) -> None:
    fs = fs_gcs()
    with fs.open(path, "w") as f:
        f.write(text)

def gcs_exists(path: str) -> bool:
    fs = fs_gcs()
    try:
        return fs.exists(path)
    except Exception:
        return False

def gcs_glob(prefix: str, pattern: str="**") -> List[str]:
    fs = fs_gcs()
    try:
        return sorted(fs.glob(f"{prefix}/{pattern}"))
    except Exception:
        return []

def init_vertex_ai(model_name: str) -> GenerativeModel:
    vertexai.init(project=PROJECT_ID, location=LOCATION)
    return GenerativeModel(model_name)

def call_model_return_both(model: GenerativeModel, content: str) -> Tuple[Dict[str, Any], str]:
    cfg = GenerationConfig(temperature=TEMPERATURE)
    raw = ""
    try:
        resp = model.generate_content([content], generation_config=cfg)
        try:
            raw = resp.text or ""
        except Exception:
            raw = ""
    except Exception:
        raw = ""
    # robust JSON
    def robust_json_loads(text: str) -> Optional[Dict[str, Any]]:
        try:
            return json.loads(text)
        except Exception:
            pass
        m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
        if m:
            try:
                return json.loads(m.group(1))
            except Exception:
                pass
        m = re.search(r"(\{.*\})", text, flags=re.DOTALL)
        if m:
            s = m.group(1)
            s = s[: s.rfind("}") + 1]
            try:
                return json.loads(s)
            except Exception:
                pass
        return None
    obj = robust_json_loads(raw) or {"raw_text": raw}
    return obj, raw

def local_base_dir() -> str:
    return os.path.join("llmqa_twitch_runs", RUN_LABEL)

def gcs_base_prefix() -> str:
    return f"{GCS_BASE_PREFIX}/{RUN_LABEL}"

def ensure_local_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

def write_local_text(path: str, text: str) -> None:
    ensure_local_dir(os.path.dirname(path))
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def write_local_json(path: str, obj: Any) -> None:
    write_local_text(path, json.dumps(obj, ensure_ascii=False, indent=2))

def write_both_text(rel_path: str, text: str) -> None:
    lpath = os.path.join(local_base_dir(), rel_path)
    write_local_text(lpath, text)
    gpath = f"{gcs_base_prefix()}/{rel_path}"
    try:
        write_gcs_text(gpath, text)
    except Exception:
        pass

def write_both_json(rel_path: str, obj: Any) -> None:
    write_both_text(rel_path, json.dumps(obj, ensure_ascii=False, indent=2))

def exists_local(rel_path: str) -> bool:
    return os.path.exists(os.path.join(local_base_dir(), rel_path))

def exists_gcs(rel_path: str) -> bool:
    return gcs_exists(f"{gcs_base_prefix()}/{rel_path}")

def exists_any(rel_path: str) -> bool:
    return exists_local(rel_path) or exists_gcs(rel_path)

def read_local_json(rel_path: str) -> Optional[Any]:
    path = os.path.join(local_base_dir(), rel_path)
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return None

def read_gcs_json(rel_path: str) -> Optional[Any]:
    path = f"{gcs_base_prefix()}/{rel_path}"
    try:
        return json.loads(read_gcs_text(path))
    except Exception:
        return None

def read_any_json(rel_path: str) -> Optional[Any]:
    return read_local_json(rel_path) or read_gcs_json(rel_path)


# ====================== Prompts ======================
def prompt_for_mcq_fixed() -> str:
    # ★問題作成プロンプトは「古い版」をそのまま使う（例で answer_index:0 を含む）
    return _dedent("""
        以下の『配信ログ本文』だけを根拠に、ライブ配信に関する4択問題を日本語で作成してください。外部知識の持ち込みは禁止です。

        # 出題ルール（厳守）
        - 問題数: 2 問（ちょうど2問。増減しない）
        - 1問目は、次の固定文言をそのまま問題文に使う：
          『配信中に出てきた話題は、以下の四つの選択肢のうちどれが正しいですか？』
        - 2問目は、次のテンプレートの {TOPIC} を 1問目の正解選択肢のテキスト（短い名詞句）に置換して使う：
          『配信中、{TOPIC}に関して行われた会話の内容は、以下のどれか？』
          ※ 出力時に {TOPIC} を残さないこと。
        - 似た選択肢は作らない。正解は各問ちょうど1つ。
        - あいまい表現や主観的解釈は禁止。本文の根拠のみ。
        - 挨拶など普遍的な内容は題材にしない。
        - 出力は JSON のみ。説明文やコードフェンス（```）は禁止。
        - 以下に与えた具体的な例とは異なる問題と選択肢を作成する。
        - 配信者/リスナーの区別や経過時間を問題文に含めない。
        - 説明(explanation)はRecitation回避のため原文の直接引用禁止。根拠は短い要約（100文字以内）。
        - 配信ログ本文を読んだら、必ず正解が選べるような問題を作ってください。

        # 出力フォーマット（厳守）
        {
          "questions": [
            {
              "question": "配信中に出てきた話題は、以下の四つの選択肢のうちどれが正しいですか？",
              "choices": ["選択肢A", "選択肢B", "選択肢C", "選択肢D"],
              "answer_index": 0,
              "explanation": "本文からの根拠（直接引用）"
            },
            {
              "question": "配信中、{TOPIC}に関して行われた会話の内容は、以下のどれか？",
              "choices": ["選択肢A", "選択肢B", "選択肢C", "選択肢D"],
              "answer_index": 0,
              "explanation": "本文からの根拠"
            }
          ]
        }

        # 例（one-shot）
        一つ目の問い
        問題文：配信中に出てきた話題は、以下の四つの選択肢のうちどれが正しいですか？
        選択肢A：唐揚げ
        選択肢B：不動産投資
        選択肢C：大阪の万博
        選択肢D：オーストラリアでのマラソン
        正解のインデックス：0

        二つ目の問い
        問題文：配信中、{TOPIC}に関して行われた会話の内容は、以下のどれか？
        選択肢A：唐揚げはもも肉より胸肉の方が良い
        選択肢B：唐揚げの値上げが嫌だ
        選択肢C：いつも唐揚げ食べると胃もたれする
        選択肢D：唐揚げにはレモンをかけるべきか否か
        正解のインデックス：3

        # 配信ログ本文
    """)

def prompt_for_answering(qname: str, source_label: str) -> str:
    # ★新しい版：0の具体例を見せない（プレースホルダだけ）
    return _dedent(f"""
        次の『{source_label}だけ』を根拠に、与えられた1問({qname})の4択に回答してください。
        - 出力は JSON のみ（コードフェンス禁止）。
        - 回答は 0..3 の整数インデックス（choices 配列の添字）。
        - {{0から3までの数値}}の部分はLLMの回答の数値で置き換えてください。

        # 出力フォーマット
        {{
          "answers": {{
            "{qname.lower()}_index": {{0から3までの数値を入れてください}}
          }}
        }}

        # {qname}（question & choices のみ）
        # この後に {qname} のJSONを貼ります（answer_index/explanation は含みません）。

        # 本文
    """)

# ====================== Pairs CSV & Task building ======================
def load_pairs_df() -> pd.DataFrame:
    txt = read_gcs_text(PAIRS_CSV_PATH)
    df = pd.read_csv(io.StringIO(txt))
    # 期待列: vod_id, chunk_idx, combined_path, chatonly_path
    need = {"vod_id","chunk_idx","combined_path","chatonly_path"}
    miss = [c for c in need if c not in df.columns]
    if miss:
        raise ValueError(f"pairs CSV missing columns: {miss}")
    df["vod_id"] = df["vod_id"].astype(str)
    df["chunk_idx"] = pd.to_numeric(df["chunk_idx"], errors="coerce").astype(int)
    return df

def select_live_ids(df_pairs: pd.DataFrame, k: int) -> List[str]:
    # 1) 明示指定があればそれを優先
    if LIVE_IDS_EXPLICIT:
        uniq = []
        seen = set()
        all_ids = set(df_pairs["vod_id"].astype(str))
        for v in LIVE_IDS_EXPLICIT:
            if v in all_ids and v not in seen:
                uniq.append(v); seen.add(v)
            if len(uniq) >= k: break
        return uniq

    # 2) NUM_LIVE==1 かつ RANDOM_ONE_IF_SINGLE の場合、EXCLUDE_LIVE_IDS を除いてランダム抽出
    if k == 1 and RANDOM_ONE_IF_SINGLE:
        uniq_all = []
        seen = set()
        for v in df_pairs["vod_id"].astype(str).tolist():
            if v not in seen:
                seen.add(v); uniq_all.append(v)
        candidates = [v for v in uniq_all if v not in set(EXCLUDE_LIVE_IDS)]
        if not candidates:
            candidates = uniq_all  # さすがに空は避ける
        # 時刻ベースの乱数で選択（毎回変わる）
        random.seed(int(time.time()))
        choice = random.choice(candidates)
        print(f"[select_live_ids] RANDOM pick (excluding {EXCLUDE_LIVE_IDS}): {choice}")
        return [choice]

    # 3) それ以外は先頭から順に k 件
    uniq = []
    seen = set()
    for v in df_pairs["vod_id"].astype(str).tolist():
        if v not in seen:
            uniq.append(v); seen.add(v)
        if len(uniq) >= k: break
    return uniq

def make_run_label() -> str:
    now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    n = len(LIVE_IDS_EXPLICIT) if LIVE_IDS_EXPLICIT else NUM_LIVE
    return f"run_{RUN_ID}_{n}lives_{now}"

# ====================== MCQ生成（古いプロンプト） ======================
def generate_mcq_for_chunk(model: GenerativeModel, combined_text: str, save_label: str) -> Optional[Dict[str,Any]]:
    prompt = prompt_for_mcq_fixed()
    content = prompt + "\n" + (combined_text or "")
    obj, raw = call_model_return_both(model, content)
    # sanitize: 確実に2問に揃え、Q2の{TOPIC}置換
    def ensure_two_questions(mcq: Dict[str,Any]) -> Dict[str,Any]:
        if isinstance(mcq, dict) and isinstance(mcq.get("questions"), list):
            if len(mcq["questions"]) > 2:
                mcq["questions"] = mcq["questions"][:2]
        return mcq
    def fill_topic(mcq: Dict[str,Any]) -> Optional[Dict[str,Any]]:
        if not isinstance(mcq, dict) or "questions" not in mcq: return None
        qs = mcq.get("questions") or []
        if len(qs) < 2: return None
        try:
            aidx = int(qs[0].get("answer_index", 0))
            topic = (qs[0].get("choices") or [])[aidx]
        except Exception:
            return None
        if not isinstance(topic, str) or not topic.strip(): return None
        q2q = str(qs[1].get("question",""))
        for ph in ("{TOPIC}", "<Q1正解>", "{一つ目の問の答え}"):
            q2q = q2q.replace(ph, topic)
        qs[1]["question"] = q2q
        mcq["questions"] = qs[:2]
        return mcq
    obj = ensure_two_questions(obj)
    obj = fill_topic(obj) or obj
    # 最低限の妥当性
    try:
        qs = obj["questions"]
        assert isinstance(qs, list) and len(qs) == 2
        for q in qs:
            assert isinstance(q.get("question",""), str)
            assert isinstance(q.get("choices", []), list) and len(q["choices"]) == 4
            int(q.get("answer_index", 0))  # int化できる
    except Exception:
        return {"raw_text": raw}  # ログには残す
    return obj

def q_full_path(lid: str, idx: int) -> str:
    return f"questions/full/{lid}_{idx}.json"
def q_pub_path(lid: str, idx: int) -> str:
    return f"questions/public/{lid}_{idx}.json"

# ====================== Distributions ======================
def gt_distribution_from_full(paths: List[str]) -> Dict[str, Dict[int,float]]:
    cnt = {"Q1":[0,0,0,0], "Q2":[0,0,0,0]}
    tot = {"Q1":0, "Q2":0}
    for rel in paths:
        mcq = read_any_json(rel)
        if not (isinstance(mcq, dict) and "questions" in mcq and isinstance(mcq["questions"], list) and len(mcq["questions"])==2):
            continue
        try:
            a1 = int(mcq["questions"][0]["answer_index"]); cnt["Q1"][a1]+=1; tot["Q1"]+=1
            a2 = int(mcq["questions"][1]["answer_index"]); cnt["Q2"][a2]+=1; tot["Q2"]+=1
        except Exception:
            continue
    out = {}
    for k in ["Q1","Q2"]:
        t = max(1, tot[k])
        out[k] = {i: float(cnt[k][i])/t for i in range(4)}
    return out

def gt_distribution_from_shuffled(keys: List[str]) -> Dict[str, Dict[int,float]]:
    cnt = {"Q1":[0,0,0,0], "Q2":[0,0,0,0]}
    tot = {"Q1":0, "Q2":0}
    for key in keys:
        rel = f"shuffled/by_key/{key}.json"
        rec = read_any_json(rel)
        if not isinstance(rec, dict): continue
        qn = rec.get("qname")
        gi = rec.get("new_gt_index")
        if qn in ("Q1","Q2") and isinstance(gi, int) and 0<=gi<4:
            cnt[qn][gi]+=1; tot[qn]+=1
    out = {}
    for k in ["Q1","Q2"]:
        t = max(1, tot[k])
        out[k] = {i: float(cnt[k][i])/t for i in range(4)}
    return out

# ====================== Shuffle (stable) ======================
def key_tuple_str(lid: str, chunk_idx: int, qname: str) -> str:
    return f"{lid}_{chunk_idx}_{qname}"

def seed_for_item(base_seed: int, lid: str, chunk_idx: int, qname: str) -> int:
    h = hashlib.blake2b(digest_size=8)
    h.update(f"{base_seed}|{lid}|{chunk_idx}|{qname}".encode("utf-8"))
    return int.from_bytes(h.digest(), "little", signed=False)

def stable_shuffle(choices: List[str], gt_index: int, lid: str, chunk_idx: int, qname: str) -> Tuple[List[str], int, List[int]]:
    rng = np.random.default_rng(seed_for_item(RNG_SEED_BASE, lid, chunk_idx, qname))
    perm = rng.permutation(4)
    new_choices = [choices[old] for old in perm]
    new_gt = int(np.where(perm == gt_index)[0][0])
    assert new_choices[new_gt] == choices[gt_index]
    return new_choices, new_gt, perm.tolist()

def get_or_make_shuffle_record(lid: str, chunk_idx: int, qname: str, qjson_full: Dict[str,Any]) -> Dict[str,Any]:
    key = key_tuple_str(lid, chunk_idx, qname)
    rel = f"shuffled/by_key/{key}.json"
    rec = read_any_json(rel)
    if rec is not None:
        return rec
    q = qjson_full["questions"][0] if qname=="Q1" else qjson_full["questions"][1]
    new_choices, new_gt, perm = stable_shuffle(q["choices"], int(q["answer_index"]), lid, chunk_idx, qname)
    rec = {
        "key": key,
        "live_id": lid,
        "chunk_idx": chunk_idx,
        "qname": qname,
        "question": q["question"],
        "orig_choices": q["choices"],
        "new_choices": new_choices,
        "orig_gt_index": int(q["answer_index"]),
        "new_gt_index": int(new_gt),
        "perm_new_to_old": perm,
    }
    write_both_json(rel, rec)
    return rec

# ====================== Answering ======================
def build_answer_content(qname: str, source_label: str, qjson_public: Dict[str, Any], body_text: Optional[str]) -> str:
    prompt = prompt_for_answering(qname, source_label)
    base = f"""{prompt}

# {qname} JSON
{json.dumps(qjson_public, ensure_ascii=False, indent=2)}
"""
    if body_text and body_text.strip():
        return base + f"\n# 本文\n{body_text}"
    else:
        return base

def parse_pred_index(obj: Dict[str, Any], qname: str) -> Tuple[Optional[int], int]:
    if not isinstance(obj, dict): return (None,1)
    ans = obj.get("answers", {})
    if not isinstance(ans, dict): return (None,1)
    key = f"{qname.lower()}_index"
    val = ans.get(key, ans.get("index"))
    try:
        pred = int(val)
    except Exception:
        return (None,1)
    if 0 <= pred <= 3:
        return (pred,0)
    return (None,1)

def stage_answer(model: GenerativeModel, tasks: List[Dict[str, Any]],
                 stage_code: str, source_label: str, include_body: bool) -> List[Dict[str, Any]]:
    new_logs = []
    pbar = tqdm(total=len(tasks), desc=f"{stage_code}: 0/{len(tasks)}", dynamic_ncols=True)
    done = 0
    for t in tasks:
        key = key_tuple_str(t["lid"], t["chunk_idx"], t["qname"])
        log_rel = f"logs/{stage_code}/by_key/{key}.json"
        if exists_any(log_rel):
            done += 1; pbar.set_description_str(f"{stage_code}: {done}/{len(tasks)}"); pbar.update(1)
            continue

        sh = read_any_json(f"shuffled/by_key/{key}.json")
        if not isinstance(sh, dict):
            done += 1; pbar.set_description_str(f"{stage_code}: {done}/{len(tasks)}"); pbar.update(1)
            continue

        q_pub = {"question": sh["question"], "choices": sh["new_choices"]}
        body = ""
        if include_body:
            if "combined" in stage_code.lower():
                body = t.get("combined_text","") or ""
                if not body and isinstance(t.get("combined_path"), str):
                    try: body = read_gcs_text(t["combined_path"])
                    except Exception: body = ""
            elif "chat" in stage_code.lower():
                body = t.get("chat_text","") or ""
                if not body and isinstance(t.get("chat_path"), str):
                    try: body = read_gcs_text(t["chat_path"])
                    except Exception: body = ""

        content = build_answer_content(t["qname"], source_label, q_pub, body_text=body)
        obj, raw = call_model_return_both(model, content)
        pred, invalid = parse_pred_index(obj, t["qname"])
        correct = int(pred == int(sh["new_gt_index"])) if (invalid==0 and pred is not None) else None

        log_rec = {
            "key": key,
            "stage": stage_code,
            "live_id": t["lid"],
            "chunk_idx": t["chunk_idx"],
            "qname": t["qname"],
            "source_label": source_label,
            "include_body": bool(include_body),
            "question": sh["question"],
            "choices": sh["new_choices"],
            "new_gt_index": int(sh["new_gt_index"]),
            "pred_index": (int(pred) if pred is not None else None),
            "invalid_format": int(invalid),
            "correct": (int(correct) if correct is not None else None),
            "prompt_sent": content,
            "raw_response": raw,
            "ts": datetime.datetime.now().isoformat(timespec="seconds"),
        }
        write_both_json(log_rel, log_rec)
        new_logs.append(log_rec)

        done += 1
        pbar.set_description_str(f"{stage_code}: {done}/{len(tasks)}")
        pbar.update(1)
    pbar.close()
    return new_logs

def load_stage_logs(stage_code: str) -> pd.DataFrame:
    recs = []
    # local
    base = os.path.join(local_base_dir(), f"logs/{stage_code}/by_key")
    if os.path.isdir(base):
        for fn in os.listdir(base):
            if fn.endswith(".json"):
                try:
                    with open(os.path.join(base, fn), "r", encoding="utf-8") as f:
                        recs.append(json.load(f))
                except Exception:
                    pass
    # gcs
    for gp in gcs_glob(f"{gcs_base_prefix()}/logs/{stage_code}/by_key", pattern="*.json"):
        key = os.path.basename(gp).replace(".json","")
        if any(r.get("key")==key for r in recs):
            continue
        try:
            txt = read_gcs_text(gp)
            recs.append(json.loads(txt))
        except Exception:
            pass
    cols = ["key","stage","live_id","chunk_idx","qname","new_gt_index","pred_index","invalid_format","correct"]
    return pd.DataFrame(recs)[cols] if recs else pd.DataFrame(columns=cols)

# ====================== Samples (correct/incorrect 各2件をローカル保存) ======================
def save_stage_samples(stage_code: str, k_each: int = 2) -> None:
    df = load_stage_logs(stage_code)
    for cls, cond in [("correct", df["correct"]==1),
                      ("incorrect", df["correct"]==0)]:
        sub = df[cond].head(k_each)
        for _, r in sub.iterrows():
            key = r["key"]
            rel = f"logs/{stage_code}/by_key/{key}.json"
            log = read_any_json(rel) or {}
            lines = []
            lines.append(f"### META\nstage: {stage_code}\nkey: {key}\nlive_id: {r['live_id']}\nchunk_idx: {r['chunk_idx']}\nqname: {r['qname']}\n")
            lines.append(f"GT(new): {r['new_gt_index']}  pred: {r['pred_index']}  correct: {r['correct']}  invalid: {r['invalid_format']}")
            lines.append("\n### QUESTION (shuffled)\n" + json.dumps({"question": log.get("question"), "choices": log.get("choices")}, ensure_ascii=False, indent=2))
            lines.append("\n### PROMPT SENT (exact)\n" + (log.get("prompt_sent","") or ""))
            lines.append("\n### RAW RESPONSE\n" + (log.get("raw_response","") or ""))
            lpath = os.path.join(local_base_dir(), f"samples/{stage_code}/{cls}/{key}.txt")
            if not os.path.exists(lpath):
                write_local_text(lpath, "\n".join(lines))

# ====================== Metrics ======================
def index_distribution(series: pd.Series, k: int=4) -> Dict[int,float]:
    cnt = series.value_counts().reindex(range(k), fill_value=0).sort_index()
    tot = int(cnt.sum())
    if tot==0: return {i:0.0 for i in range(k)}
    share = (cnt / tot).fillna(0.0)
    return {int(i): float(share.iloc[i]) for i in range(k)}

def summarize_stage(df: pd.DataFrame, qname: str, stage_code: str) -> Dict[str,Any]:
    sub = df[(df["qname"]==qname) & (df["stage"]==stage_code)].copy()
    provided = len(sub)
    valid = int((sub["invalid_format"]==0).sum())
    correct_total = int(sub.loc[sub["invalid_format"]==0, "correct"].fillna(0).astype(int).sum())
    acc = float(correct_total / valid) if valid else None
    pred_dist = index_distribution(sub.loc[sub["invalid_format"]==0, "pred_index"].dropna().astype(int),4) if valid else {i:0.0 for i in range(4)}
    return {"stage":stage_code, "qname":qname, "provided":int(provided), "answered_valid":int(valid),
            "correct_total":int(correct_total), "accuracy":acc, "pred_index_dist":pred_dist,
            "invalid_format_total": int((sub["invalid_format"]==1).sum())}

# ====================== Main Flow ======================
# 0) Run label
if RESUME_RUN_LABEL.strip():
    RUN_LABEL = RESUME_RUN_LABEL.strip()
elif not RUN_LABEL.strip():
    RUN_LABEL = make_run_label()

ensure_local_dir(local_base_dir())
if not exists_any("manifest.json"):
    manifest = {
        "run_label": RUN_LABEL,
        "run_id": RUN_ID,
        "num_live_requested": NUM_LIVE,
        "rng_seed_base": RNG_SEED_BASE,
        "model_name": MODEL_NAME,
        "temperature": TEMPERATURE,
        "created_at": datetime.datetime.now().isoformat(timespec="seconds"),
        "pairs_csv": PAIRS_CSV_PATH,
        "random_one_if_single": RANDOM_ONE_IF_SINGLE,
        "exclude_live_ids": EXCLUDE_LIVE_IDS,
    }
    write_both_json("manifest.json", manifest)

print(f"RUN_LABEL: {RUN_LABEL}")
print(f"Local out: {local_base_dir()}")
print(f"GCS out:   {gcs_base_prefix()}")
print(f"Pairs CSV: {PAIRS_CSV_PATH}")

# 1) Load pairs & select lives (random if single & not explicit)
pairs_df = load_pairs_df()
target_lives = select_live_ids(pairs_df, NUM_LIVE)
write_both_json("selected_live_ids.json", {"live_ids": target_lives})
print("選択 live_id:", target_lives)

# 2) Build chunk list for target lives
sub_pairs = pairs_df[pairs_df["vod_id"].isin(target_lives)].copy()
sub_pairs = sub_pairs.sort_values(["vod_id","chunk_idx"])

# 3) MCQ generation (per chunk, using combined)
model = init_vertex_ai(MODEL_NAME)

created_mcq = 0
paths_full = []
pbar = tqdm(total=len(sub_pairs), desc="MCQ: 0/chunks", dynamic_ncols=True)
done = 0
for row in sub_pairs.itertuples(index=False):
    lid = str(row.vod_id); idx = int(row.chunk_idx)
    rel_full = q_full_path(lid, idx)
    rel_pub  = q_pub_path(lid, idx)
    if exists_any(rel_full) and exists_any(rel_pub):
        paths_full.append(rel_full)
        done += 1; pbar.set_description_str(f"MCQ: {done}/{len(sub_pairs)}"); pbar.update(1)
        continue
    # read combined
    try:
        combined_text = read_gcs_text(row.combined_path)
    except Exception as e:
        done += 1; pbar.set_description_str(f"MCQ: {done}/{len(sub_pairs)}"); pbar.update(1)
        continue
    if len((combined_text or "").strip()) < MIN_COMBINED_CHARS:
        done += 1; pbar.set_description_str(f"MCQ: {done}/{len(sub_pairs)}"); pbar.update(1)
        continue
    mcq = generate_mcq_for_chunk(model, combined_text, save_label=f"mcq_{lid}_{idx}")
    if isinstance(mcq, dict) and mcq.get("questions") and len(mcq["questions"])==2:
        write_both_json(rel_full, mcq)
        q1 = mcq["questions"][0]; q2 = mcq["questions"][1]
        pub = {"questions":[{"question":q1["question"],"choices":q1["choices"][:4]},
                            {"question":q2["question"],"choices":q2["choices"][:4]}]}
        write_both_json(rel_pub, pub)
        created_mcq += 1
        paths_full.append(rel_full)
    done += 1; pbar.set_description_str(f"MCQ: {done}/{len(sub_pairs)}"); pbar.update(1)
pbar.close()
print(f"MCQ created (new): {created_mcq}")

# 4) Distribution BEFORE shuffle
dist_before = gt_distribution_from_full(paths_full)
write_both_json("metrics_dist_before.json", dist_before)
print("\n=== GT index distribution (BEFORE shuffle) ===")
print(json.dumps(dist_before, ensure_ascii=False, indent=2))

# 5) Shuffle choices (stable) & record
shuffle_keys = []
pbar = tqdm(total=len(paths_full), desc="Shuffle: 0/chunks", dynamic_ncols=True)
done = 0
for rel_full in paths_full:
    base = os.path.basename(rel_full).replace(".json","")  # lid_idx
    lid, idx = base.split("_", 1)
    idx = int(idx)
    mcq = read_any_json(rel_full)
    if not (isinstance(mcq, dict) and mcq.get("questions") and len(mcq["questions"])==2):
        done += 1; pbar.set_description_str(f"Shuffle: {done}/{len(paths_full)}"); pbar.update(1)
        continue
    for qname in ("Q1","Q2"):
        rec = get_or_make_shuffle_record(lid, idx, qname, mcq)
        shuffle_keys.append(rec["key"])
    done += 1; pbar.set_description_str(f"Shuffle: {done}/{len(paths_full)}"); pbar.update(1)
pbar.close()

# 6) Distribution AFTER shuffle
dist_after = gt_distribution_from_shuffled(shuffle_keys)
write_both_json("metrics_dist_after.json", dist_after)
print("\n=== GT index distribution (AFTER shuffle) ===")
print(json.dumps(dist_after, ensure_ascii=False, indent=2))

# 7) Build answering tasks
def build_tasks_from_pairs_and_mcq(sub_pairs: pd.DataFrame) -> List[Dict[str,Any]]:
    tasks = []
    for row in sub_pairs.itertuples(index=False):
        lid = str(row.vod_id); idx = int(row.chunk_idx)
        rel_full = q_full_path(lid, idx)
        if not exists_any(rel_full):
            continue
        mcq = read_any_json(rel_full)
        if not (isinstance(mcq, dict) and mcq.get("questions") and len(mcq["questions"])==2):
            continue
        has_chat = isinstance(row.chatonly_path, str) and len(row.chatonly_path) > 0
        tasks.append({
            "lid": lid, "chunk_idx": idx,
            "combined_path": row.combined_path,
            "chat_path": row.chatonly_path if has_chat else None,
            "has_chat": bool(has_chat),
            "qname": "Q1",
        })
        tasks.append({
            "lid": lid, "chunk_idx": idx,
            "combined_path": row.combined_path,
            "chat_path": row.chatonly_path if has_chat else None,
            "has_chat": bool(has_chat),
            "qname": "Q2",
        })
    return tasks

all_tasks = build_tasks_from_pairs_and_mcq(sub_pairs)

# 8) Primary stages
logs_q1_combined = stage_answer(model, [t for t in all_tasks if t["qname"]=="Q1"],
                                stage_code="Q1_LLM2_combined",
                                source_label="配信ログ本文（チャット＋音声書き起こし）",
                                include_body=True)
logs_q1_nothing  = stage_answer(model, [t for t in all_tasks if t["qname"]=="Q1"],
                                stage_code="Q1_LLM3_nothing",
                                source_label="問題文のみ",
                                include_body=False)

logs_q2_combined = stage_answer(model, [t for t in all_tasks if t["qname"]=="Q2"],
                                stage_code="Q2_LLM6_combined",
                                source_label="配信ログ本文（チャット＋音声書き起こし）",
                                include_body=True)
logs_q2_nothing  = stage_answer(model, [t for t in all_tasks if t["qname"]=="Q2"],
                                stage_code="Q2_LLM7_nothing",
                                source_label="問題文のみ",
                                include_body=False)

# 9) Subset selection (Q1: LLM2 correct & LLM3 incorrect)
df_q1_c = load_stage_logs("Q1_LLM2_combined")
df_q1_n = load_stage_logs("Q1_LLM3_nothing")
subset_q1 = df_q1_c.merge(df_q1_n, on=["key","live_id","chunk_idx","qname"], suffixes=("_llm2","_llm3"))
subset_q1 = subset_q1[(subset_q1["correct_llm2"]==1) & ((subset_q1["correct_llm3"].isna()) | (subset_q1["correct_llm3"]==0))]
subset_q1_keys = set(subset_q1["key"].tolist())

subset_q1_tasks = [t for t in all_tasks if t["qname"]=="Q1" and key_tuple_str(t["lid"], t["chunk_idx"], "Q1") in subset_q1_keys]
subset_q1_tasks_with_chat = [t for t in subset_q1_tasks if t["has_chat"]]

# Q1 extra: chat / nothing
logs_q1_chat    = stage_answer(model, subset_q1_tasks_with_chat,
                               stage_code="Q1_LLM4_chat",
                               source_label="チャット本文",
                               include_body=True)
logs_q1_nothing2= stage_answer(model, subset_q1_tasks,
                               stage_code="Q1_LLM5_nothing",
                               source_label="問題文のみ",
                               include_body=False)

# 10) Subset selection (Q2: LLM6 correct & LLM7 incorrect)
df_q2_c = load_stage_logs("Q2_LLM6_combined")
df_q2_n = load_stage_logs("Q2_LLM7_nothing")
subset_q2 = df_q2_c.merge(df_q2_n, on=["key","live_id","chunk_idx","qname"], suffixes=("_llm6","_llm7"))
subset_q2 = subset_q2[(subset_q2["correct_llm6"]==1) & ((subset_q2["correct_llm7"].isna()) | (subset_q2["correct_llm7"]==0))]
subset_q2_keys = set(subset_q2["key"].tolist())

subset_q2_tasks = [t for t in all_tasks if t["qname"]=="Q2" and key_tuple_str(t["lid"], t["chunk_idx"], "Q2") in subset_q2_keys]
subset_q2_tasks_with_chat = [t for t in subset_q2_tasks if t["has_chat"]]

# Q2 extra: chat / nothing
logs_q2_chat    = stage_answer(model, subset_q2_tasks_with_chat,
                               stage_code="Q2_LLM8_chat",
                               source_label="チャット本文",
                               include_body=True)
logs_q2_nothing2= stage_answer(model, subset_q2_tasks,
                               stage_code="Q2_LLM9_nothing",
                               source_label="問題文のみ",
                               include_body=False)

# 11) サンプル（各ステージ：正解2・不正解2をローカル保存）
for st in ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing",
           "Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]:
    save_stage_samples(st, k_each=2)

# 12) Metrics（分布・精度など）
stage_codes = ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing",
               "Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]
logs_all = []
for st in stage_codes:
    df = load_stage_logs(st)
    if not df.empty:
        logs_all.append(df)
logs_df = pd.concat(logs_all, ignore_index=True) if logs_all else pd.DataFrame(columns=["key","stage","live_id","chunk_idx","qname","new_gt_index","pred_index","invalid_format","correct"])

summaries = {"Q1":{}, "Q2":{}}
for st in ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing"]:
    summaries["Q1"][st] = summarize_stage(logs_df, "Q1", st)
for st in ["Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]:
    summaries["Q2"][st] = summarize_stage(logs_df, "Q2", st)

subset_sizes = {
    "Q1_total": len(subset_q1_tasks),
    "Q1_with_chat": len(subset_q1_tasks_with_chat),
    "Q2_total": len(subset_q2_tasks),
    "Q2_with_chat": len(subset_q2_tasks_with_chat),
}

def gt_distribution_from_files_or_empty() -> Tuple[Dict[str,Dict[int,float]], Dict[str,Dict[int,float]]]:
    try:
        return (json.loads(read_gcs_text(f"{gcs_base_prefix()}/metrics_dist_before.json")),
                json.loads(read_gcs_text(f"{gcs_base_prefix()}/metrics_dist_after.json")))
    except Exception:
        return ({}, {})

metrics = {
    "run_label": RUN_LABEL,
    "run_id": RUN_ID,
    "num_live": len(target_lives),
    "gt_index_distribution_before": gt_distribution_from_files_or_empty()[0],
    "gt_index_distribution_after": gt_distribution_from_files_or_empty()[1],
    "stage_summaries": summaries,
    "subset_sizes": subset_sizes,
    "generated_at": datetime.datetime.now().isoformat(timespec="seconds"),
}
write_both_json("metrics.json", metrics)

# 13) Human-readable print with pred_dist
def _fmt_stage_line(s: Dict[str,Any]) -> str:
    acc = "NA" if s["accuracy"] is None else f"{s['accuracy']:.3f}"
    return (f"{s['stage']}: provided={s['provided']} valid={s['answered_valid']} "
            f"correct={s['correct_total']} acc={acc} pred_dist={s['pred_index_dist']} "
            f"invalid={s['invalid_format_total']}")

print("\n--- Q1 ---")
for st in ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing"]:
    print(_fmt_stage_line(summaries["Q1"][st]))

print("\n--- Q2 ---")
for st in ["Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]:
    print(_fmt_stage_line(summaries["Q2"][st]))

print("\n=== Conditional subset sizes ===")
print(f"Q1 subset (LLM2 correct & LLM3 incorrect): total={subset_sizes['Q1_total']}; with_chat={subset_sizes['Q1_with_chat']}")
print(f"Q2 subset (LLM6 correct & LLM7 incorrect): total={subset_sizes['Q2_total']}; with_chat={subset_sizes['Q2_with_chat']}")

print(f"\nAll done. Local outputs: {local_base_dir()}")
print(f"GCS outputs:          {gcs_base_prefix()}")
