In [None]:
# --- New ipynb cell: Shuffle choices & GT with full logging (local+GCS), resumable execution ---

from __future__ import annotations
import os, io, json, re, textwrap, datetime, hashlib
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import fsspec
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig

# ====================== CONFIG ======================
PROJECT_ID        = "dena-ai-intern-ds-dev-gcp"
LOCATION          = "us-central1"

# ---- Input: existing run (READ-ONLY) ----
RUN_TAG           = "run_1050"
RUNLOG_PREFIX     = f"gs://dena-ai-intern-yoshihara-data/yoshi_LLMQA_run_logs/{RUN_TAG}"
ADOPT_LOG_PATH    = f"{RUNLOG_PREFIX}/adopted_chunks_top100.csv"

# ---- Target lives ----
NUM_LIVE          = 40
LIVE_IDS_EXPLICIT: List[int] = []  # 空なら最初のNUM_LIVE件を自動選択

# ---- Shuffle & model ----
RNG_SEED_BASE     = 20240917  # 安定シャッフルのベースシード（アイテム毎に派生させる）
MODEL_NAME        = "gemini-2.5-pro"
TEMPERATURE       = 0

# ---- Output (LOCAL + GCS). 既存GCS成果物は一切更新しない。新しい専用プレフィックスを使う。----
GCS_BASE_PREFIX   = "gs://dena-ai-intern-yoshihara-data/yoshi_LLMQA_shuffle_eval"

# 再開用のラベル。空なら新規作成: 例) run_1050_10lives_YYYYmmdd_HHMMSS
RESUME_RUN_LABEL  = "run_1050_40lives_20250917_061555"     # 既存の run ラベル（例: "run_1050_10lives_20250916_094017"）を指定で再開
RUN_LABEL         = ""     # 上記が空のとき自動生成

# 保存物（全問、全ステージで保存）
SAVE_EVERY_QA_SAMPLE = True  # 正解/不正解をすべて samples に保存

# ====================== Helpers (FS/Vertex/JSON) ======================
def _dedent(s: str) -> str:
    return textwrap.dedent(s).strip()

def fs_gcs():
    return fsspec.filesystem("gcs")

def init_vertex_ai(model_name: str) -> GenerativeModel:
    vertexai.init(project=PROJECT_ID, location=LOCATION)
    return GenerativeModel(model_name)

def read_gcs_text(path: str) -> str:
    fs = fs_gcs()
    with fs.open(path, "r") as f:
        return f.read()

def write_gcs_text(path: str, text: str) -> None:
    fs = fs_gcs()
    # gcsfsは'wb'書き換えのみ安定。appendは不可のことが多いので1ファイル=1レコードで保存する。
    with fs.open(path, "w") as f:
        f.write(text)

def gcs_exists(path: str) -> bool:
    fs = fs_gcs()
    try:
        return fs.exists(path)
    except Exception:
        return False

def gcs_glob(prefix: str, pattern: str="**") -> List[str]:
    fs = fs_gcs()
    try:
        return sorted(fs.glob(f"{prefix}/{pattern}"))
    except Exception:
        return []

def robust_json_loads(text: str) -> Optional[Dict[str, Any]]:
    # 1) 素直
    try:
        return json.loads(text)
    except Exception:
        pass
    # 2) ```json ... ``` 抜き出し
    m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
    if m:
        try:
            return json.loads(m.group(1))
        except Exception:
            pass
    # 3) 先頭から最後の } まで
    m = re.search(r"(\{.*\})", text, flags=re.DOTALL)
    if m:
        s = m.group(1)
        s = s[: s.rfind("}") + 1]
        try:
            return json.loads(s)
        except Exception:
            pass
    return None

