# Environment Setup

### rtx5060ti环境安装

    建议： 重新构建一个新的环境

    pip install -U pip wheel setuptools
    pip install numpy packaging psutil ninja
    
    pip install torch torchvision torchaudio \
      --index-url https://download.pytorch.org/whl/cu128
    
    pip install -U unsloth

    pip install vllm==0.10.2

    pip install flash-attn --no-build-isolation

    pip install wandb

In [None]:
# jupyter启动可能会报错，需要清除缓存
import shutil, os

cache = os.path.expanduser('~/.cache/torchinductor')
shutil.rmtree(cache, ignore_errors=True)
print('cache removed')

from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTConfig, GRPOConfig, SFTTrainer, GRPOTrainer
from vllm import SamplingParams

import gc
import re
import time
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from safetensors import safe_open
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextStreamer

# Model Setup

In [None]:
model_id = 'unsloth/Qwen3-0.6B-Base'  # Select model optimized for instruction-following and reasoning
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048  # Can increase for longer reasoning traces
lora_rank = 32  # Larger rank = smarter, but slower

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=max_seq_length,
    load_in_4bit=False,  # False for LoRA 16bit
    fast_inference=True,  # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8,  # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,  # Rank: adaptation capacity (16 good for reasoning tasks)
    lora_alpha=lora_rank * 2,  # Scaling factor (typically 2x rank)
    lora_dropout=0.1,  # Regularization to prevent overfitting
    target_modules=[  # Remove QKVO if out of memory
        'q_proj', 'k_proj', 'v_proj', 'o_proj',
        'gate_proj', 'up_proj', 'down_proj',
    ],
    use_gradient_checkpointing='unsloth',  # Reduces memory usage
    random_state=2025,
)

In [None]:
import torch

# 当前进程已占用的显存（MB）
allocated = torch.cuda.memory_allocated() / 1024 ** 2
# 当前进程已申请的总显存（MB，包括缓存）
reserved = torch.cuda.memory_reserved() / 1024 ** 2

print(f"Allocated: {allocated:.0f} MB")
print(f"Reserved:  {reserved:.0f} MB")

free, total = torch.cuda.mem_get_info()  # 单位：字节
print(f"整卡已用  {(total - free) / 1024 ** 2:6.1f} MB")
print(f"整卡剩余  {free / 1024 ** 2:6.1f} MB")

In [None]:
# Load the medical verifier model (3B SequenceClassification) for reward functions
# This handles semantic equivalence and aliases, outputting True/False based on alignment
verifier_path = 'FreedomIntelligence/medical_o1_verifier_3B'
verifier_tokenizer = AutoTokenizer.from_pretrained(verifier_path)
verifier_model = AutoModelForSequenceClassification.from_pretrained(
    verifier_path,
    dtype='auto',
    device_map='auto',
    attn_implementation='flash_attention_2',
    num_labels=2
)
verifier_model.eval()  # Set to evaluation mode

# Verifier template from model card
VERIFIER_TEMPLATE = '''<Model Response>
{}
</Model Response>

<Reference Answer>
{}
</Reference Answer>

Your task is to evaluate the model response by comparing it to the reference answer. If the model response is correct and aligns with the reference answer, output "True". If it is incorrect or fails to select the correct option (if options are provided), output "False". {}'''

# Chat Template

In [None]:
# Define structured output format for mathematical reasoning
REASONING_START = '<THINK>'  # Begin reasoning section
REASONING_END = '</THINK>'  # End reasoning section
ANSWER_START = '<ANSWER>'  # Begin final answer
ANSWER_END = '</ANSWER>'  # End final answer

# System prompt adapted for medical reasoning (inspired by paper's complex reasoning emphasis)
'''
你是一名医学推理助手。当给出一个医学问题时：
在 {REASONING_START} 和 {REASONING_END} 之间展示你逐步的复杂推理过程（包括反思、回溯以及替代路径）。
在 {ANSWER_START} 和 {ANSWER_END} 之间给出你的最终答案。
表达要精确，需考虑医学中的别名/同义词，并清晰展示所有推理与思考步骤。
'''
SYSTEM_PROMPT = f'''You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between {REASONING_START} and {REASONING_END}.
2. Provide your final answer between {ANSWER_START} and {ANSWER_END}.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.'''

