# MedGemma 1.5 原始模型 (W4A16/FP16) + RadGraph F1 评估

## 环境说明
- **Kaggle P100**：需 PyTorch cu118（预装 cu128 不支持 sm_60）
- **transformers 4.47.1**：支持 Gemma3 的稳定版本
- **PyTorch**：使用 Kaggle 预装版本（避免网络下载错误）
- **运行流程**：安装 → Restart Session → Run All

## Kaggle 数据集
- **Add Input**：`mimic-cxr-dataset`（含 official_data_iccv_final/files/ 图片）
- **CSV**：`mimic_eval_single_image_final_233.csv`（可单独上传或放 mimic-cxr-dataset 根目录）

# Cell 1: 安装（必须先运行，完成后 Restart Session）

In [None]:
# Kaggle P100 环境：使用预装 PyTorch + transformers 4.47.1
import subprocess, sys, os, shutil

TF_ENV = "/kaggle/working/tf_env"
CONSTRAINTS = "/kaggle/working/constraints.txt"

# 清理旧环境
if os.path.exists(TF_ENV):
    shutil.rmtree(TF_ENV)

# 1. 使用 Kaggle 预装 PyTorch（避免网络下载错误）
print("使用 Kaggle 预装的 PyTorch（通常是 2.x + cu121，兼容 P100）...")

# 2. 创建约束文件（只锁定 transformers 版本）
with open(CONSTRAINTS, "w") as f:
    f.write("transformers==4.47.1\n")

# 3. 使用约束文件安装 transformers 到 TF_ENV（使用清华镜像加速）
subprocess.run([
    "pip", "install", "--target", TF_ENV, "-q", "-c", CONSTRAINTS,
    "transformers==4.47.1", "pillow<12", "jinja2",
    "-i", "https://pypi.tuna.tsinghua.edu.cn/simple"
], check=True)

# 4. 单独安装 radgraph 到系统（使用清华镜像）
subprocess.run([
    "pip", "install", "-q", "radgraph",
    "-i", "https://pypi.tuna.tsinghua.edu.cn/simple"
], check=True)

# 5. 设置路径优先级
sys.path.insert(0, TF_ENV)

print("✅ 安装完成，请立刻 Restart Session，然后 Run All")

# Cell 2: HuggingFace 登录（MedGemma 为 gated 模型）