def call_model_return_both(model: GenerativeModel, content: str) -> Tuple[Dict[str, Any], str]:
    cfg = GenerationConfig(temperature=TEMPERATURE)
    raw = ""
    try:
        resp = model.generate_content([content], generation_config=cfg)
        try:
            raw = resp.text or ""
        except Exception:
            raw = ""
    except Exception:
        raw = ""
    obj = robust_json_loads(raw)
    if obj is None:
        obj = {"raw_text": raw}
    return obj, raw

# ====================== Prompts ======================
def prompt_for_answering(qname: str, source_label: str) -> str:
    return _dedent(f"""
        次の『{source_label}だけ』を根拠に、与えられた1問({qname})の4択に回答してください。
        - 出力は JSON のみ（コードフェンス禁止）。
        - 回答は 0..3 の整数インデックス（choices 配列の添字）。
        - {{0から3までの数値}}の部分はLLMの回答の数値で置き換えて下さい

        # 出力フォーマット
        {{
          "answers": {{
            "{qname.lower()}_index":  {{0から3までの数値を入れてください}}
          }}
        }}

        # {qname}（question & choices のみ）
        # この後に {qname} のJSONを貼ります（answer_index/explanation は含みません）。

        # 本文
    """)

def build_content(qname: str, source_label: str, qjson: Dict[str, Any], body_text: Optional[str]=None) -> str:
    prompt = prompt_for_answering(qname, source_label)
    base = f"""{prompt}

# {qname} JSON
{json.dumps(qjson, ensure_ascii=False, indent=2)}
"""
    if body_text and body_text.strip():
        return base + f"\n# 本文\n{body_text}"
    else:
        return base

# ====================== Run directory layout (LOCAL + GCS) ======================
def make_run_label() -> str:
    now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    n = NUM_LIVE if not LIVE_IDS_EXPLICIT else len(LIVE_IDS_EXPLICIT)
    return f"{RUN_TAG}_{n}lives_{now}"

def local_base_dir() -> str:
    return os.path.join("llmqa_shuffle_eval", RUN_LABEL)

def gcs_base_prefix() -> str:
    return f"{GCS_BASE_PREFIX}/{RUN_LABEL}"

def ensure_local_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

def write_local_text(path: str, text: str) -> None:
    ensure_local_dir(os.path.dirname(path))
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def write_local_json(path: str, obj: Any) -> None:
    write_local_text(path, json.dumps(obj, ensure_ascii=False, indent=2))

def write_both_text(rel_path: str, text: str) -> None:
    # Local
    lpath = os.path.join(local_base_dir(), rel_path)
    write_local_text(lpath, text)
    # GCS
    gpath = f"{gcs_base_prefix()}/{rel_path}"
    try:
        write_gcs_text(gpath, text)
    except Exception:
        pass

def write_both_json(rel_path: str, obj: Any) -> None:
    write_both_text(rel_path, json.dumps(obj, ensure_ascii=False, indent=2))

def exists_local(rel_path: str) -> bool:
    return os.path.exists(os.path.join(local_base_dir(), rel_path))

def exists_gcs(rel_path: str) -> bool:
    return gcs_exists(f"{gcs_base_prefix()}/{rel_path}")

def exists_any(rel_path: str) -> bool:
    return exists_local(rel_path) or exists_gcs(rel_path)

def read_local_json(rel_path: str) -> Optional[Any]:
    path = os.path.join(local_base_dir(), rel_path)
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return None

def read_gcs_json(rel_path: str) -> Optional[Any]:
    path = f"{gcs_base_prefix()}/{rel_path}"
    try:
        return json.loads(read_gcs_text(path))
    except Exception:
        return None

def read_any_json(rel_path: str) -> Optional[Any]:
    obj = read_local_json(rel_path)
    if obj is not None:
        return obj
    return read_gcs_json(rel_path)

def list_stage_log_files_local(stage_code: str) -> List[str]:
    base = os.path.join(local_base_dir(), f"logs/{stage_code}/by_key")
    if not os.path.isdir(base):
        return []
    return sorted([os.path.join(base, fn) for fn in os.listdir(base) if fn.endswith(".json")])