print(SYSTEM_PROMPT)

### 模板解析举例
    messages = [
      {"role": "system", "content": "You are a medical reasoning assistant."},
      {"role": "user", "content": "什么是ACS？"},
      {"role": "assistant", "content": "ACS 指急性冠脉综合征"},
      {"role": "user", "content": "如何诊断？"},
    ]

    渲染结果：
        You are a medical reasoning assistant.</s>
        什么是ACS？
        ACS 指急性冠脉综合征</s>
        如何诊断？
        <REASONING_START>

    chat_template = (  # Build and assign chat_template to the tokenizer
    
        '''
        含义：
            如果用户已经传了 system： 直接用它
            如果没有：强制注入你定义好的 SYSTEM_PROMPT
        '''
        # If the very first message is a SYSTEM role, print it + <eos>:
        "{% if messages[0]['role'] == 'system' %}"
        "{{ messages[0]['content'] + eos_token }}"
        "{% set loop_messages = messages[1:] %}"
        "{% else %}"
        # Otherwise, inject our system_prompt + <eos>:
        "{{ '{system_prompt}' + eos_token }}"
        "{% set loop_messages = messages %}"
        "{% endif %}"
    
        '''
        含义：
            assistant 的回答是 “一个完整的训练样本”
            <eos> = reward / loss 的自然边界
            对 PPO / GRPO 特别友好
        '''
        # Now loop over the remaining messages (either user or assistant):
        "{% for message in loop_messages %}"
        "{% if message['role'] == 'user' %}"
        "{{ message['content'] }}"
        "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] + eos_token }}"
        "{% endif %}"
        "{% endfor %}"
    
        '''
        含义：
        add_generation_prompt=True
        最后就会变成：
            php-template
            复制代码
            ... <eos>
            <REASONING_START>
    
        '''
        # If we asked for "add_generation_prompt", append <REASONING> to the end:
        "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"
        "{% endif %}"
    )
    # Replace with out specific template:
    tokenizer.chat_template = chat_template \
        .replace("'{system_prompt}'", f"'{SYSTEM_PROMPT}'") \
        .replace("'{reasoning_start}'", f"'{REASONING_START}'")

In [None]:
chat_template = (  # Build and assign chat_template to the tokenizer
    # If the very first message is a SYSTEM role, print it + <eos>:
    "{% if messages[0]['role'] == 'system' %}"
    "{{ messages[0]['content'] + eos_token }}"
    "{% set loop_messages = messages[1:] %}"
    "{% else %}"
    # Otherwise, inject our system_prompt + <eos>:
    "{{ '{system_prompt}' + eos_token }}"
    "{% set loop_messages = messages %}"
    "{% endif %}"

    # Now loop over the remaining messages (either user or assistant):
    "{% for message in loop_messages %}"
    "{% if message['role'] == 'user' %}"
    "{{ message['content'] }}"
    "{% elif message['role'] == 'assistant' %}"
    "{{ message['content'] + eos_token }}"
    "{% endif %}"
    "{% endfor %}"

    # If we asked for "add_generation_prompt", append <REASONING> to the end:
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"
    "{% endif %}"
)
# Replace with out specific template:
tokenizer.chat_template = chat_template \
    .replace("'{system_prompt}'", f"'{SYSTEM_PROMPT}'") \
    .replace("'{reasoning_start}'", f"'{REASONING_START}'")

