# Cell 0: 环境设置（Kaggle P100）

- **P100 兼容**：Kaggle 预装 PyTorch cu128 不支持 P100 (sm_60)，需先运行下方安装 Cell 安装 PyTorch cu118
- **安装后**：若仍报 GPU 不兼容，请 **Restart Session** 后重新 Run All

In [None]:
# P100 (sm_60) 需 PyTorch cu118。transformers 4.46 安装到独立目录，避免与系统 5.x 冲突
import subprocess
import sys
subprocess.run(["pip", "install", "-q", "torch==2.7.1", "torchvision==0.22.1", "torchaudio==2.7.1", "--index-url", "https://download.pytorch.org/whl/cu118"], capture_output=True)
subprocess.run(["pip", "install", "-q", "pillow>=9.0,<12", "jinja2"], capture_output=True)
TF_ENV = "/kaggle/working/transformers_4.46"
subprocess.run(["pip", "install", "--target", TF_ENV, "--no-cache-dir", "-q", "transformers==4.46.0", "radgraph"], capture_output=True)
subprocess.run(["pip", "install", "-q", "bitsandbytes", "peft"], capture_output=True)
sys.path.insert(0, TF_ENV)

# Cell 1: 蒸馏说明

## 逻辑
- **Teacher**：原始 MedGemma，对 233 张图生成报告（或直接用 CSV 中 Generated_Report，因 233 由原模型筛选）
- **Student**：QLoRA（4-bit 量化 + LoRA），学习模仿 teacher 输出
- **蒸馏目标**：teacher 的生成序列作为 target，student 用 CE 损失逐 token 拟合

# Cell 2: 环境检查（需先运行上方安装）

In [None]:
import sys
print(f"Python: {sys.version}")
import torch
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    DTYPE = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
else:
    DTYPE = torch.float32

# Cell 3: 路径与数据

In [None]:
import os
import gc
import pandas as pd
from PIL import Image

CSV_CANDIDATES = ["/kaggle/input/mimic-cxr-dataset/mimic_eval_single_image_final_233.csv", "/kaggle/input/mimic-cxr-dataset/official_data_iccv_final/mimic_eval_single_image_final_233.csv", "/kaggle/input/mimic-eval-233/mimic_eval_single_image_final_233.csv", "/kaggle/working/mimic_eval_single_image_final_233.csv", "./mimic_eval_single_image_final_233.csv"]
CSV_PATH = next((p for p in CSV_CANDIDATES if os.path.exists(p)), CSV_CANDIDATES[0])

df = pd.read_csv(CSV_PATH)
print(f"共 {len(df)} 条")

# 蒸馏目标：用 teacher 输出。若 CSV 中 Generated_Report 来自原模型，可直接用；否则需先跑 teacher 生成
USE_CSV_TEACHER = df["Generated_Report"].notna().all() and (df["Generated_Report"].str.len() > 10).all()
print(f"使用 CSV 中 Generated_Report 作为 teacher 目标: {USE_CSV_TEACHER}")

# Cell 4: 获取 Teacher 目标（若 CSV 无则用原模型生成）

In [None]:
if USE_CSV_TEACHER:
    teacher_targets = df["Generated_Report"].fillna("").tolist()
    print("使用 CSV 中已有 teacher 输出")
else:
    from transformers import AutoProcessor, AutoModelForImageTextToText
    from tqdm import tqdm

    model_id = "google/medgemma-1.5-4b-it"
    print("加载 Teacher (原模型)...")
    teacher = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=DTYPE, device_map="auto")
    proc = AutoProcessor.from_pretrained(model_id)

    def gen(img_path, prompt="Describe this chest X-ray in a radiology report format."):
        if not os.path.exists(img_path):
            return ""
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception:
            return ""
        msgs = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
        inp = proc.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(teacher.device, dtype=DTYPE)
        L = inp["input_ids"].shape[-1]
        with torch.inference_mode():
            out = teacher.generate(**inp, max_new_tokens=512, do_sample=False)
        return proc.decode(out[0][L:], skip_special_tokens=True).strip()

    teacher_targets = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        teacher_targets.append(gen(row["Image_Path"]))

    del teacher
    del proc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("Teacher 已释放")

df["teacher_target"] = teacher_targets

# Cell 5: 构建蒸馏数据集

In [None]:
from torch.utils.data import Dataset