def list_stage_log_files_gcs(stage_code: str) -> List[str]:
    return gcs_glob(f"{gcs_base_prefix()}/logs/{stage_code}/by_key", pattern="*.json")

def load_stage_logs(stage_code: str) -> pd.DataFrame:
    # Prefer local, fall back to GCS, union with de-dup by "key"
    records = []
    # Local
    for p in list_stage_log_files_local(stage_code):
        try:
            with open(p, "r", encoding="utf-8") as f:
                records.append(json.load(f))
        except Exception:
            pass
    # GCS (include only keys not already loaded)
    seen = set(r.get("key") for r in records if "key" in r)
    gpaths = list_stage_log_files_gcs(stage_code)
    fs = fs_gcs()
    for gp in gpaths:
        key = os.path.basename(gp).replace(".json","")
        if key in seen:
            continue
        try:
            with fs.open(gp, "r") as f:
                records.append(json.load(f))
                seen.add(key)
        except Exception:
            pass
    if not records:
        return pd.DataFrame(columns=[
            "key","stage","live_id","chunk_idx","qname",
            "new_gt_index","pred_index","correct","invalid_format"
        ])
    return pd.DataFrame(records)

# ====================== Input loading & task build ======================
def load_runlog_df() -> pd.DataFrame:
    print(f"RUN_TAG: {RUN_TAG}")
    print(f"Loading adopted log: {ADOPT_LOG_PATH}")
    txt = read_gcs_text(ADOPT_LOG_PATH)
    return pd.read_csv(io.StringIO(txt))

def select_live_ids(df: pd.DataFrame, k: int) -> List[int]:
    if LIVE_IDS_EXPLICIT:
        uniq = []
        s = set()
        for lid in LIVE_IDS_EXPLICIT:
            if lid not in s and lid in df.loc[df["status"]=="ok","live_id"].dropna().astype(int).values:
                s.add(lid); uniq.append(lid)
            if len(uniq)>=k: break
        return uniq
    cand = df.loc[df["status"]=="ok","live_id"].dropna().astype(int).tolist()
    seen, out = set(), []
    for lid in cand:
        if lid not in seen:
            seen.add(lid); out.append(lid)
        if len(out) >= k:
            break
    return out

def _load_json_from_gcs_or_none(path: Optional[str]) -> Optional[Dict[str, Any]]:
    if not isinstance(path, str) or not path:
        return None
    try:
        return json.loads(read_gcs_text(path))
    except Exception:
        return None

def build_tasks_from_runlog(df: pd.DataFrame, live_ids: List[int]) -> List[Dict[str, Any]]:
    tasks = []
    sub = df[(df["status"]=="ok") & (df["live_id"].isin(live_ids))].copy()
    sub = sub.dropna(subset=["mcq_full_path"])
    for _, r in sub.iterrows():
        full = _load_json_from_gcs_or_none(r["mcq_full_path"])
        pub  = _load_json_from_gcs_or_none(r.get("mcq_public_path"))
        if not (isinstance(full, dict) and full.get("questions") and len(full["questions"])==2):
            continue
        if not (isinstance(pub, dict) and pub.get("questions") and len(pub["questions"])==2):
            pub = {
                "questions":[
                    {"question": full["questions"][0]["question"], "choices": full["questions"][0]["choices"][:4]},
                    {"question": full["questions"][1]["question"], "choices": full["questions"][1]["choices"][:4]},
                ]
            }
        q1_full, q2_full = full["questions"][0], full["questions"][1]
        q1_pub,  q2_pub  = pub["questions"][0],  pub["questions"][1]
        # texts
        combined_text = ""
        if isinstance(r.get("combined_path"), str):
            try:
                combined_text = read_gcs_text(r["combined_path"])
            except Exception:
                pass
        chat_text = ""
        if isinstance(r.get("chat_path"), str):
            try:
                chat_text = read_gcs_text(r["chat_path"])
            except Exception:
                pass
        has_chat = bool(isinstance(r.get("chat_path"), str) and r["chat_path"])

        # Q1
        tasks.append({
            "live_id": int(r["live_id"]),
            "chunk_idx": int(r["chunk_idx"]) if pd.notna(r["chunk_idx"]) else None,
            "qname": "Q1",
            "question": str(q1_pub["question"]),
            "choices": list(q1_pub["choices"])[:4],
            "orig_gt_index": int(q1_full["answer_index"]),
            "combined_text": combined_text,
            "chat_text": chat_text,
            "has_chat": has_chat,
        })
        # Q2
        tasks.append({
            "live_id": int(r["live_id"]),
            "chunk_idx": int(r["chunk_idx"]) if pd.notna(r["chunk_idx"]) else None,
            "qname": "Q2",
            "question": str(q2_pub["question"]),
            "choices": list(q2_pub["choices"])[:4],
            "orig_gt_index": int(q2_full["answer_index"]),
            "combined_text": combined_text,
            "chat_text": chat_text,
            "has_chat": has_chat,
        })
    return tasks