In [None]:
example_messages = [  # Quick sanity check of the template
    {'role': 'user', 'content': 'What is the most severe complication of dengue fever?'},
    {'role': 'assistant', 'content': (
        f'{REASONING_START}'
        'Dengue fever can lead to severe forms like dengue hemorrhagic fever. '
        'Considering symptoms and progression, the most severe is plasma leakage leading to shock.'
        f'{REASONING_END}{ANSWER_START}Dengue shock syndrome{ANSWER_END}'
    )},
    {'role': 'user', 'content': 'What drug is used for hypertension?'},
]
print(tokenizer.apply_chat_template(example_messages, tokenize=False, add_generation_prompt=True))

# Pre Fine-tuning (SFT)

## Data preparation

In [None]:
def format_dataset(x):  # Format the dataset to follow our GRPO style formatting (adapted for medical SFT data)
    # The medical-o1-reasoning-SFT has 'Complex CoT' and 'Response' fields
    expected_answer, problem = x['Response'], x['Question']  # 'Response' is the final answer, 'question' is the problem
    thoughts = x['Complex_CoT'].strip()  # Use the provided Complex CoT as reasoning
    final_prompt = REASONING_START + thoughts + REASONING_END + ANSWER_START + expected_answer + ANSWER_END  # Custom formatting
    return [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': problem},
        {'role': 'assistant', 'content': final_prompt},
    ]


sft_dataset = load_dataset('FreedomIntelligence/medical-o1-reasoning-SFT', 'en', split='train').to_pandas()
sft_dataset['messages'] = sft_dataset.apply(format_dataset, axis=1)
print(tokenizer.apply_chat_template(sft_dataset['messages'][0], tokenize=False))
sft_dataset

In [None]:
from datasets import Dataset, DatasetDict
import pandas as pd

# Truncate pre fine-tuning sft_dataset to max_seq_length / 2 since we don't want too long reasoning traces
sft_dataset['seq_length'] = sft_dataset['messages'].apply(lambda x: len(tokenizer.apply_chat_template(x)))
print('Token-length percentiles (50/90/99):', np.percentile(sft_dataset['seq_length'], [50, 90, 99]))

threshold = max_seq_length / 2
sft_dataset_filtered = sft_dataset.loc[sft_dataset['seq_length'] <= threshold].copy()
print(f'Remaining for training (<= {threshold} tokens): {len(sft_dataset_filtered)}/{len(sft_dataset)}')

sft_dataset_filtered['text'] = tokenizer.apply_chat_template(sft_dataset_filtered['messages'].values.tolist(),
                                                             tokenize=False)
sft_dataset_filtered = Dataset.from_pandas(sft_dataset_filtered)
sft_dataset_filtered

## Pre fine-tune to understand custom GRPO formatting

In [None]:
n = 1000
# 方法1：先转 pandas 再 head
sft_dataset_filtered = Dataset.from_pandas(
    sft_dataset_filtered.to_pandas().head(n)
)

print(len(sft_dataset_filtered))

# 方法2：直接对 Dataset 切片（更快，无需 pandas）
# sft_dataset_filtered = sft_dataset_filtered.select(range(n))

In [None]:
trainer = SFTTrainer(
    model=model,  # 待微调的基础模型实例
    tokenizer=tokenizer,  # 与模型配套的分词器，用于数据编码
    train_dataset=sft_dataset_filtered,  # 已过滤并完成 chat 模板拼接的训练集
    args=SFTConfig(
        dataset_text_field='text',  # 数据集中已格式化好的文本字段名
        num_train_epochs=3,  # 完整遍历训练集 3 次
        per_device_train_batch_size=4,  # 每张 GPU 的并行样本数（实际 batch 需 ×梯度累积×卡数）
        gradient_accumulation_steps=8,  # 累积 8 步才更新权重，等价于扩大 8 倍 batch
        optim='adamw_8bit',  # 使用 8-bit AdamW 优化器，节省显存
        weight_decay=0.01,  # L2 正则系数，防止过拟合
        learning_rate=2e-4,  # 峰值学习率，指令微调常用 1e-4~5e-4
        lr_scheduler_type='cosine',  # 余弦退火调度，平滑降到 0
        warmup_ratio=0.05,  # 前 5 % 步数线性 warmup 到峰值 lr
        logging_steps=100,  # 每 100 步打印一次日志
        report_to='none',  # 关闭 wandb / tensorboard 等外部日志平台
    )
)
trainer.train()
trainer.save_model(f'./{model_name}_sft')

