# dataset


In [6]:
import json
data_path="/root/dataset/skin/SkinCAP/SkinCAP_20250712_121252_close_end_QA.json"
with open(data_path,encoding='utf-8') as file:
    data=json.load(file)

print(data[0])



{'image_name': '10.png', 'caption_zh': '红色光滑的外生性结节，根部稍缩窄。考虑鳞癌待查', 'caption_zh_polish': '红色光滑的外生性结节，根部稍缩窄，可能是鳞癌的表现。鳞癌是一种常见的皮肤肿瘤，通常起源于表皮层的角化细胞。这种类型的肿瘤通常会表现为皮肤上的病变，如结节或溃疡，需要进行进一步检查以确认诊断。', 'caption_zh_polish_en': 'The red, smooth, exophytic nodule with a slightly narrowed base may indicate squamous cell carcinoma. Squamous cell carcinoma is a common type of skin tumor that typically originates from the keratinocytes in the epidermis. This type of tumor often presents as skin lesions such as nodules or ulcers, and further investigation is needed to confirm the diagnosis.', 'answer': 'squamous cell carcinoma', 'question_type': 'close_end_QA'}


# Parse the response message

In [8]:
import os
import datasets

BASE_IMG_DIR = "/root/dataset/skin/SkinCAP/skincap"

ds = datasets.Dataset.from_list(data)

def add_image_path(ex):
    ex["image_path"] = os.path.join(BASE_IMG_DIR, ex["image_name"])
    return ex

ds = ds.map(add_image_path)

# Make it an Image column (this is fine)
ds = ds.cast_column("image_path", datasets.Image())

SYSTEM = "SYSTEM INSTRUCTION: think silently if needed."
USER_TEMPLATE = (
    "You are given a clinical image and a question.\n"
    "Return ONLY the disease name in English. No extra words.\n"
    "Question: {q}\n"
)

def to_prompt(ex):
    q = ex.get("caption_zh_polish_en") or ex.get("caption_zh") or ""
    return {
        # prompt ONLY text (fully serializable)
        "prompt": [
            {"role": "system", "content": SYSTEM},
            {"role": "user", "content": USER_TEMPLATE.format(q=q)},
        ],
        "answer": ex["answer"],
        # keep image in its own column
        "image": ex["image_path"],
    }

# Important: remove image_path if you create image column,
# otherwise you keep both (either is ok, but be consistent)
ds = ds.map(to_prompt)

# keep only what you need
keep_cols = ["prompt", "answer", "image", "image_name", "question_type"]
ds = ds.remove_columns([c for c in ds.column_names if c not in keep_cols])


Map:   0%|          | 0/1631 [00:00<?, ? examples/s]

Map:   0%|          | 0/1631 [00:00<?, ? examples/s]

In [9]:
import re
from difflib import SequenceMatcher

def _get_completion_text(completion) -> str:
    # conversational: [[{"role":"assistant","content":"..."}]]
    if isinstance(completion, list) and completion and isinstance(completion[0], dict):
        return completion[0].get("content", "") or ""
    # non-conversational: "..."
    return str(completion) if completion is not None else ""

def _normalize_disease(s: str) -> str:
    s = (s or "").strip().lower()
    # drop wrappers like quotes
    s = s.strip('"').strip("'")
    # remove common prefixes
    s = re.sub(r"^\s*(final\s*answer\s*:\s*)", "", s)
    # keep letters/numbers/spaces
    s = re.sub(r"[^a-z0-9\s\-]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

# 你可以按你的任务继续扩充
_ALIAS = {
    "scc": "squamous cell carcinoma",
    "squamous cell ca": "squamous cell carcinoma",
    "squamous cell cancer": "squamous cell carcinoma",
    "bcc": "basal cell carcinoma",
    "basal cell cancer": "basal cell carcinoma",
    "mm": "melanoma",
}

def _canonicalize(s: str) -> str:
    s = _normalize_disease(s)
    return _normalize_disease(_ALIAS.get(s, s))

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    rewards = []
    for comp, gt in zip(completions, answer):
        pred_raw = _get_completion_text(comp)
        pred = _canonicalize(pred_raw)
        gt_norm = _canonicalize(str(gt))

        if not pred:
            rewards.append(0.0)
            continue

        # strict match
        if pred == gt_norm:
            rewards.append(1.0)
            continue

        # optional: allow substring match (model outputs extra words)
        # If you really want ONLY disease name, keep this as 0.0 (strict).
        if gt_norm in pred or pred in gt_norm:
            rewards.append(0.5)
            continue

        # optional: fuzzy fallback (protects against minor typos)
        sim = SequenceMatcher(None, pred, gt_norm).ratio()
        rewards.append(0.5 if sim >= 0.92 else 0.0)

    return rewards


# GRPO config

In [None]:
import torch
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

ckpt = "/root/model/medgemma-1.5-4b-it"
output_dir="/root/model/GRPO_medgemma4b"

training_args = GRPOConfig(
    output_dir=output_dir,
    eval_on_start=False,                     # Run an evaluation at the very beginning of training.
    learning_rate=5e-6,                      # The initial learning rate for the AdamW optimizer.
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,           # Accumulate gradients for this many steps to simulate a larger batch size (per_device_train_batch_size * gradient_accumulation_steps).
    num_generations=2,                       # Number of completions to generate per prompt for GRPO's preference learning.
    max_prompt_length=512,                   # Maximum token length for input prompts.
    max_completion_length=1024,              # Maximum token length for the model's generated completions.
    max_steps=1700,
    logging_steps=20,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=100,
    report_to="tensorboard",
    use_vllm=True,                           # Use the vLLM library for significantly faster inference during generation.
    vllm_mode="colocate",                    # vLLM deployment mode; 'colocate' runs vLLM on the same GPU(s) as the trainer.
    vllm_gpu_memory_utilization=.30,         # Fraction of GPU memory that vLLM is allowed to use.
    bf16=True,                               # Enable bfloat16 mixed precision training to save memory and speed up training.
    gradient_checkpointing=True,             # Save memory by trading compute (avoids storing all intermediate activations).
    gradient_checkpointing_kwargs={
        "use_reentrant": False               # Use a more efficient implementation of gradient checkpointing.
    },
    model_init_kwargs={
        "device_map": "auto",
        "dtype": torch.bfloat16,             # Set model parameter data type to bfloat16.
        "attn_implementation": "eager"       # Gemma 3 recommends using the 'eager' attention implementation.
    },
    push_to_hub=True
)

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=64,
    lora_alpha=64,
    target_modules="all-linear",
)

# train model

In [None]:
train_dataset = ds.select(range(0, 3900))
eval_dataset  = ds.select(range(3900, 4000))

trainer = GRPOTrainer(
    model=ckpt,
    reward_funcs=[correctness_reward_func],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
)
trainer.train()
trainer.save_model(output_dir=training_args.output_dir)

In [6]:
import json
p = "dataset/skin/SkinCAP/SkinCAP_20260208_173640_close_end_QA.json"
data = json.load(open(p, "r", encoding="utf-8"))

bad = 0
for i, ex in enumerate(data):
    for k in ["answer", "image_name", "question_type", "caption_zh_polish_en"]:
        v = ex.get(k, None)
        if v is not None and not isinstance(v, str):
            print("BAD", i, k, type(v), v)
            bad += 1
            break

print("bad_count=", bad)

bad_count= 0
