# Cell 0: 环境设置（Kaggle P100）

- **P100 兼容**：Kaggle 预装 PyTorch cu128 不支持 P100 (sm_60)，需先运行下方安装 Cell 安装 PyTorch cu118
- **安装后**：若仍报 GPU 不兼容，请 **Restart Session** 后重新 Run All
- P100 不支持 BF16，自动用 FP16

In [None]:
# P100 (sm_60) 需 PyTorch cu118。transformers 4.46 安装到独立目录，避免与系统 5.x 冲突
import subprocess
import sys
subprocess.run(["pip", "install", "-q", "torch==2.7.1", "torchvision==0.22.1", "torchaudio==2.7.1", "--index-url", "https://download.pytorch.org/whl/cu118"], capture_output=True)
subprocess.run(["pip", "install", "-q", "pillow>=9.0,<12", "jinja2"], capture_output=True)
TF_ENV = "/kaggle/working/transformers_4.46"
subprocess.run(["pip", "install", "--target", TF_ENV, "--no-cache-dir", "-q", "transformers==4.46.0", "radgraph"], capture_output=True)
subprocess.run(["pip", "install", "-q", "bitsandbytes"], capture_output=True)
sys.path.insert(0, TF_ENV)

# Cell 1: W4A8 量化 + F1 + RadGraph

## W4A8 演算逻辑
- **W4**：权重 4-bit 量化（bitsandbytes NF4，因 MedGemma 无官方 AWQ）
- **A8**：激活 8-bit 量化。对每层 Linear 输入做 per-tensor 对称量化：
  - 范围 [-128, 127]，scale = max(|x|)/127
  - q = round(x * scale).clamp(-128, 127)，反量化 x' = q/scale
- **运用方法**：`forward_pre_hook` 在 Linear 前对输入做 fake 量化；真实加速需 QServe 等 INT4×INT8 GEMM kernel。
- **与 W4A4 对比**：A8 精度更高，F1 更接近原始；A4 显存更省，F1 略降。

# Cell 2: 环境检查（需先运行上方安装）

In [None]:
import sys
print(f"Python: {sys.version}")
import torch
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    USE_BF16 = torch.cuda.get_device_capability(0)[0] >= 8
    DTYPE = torch.bfloat16 if USE_BF16 else torch.float16
    print(f"精度: {'BF16' if USE_BF16 else 'FP16 (P100)'}")
else:
    DTYPE = torch.float32

# Cell 3: 激活量化工具（W4A8: 8-bit 激活）

In [None]:
import torch
import torch.nn as nn
from typing import List

def _fake_quant_activation(x: torch.Tensor, bits: int) -> torch.Tensor:
    """Per-tensor 对称 fake 量化: 8-bit [-128,127], 4-bit [-8,7]"""
    if not x.is_floating_point():
        return x
    max_val = x.abs().max().clamp(min=1e-8)
    if bits == 8:
        scale = 127.0 / max_val
        q = (x * scale).round().clamp(-128, 127)
    else:
        scale = 7.0 / max_val
        q = (x * scale).round().clamp(-8, 7)
    return q / scale

def _is_linear_like(module):
    return hasattr(module, "weight") and hasattr(module.weight, "shape") and len(module.weight.shape) == 2

def register_activation_quant_hooks(model: nn.Module, bits: int = 8) -> List:
    hooks = []
    def make_hook(b):
        def hook(module, input):
            if not input or not isinstance(input[0], torch.Tensor):
                return input
            inp = input[0]
            if inp.is_floating_point():
                return (_fake_quant_activation(inp, b),) + input[1:]
            return input
        return hook
    for name, m in model.named_modules():
        if "lm_head" in name:
            continue
        if _is_linear_like(m):
            h = m.register_forward_pre_hook(make_hook(bits))
            hooks.append((name, h))
    return hooks

# Cell 4: 导入与路径配置

In [None]:
import os
import gc
import pandas as pd
from PIL import Image
from tqdm import tqdm

DATASET_ROOT = "/kaggle/input/mimic-cxr-dataset/official_data_iccv_final"
CSV_CANDIDATES = ["/kaggle/input/mimic-cxr-dataset/mimic_eval_single_image_final_233.csv", "/kaggle/input/mimic-cxr-dataset/official_data_iccv_final/mimic_eval_single_image_final_233.csv", "/kaggle/input/mimic-eval-233/mimic_eval_single_image_final_233.csv", "/kaggle/working/mimic_eval_single_image_final_233.csv", "./mimic_eval_single_image_final_233.csv"]
CSV_PATH = next((p for p in CSV_CANDIDATES if os.path.exists(p)), CSV_CANDIDATES[0])
print(f"CSV: {CSV_PATH}")

# Cell 5: 先删除原始/其他模型，再加载 W4A8

In [None]:
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
print("GPU 已清空，准备加载 W4A8 模型")

# Cell 6: 加载 W4A8（4-bit 权重 + 8-bit 激活）

In [None]:
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

model_id = "google/medgemma-1.5-4b-it"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=DTYPE,
    bnb_4bit_quant_type="nf4",
)

print("加载 W4A8: 4-bit 权重 (bitsandbytes) + 8-bit 激活 (hook)")
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