## Check if model has learnt to follow the format

In [None]:
# # The FastLanguageModel.from_pretrained should be only called once. Otherwise, it will be OOM
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=f'./{model_name}_sft', # Reload LoRA weights
#     max_seq_length=max_seq_length,
#     load_in_4bit=False,         # False for LoRA 16bit
#     fast_inference=True,        # Enable vLLM fast inference
#     max_lora_rank=lora_rank,
#     gpu_memory_utilization=0.8, # Reduce if out of memory
# )
# FastLanguageModel.for_inference(model)

In [None]:
text = tokenizer.apply_chat_template(  # Render into a single string and append <REASONING> for generation
    sft_dataset_filtered[1]['messages'][:2],
    tokenize=False,
    add_generation_prompt=True,  # Append the final <REASONING>
)

print('-----------------------------------------------')
print(text)
print('-----------------------------------------------')
_ = model.generate(
    **tokenizer(text, return_tensors='pt').to('cuda'),
    temperature=0,
    max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=False),  # Stream the model's generations (CoT + solution)
)
print('-----------------------------------------------')

In [None]:
del sft_dataset, sft_dataset_filtered
gc.collect()
torch.cuda.empty_cache()

# Post Fine-tuning (RL)

## Data preparation

In [None]:
def process_dataset_sample(example):  # Convert medical example to conversation format for GRPO training
    return {
        'prompt': [  # Create conversation with system prompt for structured reasoning
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['Open-ended Verifiable Question']},
        ],
        # Ground truth answer for reward functions (text, not numerical)
        'answer': example['Ground-True Answer'].strip()
    }

In [None]:
# Load the verifiable medical problems dataset for RL
train_dataset = load_dataset('FreedomIntelligence/medical-o1-verifiable-problem', split='train')
train_dataset = train_dataset.map(process_dataset_sample)
print(f'Training samples: {len(train_dataset):,}\n'  # ~20K as per paper
      f"- Sample question: {train_dataset[0]['prompt'][1]['content']}\n"
      f"- Sample answer (ground truth for rewards): {train_dataset[0]['answer']}\n"
      f"- Prompt (system + user):\n{train_dataset[0]['prompt']}")

In [None]:
n = 1000
# 方法1：先转 pandas 再 head
train_dataset = Dataset.from_pandas(
    train_dataset.to_pandas().head(n)
)

print(len(train_dataset))