# ====================== Stable shuffle ======================
def _key_tuple_str(lid: int, chunk_idx: Optional[int], qname: str) -> str:
    return f"{lid}_{'NA' if chunk_idx is None else chunk_idx}_{qname}"

def _seed_for_item(base_seed: int, lid: int, chunk_idx: Optional[int], qname: str) -> int:
    # 安定乱数用に blake2b から 64bit int を生成
    h = hashlib.blake2b(digest_size=8)
    h.update(f"{base_seed}|{lid}|{chunk_idx}|{qname}".encode("utf-8"))
    return int.from_bytes(h.digest(), "little", signed=False)

def stable_shuffle_choices_and_gt(choices: List[str], gt_index: int, base_seed: int,
                                  lid: int, chunk_idx: Optional[int], qname: str
                                 ) -> Tuple[List[str], int, List[int]]:
    seed = _seed_for_item(base_seed, lid, chunk_idx, qname)
    rng = np.random.default_rng(seed)
    perm = rng.permutation(4)
    new_choices = [choices[old] for old in perm]
    new_gt_index = int(np.where(perm == gt_index)[0][0])
    assert new_choices[new_gt_index] == choices[gt_index], "Shuffle mismatch: correct not preserved"
    return new_choices, new_gt_index, perm.tolist()

def get_or_make_shuffle_record(t: Dict[str, Any]) -> Dict[str, Any]:
    """
    1問（Q1/Q2）に対して、シャッフル結果を決定・保存（by_key JSON）。既存があればそれを使う。
    """
    key = _key_tuple_str(t["live_id"], t["chunk_idx"], t["qname"])
    rel = f"shuffled/by_key/{key}.json"
    rec = read_any_json(rel)
    if rec is not None:
        return rec
    new_choices, new_gt, perm = stable_shuffle_choices_and_gt(
        t["choices"], t["orig_gt_index"], RNG_SEED_BASE, t["live_id"], t["chunk_idx"], t["qname"]
    )
    rec = {
        "key": key,
        "live_id": t["live_id"],
        "chunk_idx": t["chunk_idx"],
        "qname": t["qname"],
        "question": t["question"],
        "orig_choices": t["choices"],
        "new_choices": new_choices,
        "orig_gt_index": t["orig_gt_index"],
        "new_gt_index": new_gt,
        "perm_new_to_old": perm,  # new_idx -> old_idx
    }
    write_both_json(rel, rec)
    return rec

# ====================== Prediction parsing ======================
def parse_pred_index(obj: Dict[str, Any], qname: str) -> Tuple[Optional[int], int]:
    if not isinstance(obj, dict):
        return (None, 1)
    ans = obj.get("answers", {})
    if not isinstance(ans, dict):
        return (None, 1)
    key = f"{qname.lower()}_index"
    if key in ans:
        try:
            v = int(ans[key]);  return (v if 0<=v<=3 else None, 0 if 0<=v<=3 else 1)
        except Exception:
            return (None, 1)
    if "index" in ans:
        try:
            v = int(ans["index"]);  return (v if 0<=v<=3 else None, 0 if 0<=v<=3 else 1)
        except Exception:
            return (None, 1)
    return (None, 1)