class DistillDataset(Dataset):
    def __init__(self, df, processor, image_col="Image_Path", target_col="teacher_target", prompt="Describe this chest X-ray in a radiology report format."):
        self.df = df.reset_index(drop=True)
        self.processor = processor
        self.image_col = image_col
        self.target_col = target_col
        self.prompt = prompt

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        path = row[self.image_col]
        target = str(row[self.target_col] or "")
        img = Image.open(path).convert("RGB") if os.path.exists(path) else Image.new("RGB", (224, 224), 0)
        msgs = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": self.prompt}]}]
        full = msgs + [{"role": "assistant", "content": target}]
        text = self.processor.apply_chat_template(full, tokenize=False, add_generation_prompt=False)
        return {"text": text, "image": img}

from transformers import AutoProcessor
proc = AutoProcessor.from_pretrained("google/medgemma-1.5-4b-it")
ds = DistillDataset(df, proc)
print(f"蒸馏样本数: {len(ds)}")

# Cell 6: 加载 Student（QLoRA）并训练

In [None]:
from transformers import AutoModelForImageTextToText, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
from torch.utils.data import DataLoader

def collate_fn(batch):
    texts = [b["text"] for b in batch]
    images = [b["image"] for b in batch]
    msgs = [[{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": "Describe this chest X-ray."}]}] for im in images]
    out = proc(
        text=[proc.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in msgs],
        images=images,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024,
    )
    return out

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

model_id = "google/medgemma-1.5-4b-it"
print("加载 Student (QLoRA)...")

from transformers import BitsAndBytesConfig
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=DTYPE, bnb_4bit_quant_type="nf4")

model = AutoModelForImageTextToText.from_pretrained(model_id, quantization_config=bnb, device_map="auto")

lora = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora)
model.print_trainable_parameters()

# Cell 7: 训练循环（蒸馏 CE 损失）

In [None]:
from transformers import DataCollatorForSeq2Seq

# 简化：用 SFT 方式，target 为 teacher 序列
train_args = TrainingArguments(
    output_dir="/kaggle/working/medgemma_distill_lora",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    fp16=not (DTYPE == torch.bfloat16),
    bf16=(DTYPE == torch.bfloat16),
    logging_steps=10,
    save_strategy="epoch",
    remove_unused_columns=False,
)

# 需自定义 DataCollator 处理 image+text；若 proc 不支持 batch image，可逐条训练
# 此处用简化版：仅对有效样本训练
train_df = df[df["teacher_target"].str.len() > 20].head(100)
train_ds = DistillDataset(train_df, proc)

def data_collator(batch):
    # 简化 collator：返回可用的 batch
    return batch

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_ds,
    data_collator=data_collator,
)

# 注意：MedGemma 的 Trainer 需自定义 compute_loss 以处理 image+text，此处为框架示例
# 完整实现需继承 Trainer 并重写 compute_loss，或使用 trl SFTTrainer
print("训练框架已就绪；完整 image+text 蒸馏建议用 trl SFTTrainer 或自定义 Trainer")

# Cell 8: 手动蒸馏训练循环（兼容 image+text）

In [None]:
from torch.optim import AdamW

model.train()
opt = AdamW(model.parameters(), lr=2e-5)
BATCH = 2
EPOCHS = 2

for ep in range(EPOCHS):
    total_loss = 0
    for i in range(0, min(50, len(train_ds)), BATCH):
        batch = [train_ds[j] for j in range(i, min(i + BATCH, len(train_ds)))]
        inputs_list = []
        for b in batch:
            msgs = [{"role": "user", "content": [{"type": "image", "image": b["image"]}, {"type": "text", "text": "Describe this chest X-ray in a radiology report format."}]}]
            inp = proc.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
            target_text = b["text"].split("model\n")[-1] if "model\n" in b["text"] else b["text"]
            target_ids = proc(text=target_text, return_tensors="pt", truncation=True, max_length=512)["input_ids"]
            full_ids = torch.cat([inp["input_ids"], target_ids], dim=1)
            if "pixel_values" in inp:
                full_ids = full_ids.to(model.device)
                pixel = inp["pixel_values"].to(model.device, dtype=DTYPE)
                out = model(input_ids=full_ids[:, :-1], pixel_values=pixel, labels=full_ids[:, 1:].clone())
            else:
                out = model(input_ids=full_ids[:, :-1].to(model.device), labels=full_ids[:, 1:].clone().to(model.device))
            loss = out.loss if hasattr(out, "loss") else out[0]
            loss.backward()
            opt.step()
            opt.zero_grad()
            total_loss += loss.item()
        if (i // BATCH) % 5 == 0:
            print(f"Epoch {ep+1} batch {i//BATCH} loss {total_loss/(i//BATCH+1):.4f}")
    print(f"Epoch {ep+1} avg loss: {total_loss / (len(range(0, min(50, len(train_ds)), BATCH)):.4f}")

model.save_pretrained("/kaggle/working/medgemma_distill_lora")
proc.save_pretrained("/kaggle/working/medgemma_distill_lora")