- [申请访问](https://huggingface.co/google/medgemma-1.5-4b-it)
- Kaggle：Add-ons → Secrets → Label 填 `zhuxirui11`，Value 填你的 HF token

In [None]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

try:
    tok = UserSecretsClient().get_secret("zhuxirui11")
    if tok:
        login(token=tok)
        print("✅ HF 登录成功")
    else:
        print("⚠️ HUGGINGFACE_TOKEN 为空")
except Exception as e:
    print("❌ 未配置 zhuxirui11 Secret")
    print("\n请按以下步骤配置：")
    print("1. 点击右侧 'Add-ons' → 'Secrets'")
    print("2. Label 填写：zhuxirui11")
    print("3. Value 填写你的 HF token（从 https://huggingface.co/settings/tokens 获取）")
    print("4. 保存后重新运行本 Cell")
    print("\n⚠️ MedGemma 需要先在 https://huggingface.co/google/medgemma-1.5-4b-it 申请访问权限")

# Cell 3: 环境检查

In [None]:
import sys, torch

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    USE_BF16 = torch.cuda.get_device_capability(0)[0] >= 8
    DTYPE = torch.bfloat16 if USE_BF16 else torch.float16
    print(f"Precision: {'BF16' if USE_BF16 else 'FP16 (P100)'}")
else:
    DTYPE = torch.float32
    print("⚠️ No GPU, using FP32")

# Cell 4: 导入与路径配置

In [None]:
import sys, os, gc
import pandas as pd
from PIL import Image
from tqdm import tqdm

# 确保 TF_ENV 优先
TF_ENV = "/kaggle/working/tf_env"
if TF_ENV not in sys.path:
    sys.path.insert(0, TF_ENV)
elif sys.path[0] != TF_ENV:
    sys.path.remove(TF_ENV)
    sys.path.insert(0, TF_ENV)

# Kaggle 数据集路径
DATASET_ROOT = "/kaggle/input/datasets/simhadrisadaram/mimic-cxr-dataset/official_data_iccv_final"
CSV_CANDIDATES = [
    "/kaggle/input/datasets/xiruizhu1111/clean-data/mimic_eval_single_image_final_233.csv",
    "/kaggle/input/datasets/xiruizhu1111/clean-train-data/mimic_eval_single_image_final_233.csv",
    "/kaggle/input/clean-data/mimic_eval_single_image_final_233.csv",
    "/kaggle/working/mimic_eval_single_image_final_233.csv",
]

# 查找 CSV 文件
CSV_PATH = None
for p in CSV_CANDIDATES:
    if os.path.exists(p):
        CSV_PATH = p
        break

if not CSV_PATH:
    print("❌ 未找到 CSV 文件！")
    print("尝试的路径：")
    for p in CSV_CANDIDATES:
        print(f"  - {p}")
    print("\n请检查：")
    print("1. 是否已添加包含 CSV 的数据集？（右侧 Add Data）")
    print("2. CSV 文件名是否为 'mimic_eval_single_image_final_233.csv'？")
    raise FileNotFoundError("CSV 文件不存在")

# 验证 transformers 版本（必须是 4.x，不能是 5.x）
import transformers
tf_ver = transformers.__version__
if tf_ver.startswith('5.'):
    raise RuntimeError(
        f"❌ transformers {tf_ver} 不兼容！需要 4.47.1\n"
        f"解决方法：重新运行 Cell 1 安装，确保看到 'transformers-4.47.1'，然后 Restart Session"
    )
print(f"✅ transformers: {tf_ver}")

print(f"Dataset: {DATASET_ROOT}")
print(f"✅ CSV: {CSV_PATH}")

# Cell 4.5: 检查 CSV 结构（首次运行时查看）

In [None]:
# 检查 CSV 文件结构
import pandas as pd

df_check = pd.read_csv(CSV_PATH)

print(f"✅ CSV 文件加载成功！")
print(f"总行数: {len(df_check)}")
print(f"\n列名: {list(df_check.columns)}")
print(f"\n前 3 行数据预览：")
print(df_check.head(3))

# 检查关键列是否存在
print("\n=== 关键列检查 ===")
has_img_path = "Image_Path" in df_check.columns
has_ground_truth = "Ground_Truth" in df_check.columns or "text" in df_check.columns
has_subject_id = "subject_id" in df_check.columns
has_view = "View" in df_check.columns

print(f"Image_Path 列: {'✅ 存在' if has_img_path else '❌ 不存在（将查找 PA/AP/Lateral 列）'}")
print(f"Ground_Truth/text 列: {'✅ 存在' if has_ground_truth else '❌ 不存在'}")
print(f"subject_id 列: {'✅ 存在' if has_subject_id else '⚠️ 不存在'}")
print(f"View 列: {'✅ 存在' if has_view else '⚠️ 不存在（将默认使用 PA）'}")

if not has_img_path:
    # 检查是否有 PA/AP/Lateral 列
    view_cols = [c for c in ['PA', 'AP', 'Lateral'] if c in df_check.columns]
    if view_cols:
        print(f"\n找到视图列: {view_cols}")
        print(f"示例路径: {df_check[view_cols[0]].iloc[0]}")
    else:
        print("\n❌ 警告：未找到图片路径列（Image_Path 或 PA/AP/Lateral）")

# Cell 5: 加载 MedGemma 原始模型

In [None]:
# 清除 transformers 缓存（torch 已在 env check 加载，不动）
for k in list(sys.modules.keys()):
    if k == "transformers" or k.startswith("transformers."):
        del sys.modules[k]

from transformers import AutoProcessor, AutoModelForImageTextToText

model_id = "google/medgemma-1.5-4b-it"
print(f"Loading MedGemma ({DTYPE})...")

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=DTYPE,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

if torch.cuda.is_available():
    mem_gb = torch.cuda.memory_allocated(0) / (1024**3)
    print(f"✅ Model loaded, GPU memory: {mem_gb:.2f} GB")
    torch.cuda.reset_peak_memory_stats()

# Cell 6: 生成报告

In [None]:
import re

# 简化 prompt（参考 MedGemma 官方示例）
PROMPT_TEMPLATE = "Generate a radiology report for this {view} chest X-ray."

def fix_image_path(path):
    """修正 CSV 中的图片路径为实际 Kaggle datasets 路径"""
    if pd.isna(path) or not path:
        return None
    path = str(path).strip()
    
    # 如果路径已经是正确的 datasets 路径，直接返回
    if path.startswith("/kaggle/input/datasets/"):
        return path if os.path.exists(path) else None
    
    # 修正旧路径：/kaggle/input/mimic-cxr-dataset/... → /kaggle/input/datasets/simhadrisadaram/mimic-cxr-dataset/...
    if path.startswith("/kaggle/input/mimic-cxr-dataset/"):
        new_path = path.replace("/kaggle/input/mimic-cxr-dataset/", "/kaggle/input/datasets/simhadrisadaram/mimic-cxr-dataset/")
        return new_path if os.path.exists(new_path) else None
    
    # 如果是相对路径，拼接 DATASET_ROOT
    if not path.startswith("/"):
        full = os.path.join(DATASET_ROOT, path)
        return full if os.path.exists(full) else None
    
    return path if os.path.exists(path) else None

def get_single_image_path(cell_val):
    if pd.isna(cell_val):
        return None
    s = str(cell_val).strip().replace("[", "").replace("]", "").replace("'", "").replace('"', "").split(",")[0].strip()
    if "files" in s:
        rel = "files" + s.split("files", 1)[1]
    else:
        rel = s.strip("/")
    full = os.path.join(DATASET_ROOT, rel) if not rel.startswith("/") else rel
    return fix_image_path(full)

def generate_report(model, processor, img_path, view="PA"):
    if not os.path.exists(img_path):
        return ""
    try:
        img = Image.open(img_path).convert("RGB")
    except:
        return ""
    prompt = PROMPT_TEMPLATE.format(view=view)
    msgs = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
    inp = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=DTYPE)
    L = inp["input_ids"].shape[-1]
    with torch.inference_mode():
        out = model.generate(
            **inp,
            max_length=None,  # 禁用默认 max_length
            max_new_tokens=300,  # 使用相对长度
            min_new_tokens=5,  # 强制最少生成
            pad_token_id=0,  # 显式设置 pad token
            do_sample=False
        )
    txt = processor.decode(out[0][L:], skip_special_tokens=True)
    return re.sub(r'\s+', ' ', txt.replace("Findings:", "").replace("Impression:", "")).strip()