# ====================== Stage execution (resumable, logs+samples saved) ======================
def stage_answer(model: GenerativeModel, tasks: List[Dict[str, Any]],
                 stage_code: str,
                 source_label: str, include_body: bool) -> List[Dict[str, Any]]:
    """
    - 既存ログ(logs/{stage}/by_key/{key}.json)があればスキップ
    - 新規に処理した分は logs にJSON保存 + samples/{stage}/{correct|incorrect}/{key}.txt に保存
    - ローカルとGCSに同時保存
    """
    new_logs = []
    pbar = tqdm(total=len(tasks), desc=f"{stage_code}: 0/{len(tasks)}", dynamic_ncols=True)
    done = 0

    for t in tasks:
        key = _key_tuple_str(t["live_id"], t["chunk_idx"], t["qname"])
        # 既にこのステージのログがあればスキップ
        log_rel = f"logs/{stage_code}/by_key/{key}.json"
        if exists_any(log_rel):
            done += 1
            pbar.set_description_str(f"{stage_code}: {done}/{len(tasks)}")
            pbar.update(1)
            continue

        # シャッフル結果（存在しなければ作る）
        sh = get_or_make_shuffle_record(t)
        qjson = {"question": sh["question"], "choices": sh["new_choices"]}
        new_gt = int(sh["new_gt_index"])

        # 本文の用意
        body = ""
        if include_body:
            if "combined" in stage_code.lower():
                body = t.get("combined_text","") or ""
            elif "chat" in stage_code.lower():
                body = t.get("chat_text","") or ""
            else:
                body = ""
        # コンテンツ組み立て & 呼び出し
        content = build_content(t["qname"], source_label, qjson, body_text=body)
        obj, raw = call_model_return_both(model, content)
        pred, invalid = parse_pred_index(obj, t["qname"])
        correct = int(pred == new_gt) if (invalid == 0 and pred is not None) else None

        # 1レコードのログ（このステージ・この問題の最終記録）
        log_rec = {
            "key": key,
            "stage": stage_code,
            "live_id": t["live_id"],
            "chunk_idx": t["chunk_idx"],
            "qname": t["qname"],
            "source_label": source_label,
            "include_body": bool(include_body),
            "question": sh["question"],
            "choices": sh["new_choices"],
            "orig_gt_index": int(sh["orig_gt_index"]),
            "new_gt_index": int(new_gt),
            "pred_index": (int(pred) if pred is not None else None),
            "invalid_format": int(invalid),
            "correct": (int(correct) if correct is not None else None),
            "prompt_sent": content,
            "raw_response": raw,
            "ts": datetime.datetime.now().isoformat(timespec="seconds"),
        }
        # 保存（by_key JSON）
        write_both_json(log_rel, log_rec)

        # samples にも（全件）
        if SAVE_EVERY_QA_SAMPLE:
            cls = "correct" if (correct == 1) else "incorrect" if (correct == 0) else "invalid"
            sample_rel = f"samples/{stage_code}/{cls}/{key}.txt"
            sample_txt = []
            sample_txt.append(f"### META\nstage: {stage_code}\nkey: {key}\nlive_id: {t['live_id']}\nchunk_idx: {t['chunk_idx']}\nqname: {t['qname']}\n")
            sample_txt.append(f"orig_gt_index: {sh['orig_gt_index']}  new_gt_index: {new_gt}  pred_index: {pred}  correct: {correct}  invalid_format: {invalid}")
            sample_txt.append("\n### QUESTION (shuffled)\n" + json.dumps(qjson, ensure_ascii=False, indent=2))
            if include_body and body:
                sample_txt.append("\n### BODY\n" + body)
            sample_txt.append("\n### PROMPT SENT (exact)\n" + content)
            sample_txt.append("\n### RAW RESPONSE\n" + raw)
            write_both_text(sample_rel, "\n".join(sample_txt))

        new_logs.append(log_rec)

        done += 1
        pbar.set_description_str(f"{stage_code}: {done}/{len(tasks)}")
        pbar.update(1)

    pbar.close()
    return new_logs