hooks = register_activation_quant_hooks(model, bits=8)
print(f"已注册 {len(hooks)} 个 Linear 层的 8-bit 激活量化")

W4A8_GPU_GB = torch.cuda.max_memory_allocated(0) / (1024**3) if torch.cuda.is_available() else 0
print(f"W4A8 模型 GPU 占用: {W4A8_GPU_GB:.2f} GB")

# Cell 7: 图像到报告生成

In [None]:
import ast, re
PROMPT_TEMPLATE = "You are an expert radiologist. Describe this {view} view chest X-ray. Provide a concise report consisting of Findings and Impression. Focus on the heart, lungs, mediastinum, pleural space, and bones. Do NOT use bullet points, asterisks, or section headers. Do NOT include disclaimers or 'AI' warnings. Output pure medical text only."
def get_single_image_path(cell_val):
    if pd.isna(cell_val): return None
    s = str(cell_val).strip().replace("[","").replace("]","").replace("'","").replace('"',"").split(",")[0].strip()
    if "files" in s: rel = "files" + s.split("files",1)[1]
    else: rel = s.strip("/")
    full = os.path.join(DATASET_ROOT, rel) if not rel.startswith("/") else rel
    return full if os.path.exists(full) else None
def generate_report(model, processor, img_path, view="PA"):
    if not os.path.exists(img_path): return ""
    try: img = Image.open(img_path).convert("RGB")
    except: return ""
    msgs = [{"role":"user","content":[{"type":"image","image":img},{"type":"text","text":PROMPT_TEMPLATE.format(view=view)}]}]
    inp = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=DTYPE)
    L = inp["input_ids"].shape[-1]
    with torch.inference_mode(): out = model.generate(**inp, max_new_tokens=300, do_sample=False)
    return re.sub(r'\\s+', ' ', processor.decode(out[0][L:], skip_special_tokens=True).replace("Findings:","").replace("Impression:","")).strip()
df = pd.read_csv(CSV_PATH)
GT_COL = "Ground_Truth" if "Ground_Truth" in df.columns else "text"
IMG_COL = "Image_Path" if "Image_Path" in df.columns else None
rows_out = []
for idx, row in tqdm(df.head(50).iterrows(), total=min(50,len(df))):
    path, view = (row.get(IMG_COL), row.get("View","PA")) if IMG_COL else (None, "PA")
    if not path:
        for c, v in [("PA","PA"),("AP","AP"),("Lateral","Lateral")]:
            if c in df.columns and (p := get_single_image_path(row.get(c))): path, view = p, v; break
    if not path: continue
    gt = str(row.get(GT_COL) or "").strip()
    if not gt or gt.startswith("You are"): continue
    rep = generate_report(model, processor, path, view)
    rows_out.append({"subject_id":row["subject_id"],"View":view,"Image_Path":path,"Ground_Truth":gt,"Generated_Report":rep})
df_sub = pd.DataFrame(rows_out)
print(f"生成 {len(df_sub)} 条")

# Cell 8: RadGraph F1 评估

In [None]:
from radgraph import F1RadGraph

refs = df_sub["Ground_Truth"].fillna("").tolist()
hyps = df_sub["Generated_Report"].fillna("").tolist()
f1radgraph = F1RadGraph(reward_level="all", model_type="modern-radgraph-xl")
mean_reward, _, _, _ = f1radgraph(hyps=hyps, refs=refs)
rg_e, rg_er, rg_er_bar = mean_reward

print("=" * 50)
print("W4A8 RadGraph F1 分数")
print("-" * 50)
print(f"RG_E:        {float(rg_e)*100:.2f}")
print(f"RG_ER:       {float(rg_er)*100:.2f}")
print(f"RG_ER_bar:   {float(rg_er_bar)*100:.2f}")
print("=" * 50)

W4A8_SCORES = {"rg_e": float(rg_e), "rg_er": float(rg_er), "rg_er_bar": float(rg_er_bar)}

# Cell 9: 与原始模型对比

In [None]:
import json

orig_path = "/kaggle/working/original_scores.json"
if os.path.exists(orig_path):
    with open(orig_path) as f:
        orig = json.load(f)
    o_scores = orig.get("scores", {})
    o_gpu = orig.get("gpu_gb", 0)
    print("=" * 55)
    print("W4A8 与原始模型对比")
    print("-" * 55)
    print(f"{'指标':<12} {'原始':>10} {'W4A8':>10} {'变化':>10}")
    for k in ["rg_e", "rg_er", "rg_er_bar"]:
        o = o_scores.get(k, 0) * 100
        q = W4A8_SCORES.get(k, 0) * 100
        delta = q - o
        print(f"{k:<12} {o:>9.2f} {q:>9.2f} {delta:>+9.2f}")
    print(f"{'GPU (GB)':<12} {o_gpu:>9.2f} {W4A8_GPU_GB:>9.2f} {W4A8_GPU_GB-o_gpu:>+9.2f}")
    print("=" * 55)
else:
    print("未找到 original_scores.json，请先运行 01 原始模型 Notebook")

df_sub.to_csv("/kaggle/working/w4a8_results.csv", index=False)