In [None]:
# Get the top 90% prompt length so we don't accidentally truncate them, i.e. we'll remove the top 10% long prompts
tokenized_dataset = train_dataset.map(
    lambda x: {'tokens': tokenizer.apply_chat_template(x['prompt'], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {'length': len(x['tokens'])})
print(tokenizer.decode(tokenized_dataset[0]['tokens']))

thresholds = np.percentile(tokenized_dataset['length'], [50, 90, 99])
max_prompt_length = int(thresholds[1])
print('Token-length percentiles (50/90/99):', thresholds, '=> Choose max_prompt_length =', max_prompt_length)

# Filter only samples smaller than 90% max length
train_dataset = train_dataset.select(np.where(np.array(tokenized_dataset['length']) <= max_prompt_length)[0])
print(f'Remaining for training (<= {max_prompt_length} tokens): {len(train_dataset)}/{len(tokenized_dataset)}')
del tokenized_dataset

## Multi-reward design

In [None]:
match_format = re.compile(  # Match the reasoning sections and answers
    rf'{REASONING_END}.*?'  # We always prepend REASONING_START
    rf'{ANSWER_START}(.+?){ANSWER_END}'  # Answer section with capture group (text, not number)
    rf'[\s]{{0,}}(?:{re.escape(tokenizer.eos_token)})?'  # Add optional EOS token matching
    rf'[\s]{{0,}}$',  # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL,  # Multi-line matching with . matching newlines
)


def verify(guess, true_answer):
    verifier_input = VERIFIER_TEMPLATE.format(guess.strip(), true_answer.strip(), verifier_tokenizer.eos_token)
    input_batch = verifier_tokenizer([verifier_input], return_tensors='pt').to(verifier_model.device)
    with torch.no_grad():
        logits = verifier_model(**input_batch, return_dict=True).logits
    probabilities = F.softmax(logits, dim=-1)
    return probabilities[0, 1].item()

In [None]:
def match_format_strictly(completions, **kwargs) -> list[float]:
    ''' Reward Function 1: Exact Format Compliance
    High reward (3.0) for perfect format adherence. Ensure model learns the complete structured output pattern
    '''
    return [
        3.0 if match_format.search(completion[0]['content']) else 0.0
        for completion in completions
    ]

In [None]:
# If it fails, reward the model if it at least follows the format partially, by counting each symbol
def match_format_softly(completions, **kwargs) -> list[float]:
    ''' Reward Function 2: Partial Format Credit
    Graduated scoring for format elements. Encourage learning individual components even if not perfect
    '''
    rewards = []
    for completion in completions:
        reward = 0
        response = completion[0]['content']

        # Count how many keywords are seen - we penalize if too many!
        # Award +0.5 for correct token count, -0.5 for wrong count
        # reward += 0.5 if response.count(REASONING_START) == 1 else -0.5  # Prepended
        reward += 0.5 if response.count(REASONING_END) == 1 else -0.5
        reward += 0.5 if response.count(ANSWER_START) == 1 else -0.5
        reward += 0.5 if response.count(ANSWER_END) == 1 else -0.5
        rewards.append(reward)
    return rewards

In [None]:
# Extract the generated answer, and reward or penalize it
def check_answer_correctness(completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 3: Graduated scoring for medical accuracy using verifier
    - 5.0: Verifier confirms full alignment (True)
    - 2.0: Partial credit if verifier prob > 0.5 but not full (use prob for gradation)
    - -2.5: Wrong answer (False)
    Handles aliases via semantic verification
    '''
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [  # Extract answers using format pattern
        guess.group(1) if (guess := match_format.search(r)) else None
        for r in responses
    ]
    rewards = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:  # No extractable answer
            rewards.append(-2.0)
            continue

        prob_true = verify(guess, true_answer)  # Use verifier to check semantic alignment
        if prob_true > 0.9:
            rewards.append(5.0)  # High confidence correct
        elif prob_true > 0.7:
            rewards.append(3.5)  # Strong alignment
        elif prob_true > 0.5:
            rewards.append(2.0)  # Partial/approximate
        elif prob_true > 0.3:
            rewards.append(1.5)  # Reasonable attempt
        else:
            rewards.append(-2.5)  # Incorrect
    return rewards

## GRPO training setup

In [None]:
max_prompt_length += 1
max_completion_length = max_seq_length - max_prompt_length

# Encourage exploration during training
vllm_sampling_params = SamplingParams(
    min_p=0.1,
    top_p=1.0,
    top_k=-1,
    stop=[tokenizer.eos_token],
    include_stop_str_in_output=True,
)

In [None]:
training_args = GRPOConfig(  # Configure GRPO training parameters for mathematical reasoning
    output_dir=f'/tmp/{model_name}',  # Directory for checkpoints and logs
    vllm_sampling_params=vllm_sampling_params,
    # Training speed control
    num_train_epochs=1,  # Total number of training epochs
    per_device_train_batch_size=4,  # Small batch for GPU memory constraints
    gradient_accumulation_steps=16,  # Effective batch size = 4 * 16 = 64
    # Precision & Optimization
    scale_rewards='batch',
    # Calculate mean at local/group level and std at global/batch level enables more robust reward shaping
    loss_type='dr_grpo',  # Fully remove response length bias, dividing by a constant instead of the sequence length
    optim='adamw_8bit',
    weight_decay=0.1,  # Regularization
    max_grad_norm=0.1,  # Aggressive gradient clipping for stable training
    bf16=is_bfloat16_supported(),  # Enable mixed-precision training if a CUDA GPU is available (faster, less memory)
    learning_rate=1e-5,  # Conservative LR to prevent destabilizing reasoning
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    # Generation control
    temperature=1.0,  # Encourage exploration during training
    num_generations=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    # Reporting and saving
    report_to='none',
    logging_steps=79,
    logging_strategy='steps',
    save_total_limit=1,
    # max_steps=100,
)

## Train the model

In [None]:
%%time
trainer = GRPOTrainer(  # Initialize GRPO trainer with multi-reward system (improvement over paper's single-reward PPO)
    model=model,  # LoRA-adapted quantized model
    processing_class=tokenizer,
    train_dataset=train_dataset,  # Processed medical dataset
    args=training_args,  # Training configuration
    reward_funcs=[  # 3 complementary reward functions
        match_format_strictly,  # Perfect structure compliance
        match_format_softly,  # Partial format credit
        check_answer_correctness,  # Semantic accuracy via verifier (handles aliases)
    ]
)
trainer.train()
trainer.save_model(f'./{model_name}_grpo')

In [None]:
gc.collect()
torch.cuda.empty_cache()

# Evaluation

In [None]:
# https://docs.unsloth.ai/models/qwen3-how-to-run-and-fine-tune#official-recommended-settings
sampling_params = SamplingParams(
    temperature=0.6,
    min_p=0.0,
    top_p=0.95,
    top_k=20,
    max_tokens=max_completion_length,
)
match_answer_letter = re.compile(  # Regex to extract the answer letter
    rf'{ANSWER_START}.*?[\s]{{0,}}([A-D])',
    flags=re.MULTILINE | re.DOTALL
)

## Influence of LoRA

In [None]:
example_text = 'What drug is used for hypertension?'
print(model.fast_generate(  # Try the model without any GRPO trained
    example_text, sampling_params=sampling_params,
    lora_request=None
)[0].outputs[0].text)

In [None]:
tensors = {}
with safe_open(f'./{model_name}_grpo/adapter_model.safetensors', framework='pt') as f:
    for key in f.keys():  # Verify both A and B are non zero
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert (n_zeros.item() != tensor.numel())

In [None]:
# Load the LoRA and test without using system prompt
# which should not (or minimal) affect the model's original reasoning ability
text = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': example_text}],
    add_generation_prompt=True, tokenize=False,
)
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

In [None]:
# Test using system prompt
text = tokenizer.apply_chat_template([
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user', 'content': example_text},
], add_generation_prompt=True, tokenize=False)

# Compare results with system prompt but without LoRA
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=None,
)[0].outputs[0].text)

# Reasoning model is much better - it's not always correct, since we only trained it for an hour
# It'll be better if we extend the sequence length and train for longer
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

## Performance on benchmark datasets

In [None]:
def format_mcq_prompt(example, dataset_name):
    user_content = 'Please answer the following multiple-choice question, ensuring your response only '
    user_content += 'contain the correct option letter with no extra text:\n' + example['question'] + '\n'

    if dataset_name == 'medqa':
        user_content += '\n'.join([f'{k}. {v}' for k, v in example['options'].items()])
        gold_letter = example['answer_idx']
    elif dataset_name == 'medmcqa':
        choices = [example['opa'], example['opb'], example['opc'], example['opd']]
        user_content += '\n'.join([f'{chr(65 + i)}. {choice}' for i, choice in enumerate(choices)])
        gold_letter = chr(ord('A') + example['cop'])
    elif dataset_name == 'pubmedqa':
        contexts = ' '.join(example['context']['contexts'])  # Join abstracts as context
        user_content = f'Context: {contexts}\n{user_content}A. yes\nB. no\nC. maybe'
        gold_letter = 'A' if example['final_decision'] == 'yes' else 'B' if example['final_decision'] == 'no' else 'C'
    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': user_content},
        ],
        'answer': gold_letter.strip()
    }

In [None]:
# Load evaluation benchmarks as per paper: MedQA-USMLE (test), MedMCQA (validation), PubMedQA (PQA-L test equivalent)
medqa_dataset = load_dataset('GBaker/MedQA-USMLE-4-options',split='test')  # 1273 samples; gold in 'answer_idx' but use open-ended, verifier with options['answer_idx']
medmcqa_dataset = load_dataset('openlifescienceai/medmcqa',split='validation')  # 4183 samples; gold in 'cop' (choice index)
pubmedqa_dataset = load_dataset('pubmed_qa', 'pqa_labeled',split='train')  # 1000 questions in the PQA-L are used as the test set
eval_datasets = {
    'MedQA_USLME_test': medqa_dataset.map(lambda x: format_mcq_prompt(x, 'medqa')),
    'MedMCQA_validation': medmcqa_dataset.map(lambda x: format_mcq_prompt(x, 'medmcqa')),
    'PubMedQA_test': pubmedqa_dataset.map(lambda x: format_mcq_prompt(x, 'pubmedqa')),
}

In [None]:
%%time
results = {}
for ds_name, dataset in eval_datasets.items():
    num_samples = len(dataset)
    print(f'\n{ds_name}: {num_samples} samples')

    test_texts = [
        tokenizer.apply_chat_template(sample['prompt'], add_generation_prompt=True, tokenize=False)
        for sample in dataset
    ]
    outputs_with_lora = model.fast_generate(
        test_texts, sampling_params=sampling_params,
        lora_request=model.load_lora(f'./{model_name}_grpo'),
    )
    outputs_without_lora = model.fast_generate(
        test_texts, sampling_params=sampling_params,
        lora_request=None,
    )

    # Compare the correct amount of using and not using LoRA
    no_lora_format_cnt = lora_format_cnt = 0
    no_lora_answer_cnt = lora_answer_cnt = 0
    no_lora_all_cnt = lora_all_cnt = 0

    for output_with_lora, output_without_lora, gt_answer in zip(outputs_with_lora, outputs_without_lora, dataset['answer']):
        # With LoRA
        response_lora = output_with_lora.outputs[0].text
        correct_format_lora = match_format.search(response_lora) is not None
        extracted_guess_lora = match_answer_letter.search(response_lora).group(1) if correct_format_lora else None
        correct_answer_lora = extracted_guess_lora == gt_answer
        correct_all_lora = correct_format_lora and correct_answer_lora

        if correct_format_lora: lora_format_cnt += 1
        if correct_answer_lora: lora_answer_cnt += 1
        if correct_all_lora: lora_all_cnt += 1

        # Without LoRA
        response_no_lora = output_without_lora.outputs[0].text
        correct_format_no_lora = match_format.search(response_no_lora) is not None
        extracted_guess_no_lora = match_answer_letter.search(response_no_lora).group(
            1) if correct_format_no_lora else None
        correct_answer_no_lora = extracted_guess_no_lora == gt_answer
        correct_all_no_lora = correct_format_no_lora and correct_answer_no_lora

        if correct_format_no_lora: no_lora_format_cnt += 1
        if correct_answer_no_lora: no_lora_answer_cnt += 1
        if correct_all_no_lora: no_lora_all_cnt += 1

    results[ds_name] = {
        'Without LoRA': {
            'Correct Format': f'{no_lora_format_cnt}/{num_samples} ({no_lora_format_cnt / num_samples * 100:.2f}%)',
            'Correct Answer': f'{no_lora_answer_cnt}/{num_samples} ({no_lora_answer_cnt / num_samples * 100:.2f}%)',
            'Correct Both': f'{no_lora_all_cnt}/{num_samples} ({no_lora_all_cnt / num_samples * 100:.2f}%)',
        },
        'With LoRA': {
            'Correct Format': f'{lora_format_cnt}/{num_samples} ({lora_format_cnt / num_samples * 100:.2f}%)',
            'Correct Answer': f'{lora_answer_cnt}/{num_samples} ({lora_answer_cnt / num_samples * 100:.2f}%)',
            'Correct Both': f'{lora_all_cnt}/{num_samples} ({lora_all_cnt / num_samples * 100:.2f}%)',
        },
        'Improvement': {
            'Correct Format': f'+{lora_format_cnt - no_lora_format_cnt} ({(lora_format_cnt - no_lora_format_cnt) / num_samples * 100:.2f}%)',
            'Correct Answer': f'+{lora_answer_cnt - no_lora_answer_cnt} ({(lora_answer_cnt - no_lora_answer_cnt) / num_samples * 100:.2f}%)',
            'Correct Both': f'+{lora_all_cnt - no_lora_all_cnt} ({(lora_all_cnt - no_lora_all_cnt) / num_samples * 100:.2f}%)',
        }
    }
    display(pd.DataFrame(results[ds_name]).T)

# Inference

In [None]:
# # The FastLanguageModel.from_pretrained should be only called once. Otherwise, it will be OOM
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=f'./{model_name}_grpo', # Reload LoRA weights
#     max_seq_length=max_seq_length,
#     load_in_4bit=False,         # False for LoRA 16bit
#     fast_inference=True,        # Enable vLLM fast inference
#     max_lora_rank=lora_rank,
#     gpu_memory_utilization=0.8, # Reduce if out of memory
# )
# FastLanguageModel.for_inference(model)

In [None]:
def generate_with_reasoning(questions, max_completion_length=1024, system_prompt=SYSTEM_PROMPT, ):
    conversations = [[  # Format input using conversation template
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': question},
    ] for question in questions]

    prompts = [tokenizer.apply_chat_template(  # Apply chat template and tokenize
        conversation,
        add_generation_prompt=True,  # Add assistant prompt
        tokenize=False,  # Return string, not tokens
    ) for conversation in conversations]

    # Generate response with reasoning-optimized parameters
    inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(model.device)
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(  # Generate response with reasoning-optimized parameters
            **inputs,
            max_new_tokens=max_completion_length,
            temperature=0.6,  # Balance creativity and consistency
            top_p=0.95,  # Nucleus sampling for quality
            top_k=20,
            do_sample=True,  # Enable sampling for varied reasoning paths
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,  # Reduce repetitive reasoning steps
            length_penalty=1.0,  # Neutral preference for response length
            early_stopping=True,  # Stop at natural completion
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )
    end_time = time.time()
    inference_duration = end_time - start_time
    num_generated_tokens = output_ids.shape[1] - inputs['input_ids'].shape[1]

    output_ids = output_ids[:, inputs['input_ids'][0].shape[-1]:output_ids.shape[-1]]
    responses = tokenizer.batch_decode(output_ids,
                                       skip_special_tokens=True)  # Decode and extract only the generated portion
    return responses, inference_duration, num_generated_tokens

In [None]:
test_dataset = eval_datasets['MedQA_USLME_test']
medical_question = test_dataset[2]['prompt'][-1]['content']
expected_answer = test_dataset[2]['answer']

print(medical_question, '\n\n===== Response =====')
medical_responses, inference_duration, num_generated_tokens = generate_with_reasoning([medical_question],
                                                                                      max_completion_length)
medical_response = medical_responses[0]
print('Inference time (secs):', inference_duration)
print('Generated tokens:', num_generated_tokens)

In [None]:
# Validate format compliance
has_answer = ANSWER_START in medical_response and ANSWER_END in medical_response
print('Reasoning section:', REASONING_END in medical_response)
print('Answer section:', has_answer)

if has_answer:  # Check answer accuracy if answer section exists
    # answer_text = medical_response.split(ANSWER_START)[1].split(ANSWER_END)[0].strip()
    answer_text = match_answer_letter.search(medical_response).group(1).upper()
    print('Extracted:', answer_text)
    print('Expected:', expected_answer)
    print('Correct:', answer_text == expected_answer)