# ====================== Metrics & distributions ======================
def index_distribution(series: pd.Series, k: int=4) -> Dict[int, float]:
    cnt = series.value_counts().reindex(range(k), fill_value=0).sort_index()
    tot = int(cnt.sum())
    if tot == 0:
        return {i: 0.0 for i in range(k)}
    share = (cnt / tot).fillna(0.0)
    return {int(i): float(share.iloc[i]) for i in range(k)}

def summarize_stage_df(df: pd.DataFrame, qname: str, stage_code: str) -> Dict[str, Any]:
    sub = df[(df["qname"]==qname) & (df["stage"]==stage_code)].copy()
    provided = len(sub)
    valid = int((sub["invalid_format"]==0).sum())
    correct_total = int(sub.loc[sub["invalid_format"]==0, "correct"].fillna(0).astype(int).sum())
    acc = float(correct_total / valid) if valid else None
    pred_dist = index_distribution(sub.loc[sub["invalid_format"]==0, "pred_index"].dropna().astype(int), k=4) if valid else {i:0.0 for i in range(4)}
    return {
        "stage": stage_code,
        "qname": qname,
        "provided": int(provided),
        "answered_valid": int(valid),
        "correct_total": int(correct_total),
        "accuracy": acc,
        "pred_index_dist": pred_dist,
        "invalid_format_total": int((sub["invalid_format"]==1).sum()),
    }

# ====================== Main flow ======================
# 0) Run label (new or resume)
if RESUME_RUN_LABEL.strip():
    RUN_LABEL = RESUME_RUN_LABEL.strip()
else:
    RUN_LABEL = RUN_LABEL or make_run_label()

# Prepare base dirs
ensure_local_dir(local_base_dir())

# Write manifest if not exists
manifest_rel = "manifest.json"
if not exists_any(manifest_rel):
    manifest = {
        "run_label": RUN_LABEL,
        "run_tag": RUN_TAG,
        "num_live_requested": NUM_LIVE,
        "rng_seed_base": RNG_SEED_BASE,
        "model_name": MODEL_NAME,
        "temperature": TEMPERATURE,
        "created_at": datetime.datetime.now().isoformat(timespec="seconds"),
    }
    write_both_json(manifest_rel, manifest)
print(f"RUN_LABEL: {RUN_LABEL}")
print(f"Local out: {local_base_dir()}")
print(f"GCS out:   {gcs_base_prefix()}")

# 1) Load runlog & select lives
run_df = load_runlog_df()
target_live_ids = select_live_ids(run_df, NUM_LIVE)
print(f"選択 live_id: {target_live_ids}")

# Persist selected lives (for resume traceability)
selected_rel = "selected_live_ids.json"
if not exists_any(selected_rel):
    write_both_json(selected_rel, {"live_ids": target_live_ids})

# 2) Build tasks
all_tasks = build_tasks_from_runlog(run_df, target_live_ids)
print(f"総タスク数(Q1/Q2合計): {len(all_tasks)}")

# 3) Generate/Save shuffle records for all tasks (resumable)
print("シャッフル（安定）を決定・保存中...")
for t in tqdm(all_tasks, total=len(all_tasks), desc="shuffle map", dynamic_ncols=True):
    _ = get_or_make_shuffle_record(t)