# 生成报告
df = pd.read_csv(CSV_PATH)
GT_COL = "Ground_Truth" if "Ground_Truth" in df.columns else "text"
IMG_COL = "Image_Path" if "Image_Path" in df.columns else None

rows_out = []
NUM = min(50, len(df))

for idx, row in tqdm(df.head(NUM).iterrows(), total=NUM, desc="Generating reports"):
    path, view = None, "PA"
    if IMG_COL:
        path = fix_image_path(row.get(IMG_COL))  # 修正路径
        view = row.get("View", "PA")
    else:
        for c, v in [("PA", "PA"), ("AP", "AP"), ("Lateral", "Lateral")]:
            if c in df.columns and (p := get_single_image_path(row.get(c))):
                path, view = p, v
                break
    if not path:
        continue
    gt = str(row.get(GT_COL) or "").strip()
    if not gt or gt.startswith("You are"):
        continue
    rep = generate_report(model, processor, path, view)
    rows_out.append({
        "subject_id": row["subject_id"],
        "View": view,
        "Image_Path": path,
        "Ground_Truth": gt,
        "Generated_Report": rep
    })

df_sub = pd.DataFrame(rows_out)
print(f"\n✅ Generated {len(df_sub)} reports")

# Cell 7: RadGraph F1 评估

In [None]:
from radgraph import F1RadGraph
import numpy as np

refs = df_sub["Ground_Truth"].fillna("").tolist()
hyps = df_sub["Generated_Report"].fillna("").tolist()

f1radgraph = F1RadGraph(reward_level="all", model_type="modern-radgraph-xl")
mean_reward, reward_list, _, _ = f1radgraph(hyps=hyps, refs=refs)
rg_e, rg_er, rg_er_bar = mean_reward

print("=" * 60)
print("原始 MedGemma (W4A16/FP16) RadGraph F1")
print("-" * 60)
print(f"RG_E (Entity):           {float(rg_e)*100:.2f}")
print(f"RG_ER (Entity+Relation): {float(rg_er)*100:.2f}  ← 论文常用")
print(f"RG_ER_bar (Complete):    {float(rg_er_bar)*100:.2f}")
print("=" * 60)

ORIGINAL_SCORES = {"rg_e": float(rg_e), "rg_er": float(rg_er), "rg_er_bar": float(rg_er_bar)}
ORIGINAL_GPU_GB = torch.cuda.max_memory_allocated(0) / (1024**3) if torch.cuda.is_available() else 0
print(f"\nPeak GPU memory: {ORIGINAL_GPU_GB:.2f} GB")

# Cell 8: 保存结果

In [None]:
import json

# 释放模型
del model, processor
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# 保存结果
df_sub.to_csv("/kaggle/working/original_medgemma_results.csv", index=False)
with open("/kaggle/working/original_scores.json", "w") as f:
    json.dump({"scores": ORIGINAL_SCORES, "gpu_gb": ORIGINAL_GPU_GB}, f)

print("✅ 结果已保存至 /kaggle/working/")
print("\n下一步：运行 W4A4/W4A8 Notebook 进行量化对比")