# 3-1) Show GT distributions before/after
def collect_gt_dists(tasks: List[Dict[str,Any]]) -> Tuple[Dict[int,float], Dict[int,float]]:
    q = [t for t in tasks]
    # load shuffled records to get new_gt_index
    orig, neww = [], []
    for t in q:
        sh = get_or_make_shuffle_record(t)
        orig.append(int(sh["orig_gt_index"]))
        neww.append(int(sh["new_gt_index"]))
    return index_distribution(pd.Series(orig)), index_distribution(pd.Series(neww))

q1_tasks = [t for t in all_tasks if t["qname"]=="Q1"]
q2_tasks = [t for t in all_tasks if t["qname"]=="Q2"]
q1_orig_dist, q1_new_dist = collect_gt_dists(q1_tasks)
q2_orig_dist, q2_new_dist = collect_gt_dists(q2_tasks)
print("\n=== GT index distribution (original vs shuffled) ===")
print("Q1 original:", q1_orig_dist)
print("Q1 shuffled:", q1_new_dist)
print("Q2 original:", q2_orig_dist)
print("Q2 shuffled:", q2_new_dist)

# 4) Init model (single model used across stages; names map to LLM2..LLM9 conceptually)
model = init_vertex_ai(MODEL_NAME)

# 5) Run stages (resumable; each stage skips already-logged keys)
# Q1 primary: LLM2 (combined), LLM3 (nothing)
logs_q1_llm2_new = stage_answer(model, q1_tasks, stage_code="Q1_LLM2_combined",
                                source_label="配信ログ本文（チャット＋音声書き起こし）", include_body=True)
logs_q1_llm3_new = stage_answer(model, q1_tasks, stage_code="Q1_LLM3_nothing",
                                source_label="問題文のみ", include_body=False)

# Load full logs (existing + new) to compute subset
df_q1_llm2 = load_stage_logs("Q1_LLM2_combined")
df_q1_llm3 = load_stage_logs("Q1_LLM3_nothing")
subset_q1 = df_q1_llm2.merge(df_q1_llm3, on=["key","live_id","chunk_idx","qname"], suffixes=("_llm2","_llm3"))
subset_q1 = subset_q1[(subset_q1["correct_llm2"]==1) & ((subset_q1["correct_llm3"].isna()) | (subset_q1["correct_llm3"]==0))]

subset_q1_keys = set(subset_q1["key"].tolist())
subset_q1_tasks = [t for t in q1_tasks if _key_tuple_str(t["live_id"], t["chunk_idx"], t["qname"]) in subset_q1_keys]
subset_q1_tasks_with_chat = [t for t in subset_q1_tasks if t["has_chat"]]

# Q1 extra: LLM4 (chat only), LLM5 (nothing)
logs_q1_llm4_new = stage_answer(model, subset_q1_tasks_with_chat, stage_code="Q1_LLM4_chat",
                                source_label="チャット本文", include_body=True)
logs_q1_llm5_new = stage_answer(model, subset_q1_tasks, stage_code="Q1_LLM5_nothing",
                                source_label="問題文のみ", include_body=False)

# Q2 primary: LLM6 (combined), LLM7 (nothing)
logs_q2_llm6_new = stage_answer(model, q2_tasks, stage_code="Q2_LLM6_combined",
                                source_label="配信ログ本文（チャット＋音声書き起こし）", include_body=True)
logs_q2_llm7_new = stage_answer(model, q2_tasks, stage_code="Q2_LLM7_nothing",
                                source_label="問題文のみ", include_body=False)

# Load full logs to compute subset for Q2
df_q2_llm6 = load_stage_logs("Q2_LLM6_combined")
df_q2_llm7 = load_stage_logs("Q2_LLM7_nothing")
subset_q2 = df_q2_llm6.merge(df_q2_llm7, on=["key","live_id","chunk_idx","qname"], suffixes=("_llm6","_llm7"))
subset_q2 = subset_q2[(subset_q2["correct_llm6"]==1) & ((subset_q2["correct_llm7"].isna()) | (subset_q2["correct_llm7"]==0))]

subset_q2_keys = set(subset_q2["key"].tolist())
subset_q2_tasks = [t for t in q2_tasks if _key_tuple_str(t["live_id"], t["chunk_idx"], t["qname"]) in subset_q2_keys]
subset_q2_tasks_with_chat = [t for t in subset_q2_tasks if t["has_chat"]]

# Q2 extra: LLM8 (chat only), LLM9 (nothing)
logs_q2_llm8_new = stage_answer(model, subset_q2_tasks_with_chat, stage_code="Q2_LLM8_chat",
                                source_label="チャット本文", include_body=True)
logs_q2_llm9_new = stage_answer(model, subset_q2_tasks, stage_code="Q2_LLM9_nothing",
                                source_label="問題文のみ", include_body=False)

# 6) Collect all logs (existing+new) for metrics
stage_codes = ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing",
               "Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]
all_dfs = [load_stage_logs(st) for st in stage_codes]
logs_df = pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()

# 7) Summaries & distributions
summaries = {"Q1": {}, "Q2": {}}
for st in ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing"]:
    summaries["Q1"][st] = summarize_stage_df(logs_df, "Q1", st)
for st in ["Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]:
    summaries["Q2"][st] = summarize_stage_df(logs_df, "Q2", st)

subset_sizes = {
    "Q1_total": len(subset_q1_tasks),
    "Q1_with_chat": len(subset_q1_tasks_with_chat),
    "Q2_total": len(subset_q2_tasks),
    "Q2_with_chat": len(subset_q2_tasks_with_chat),
}

metrics = {
    "run_label": RUN_LABEL,
    "run_tag": RUN_TAG,
    "num_live": len(target_live_ids),
    "gt_index_distribution": {
        "Q1_original": q1_orig_dist,
        "Q1_shuffled": q1_new_dist,
        "Q2_original": q2_orig_dist,
        "Q2_shuffled": q2_new_dist,
    },
    "stage_summaries": summaries,
    "subset_sizes": subset_sizes,
    "generated_at": datetime.datetime.now().isoformat(timespec="seconds"),
}

# 8) Save metrics (local+GCS)
write_both_json("metrics.json", metrics)

# 9) Print human-readable summaries
def short_line(s: Dict[str, Any]) -> str:
    acc = "NA" if s["accuracy"] is None else f"{s['accuracy']:.3f}"
    return (f"{s['stage']}: provided={s['provided']} valid={s['answered_valid']} "
            f"correct={s['correct_total']} acc={acc} pred_dist={s['pred_index_dist']} invalid={s['invalid_format_total']}")

print("\n=== Index distributions (GT before/after shuffle) ===")
print(f"Q1 original GT: {q1_orig_dist}")
print(f"Q1 shuffled GT: {q1_new_dist}")
print(f"Q2 original GT: {q2_orig_dist}")
print(f"Q2 shuffled GT: {q2_new_dist}")

print("\n--- Q1 ---")
for st in ["Q1_LLM2_combined","Q1_LLM3_nothing","Q1_LLM4_chat","Q1_LLM5_nothing"]:
    print(short_line(summaries["Q1"][st]))

print("\n--- Q2 ---")
for st in ["Q2_LLM6_combined","Q2_LLM7_nothing","Q2_LLM8_chat","Q2_LLM9_nothing"]:
    print(short_line(summaries["Q2"][st]))

print("\n=== Conditional subset sizes ===")
print(f"Q1 subset (LLM2 correct & LLM3 incorrect): total={subset_sizes['Q1_total']}; with_chat={subset_sizes['Q1_with_chat']}")
print(f"Q2 subset (LLM6 correct & LLM7 incorrect): total={subset_sizes['Q2_total']}; with_chat={subset_sizes['Q2_with_chat']}")

print(f"\nAll done. Local outputs: {local_base_dir()}")
print(f"GCS outputs:          {gcs_base_prefix()}")
