# Environment Setup

In [None]:
try: import numpy, PIL; get_numpy = f'numpy=={numpy.__version__}'; get_pil = f'pillow=={PIL.__version__}'
except: get_numpy = 'numpy'; get_pil = 'pillow'
try: import subprocess; is_t4 = 'Tesla T4' in str(subprocess.check_output(['nvidia-smi']))
except: is_t4 = False
get_vllm, get_triton = ('vllm==0.9.2', 'triton==3.2.0') if is_t4 else ('vllm==0.10.2', 'triton')
!uv pip install -qqq --upgrade unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
!uv pip install -qqq {get_triton}
!uv pip install flash-attn --no-build-isolation
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

In [2]:
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/RL/multi-reward-medical-reasoning
!ls

Mounted at /content/drive/
/content/drive/.shortcut-targets-by-id/1UKjVaf_VMR_2xjW4oCOfinHhpq5hpQJx/RL/verifiable-medical-agent
backup.ipynb		      ppo_baselines
base.ipynb		      qwen3-1.7b-base_grpo
grpo_trainer_lora_model       qwen3-1.7b-base_sft
huggingface_tokenizers_cache  unsloth_compiled_cache
instruct.ipynb		      unsloth_training_checkpoints
llama-3.2-1b_sft	      wandb


In [3]:
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTConfig, GRPOConfig, SFTTrainer, GRPOTrainer
from vllm import SamplingParams

import gc
import re
import time
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from safetensors import safe_open
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextStreamer

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.




INFO 10-29 18:23:09 [__init__.py:216] Automatically detected platform cuda.
Switching to PyTorch attention since your Xformers is broken.

Requires Flash-Attention version >=2.7.1,<=2.8.2 but got 2.8.3.
🦥 Unsloth Zoo will now patch everything to make training faster!


# Model Setup

In [4]:
model_id = 'unsloth/Qwen3-1.7B'               # Select model optimized for instruction-following and reasoning
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048                         # Can increase for longer reasoning traces
lora_rank = 32                                # Larger rank = smarter, but slower

In [5]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=max_seq_length,
    load_in_4bit=False,         # False for LoRA 16bit
    fast_inference=True,        # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,                          # Rank: adaptation capacity (16 good for reasoning tasks)
    lora_alpha=lora_rank * 2,             # Scaling factor (typically 2x rank)
    lora_dropout=0.1,                     # Regularization to prevent overfitting
    target_modules=[                      # Remove QKVO if out of memory
        'q_proj', 'k_proj', 'v_proj', 'o_proj',
        'gate_proj', 'up_proj', 'down_proj',
    ],
    use_gradient_checkpointing='unsloth', # Reduces memory usage
    random_state=2025,
)

INFO 10-29 18:23:50 [vllm_utils.py:694] Unsloth: Patching vLLM v1 graph capture
INFO 10-29 18:23:50 [vllm_utils.py:722] Unsloth: Patching vLLM v0 graph capture
==((====))==  Unsloth 2025.10.11: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.10.2.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-1.7B with actual GPU utilization = 79.54%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 79.32 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 368.
Unsloth: vLLM's KV Cache can use up to 59.82 GB. Also swap space = 6 GB.
Unsloth: Not an error, but `device` is not supporte

`torch_dtype` is deprecated! Use `dtype` instead!


INFO 10-29 18:24:14 [__init__.py:1815] Using max model len 2048
INFO 10-29 18:24:16 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=2048.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

generation_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

INFO 10-29 18:24:20 [core.py:76] Initializing a V1 LLM engine (v0.10.2) with config: model='unsloth/Qwen3-1.7B', speculative_config=None, tokenizer='unsloth/Qwen3-1.7B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=unsloth/Qwen3-1.7B, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, comp

model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

INFO 10-29 18:24:28 [weight_utils.py:369] Time spent downloading weights for unsloth/Qwen3-1.7B: 5.347967 seconds
INFO 10-29 18:24:28 [weight_utils.py:406] No model.safetensors.index.json found in remote.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 10-29 18:24:29 [default_loader.py:268] Loading weights took 1.27 seconds
INFO 10-29 18:24:29 [punica_selector.py:19] Using PunicaWrapperGPU.
INFO 10-29 18:24:31 [gpu_model_runner.py:2392] Model loading took 3.2939 GiB and 7.739387 seconds
INFO 10-29 18:24:42 [backends.py:539] Using cache directory: /root/.cache/vllm/torch_compile_cache/1268378413/rank_0_0/backbone for vLLM's torch.compile
INFO 10-29 18:24:42 [backends.py:550] Dynamo bytecode transform time: 10.71 s


Unsloth: Compiling kernels: 100%|██████████| 7/7 [00:00<00:00, 11.65it/s, triton_poi_fused_view_6]

INFO 10-29 18:24:49 [backends.py:194] Cache the graph for dynamic shape for later use



Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 19.69it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 529.86it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 523.84it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 492.59it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 512.85it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 505.12it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 536.04it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 496.35it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 520.43it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 534.29it

INFO 10-29 18:25:24 [backends.py:215] Compiling a graph for dynamic shape takes 40.45 s





INFO 10-29 18:25:37 [monitor.py:34] torch.compile takes 51.16 s in total
INFO 10-29 18:25:39 [gpu_worker.py:298] Available KV cache memory: 57.77 GiB
INFO 10-29 18:25:40 [kv_cache_utils.py:864] GPU KV cache size: 540,832 tokens
INFO 10-29 18:25:40 [kv_cache_utils.py:868] Maximum concurrency for 2,048 tokens per request: 264.08x
INFO 10-29 18:25:40 [vllm_utils.py:699] Unsloth: Running patched vLLM v1 `capture_model`.


Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 67/67 [00:17<00:00,  3.81it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 49/49 [00:06<00:00,  7.47it/s]

INFO 10-29 18:26:04 [gpu_model_runner.py:3118] Graph capturing finished in 24 secs, took 1.03 GiB
INFO 10-29 18:26:04 [vllm_utils.py:706] Unsloth: Patched vLLM v1 graph capture finished in 24 secs.





INFO 10-29 18:26:06 [gpu_worker.py:391] Free memory on device (78.79/79.32 GiB) on startup. Desired GPU memory utilization is (0.7954304147054094, 63.09 GiB). Actual usage is 3.29 GiB for weight, 2.01 GiB for peak activation, 0.02 GiB for non-torch memory, and 1.03 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=60769821388` to fit into requested memory, or `--kv-cache-memory=77626283520` to fully utilize gpu memory. Current kv cache memory in use is 62028112588 bytes.
INFO 10-29 18:26:06 [core.py:218] init engine (profile, create kv cache, warmup model) took 95.13 seconds
INFO 10-29 18:26:07 [llm.py:295] Supported_tasks: ('generate',)
INFO 10-29 18:26:07 [__init__.py:36] No IOProcessor plugins requested by the model
Unsloth: Just some info: will skip parsing ['post_attention_layernorm', 'input_layernorm', 'layer_norm2', 'post_feedforward_layernorm', 'k_norm', 'norm2', 'post_layernorm', 'ffn_norm', 'q_norm', 'pre_feedforward_layernorm', 'layer_no

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.10.11 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


In [6]:
# Load the medical verifier model (3B SequenceClassification) for reward functions
# This handles semantic equivalence and aliases, outputting True/False based on alignment
verifier_path = 'FreedomIntelligence/medical_o1_verifier_3B'
verifier_tokenizer = AutoTokenizer.from_pretrained(verifier_path)
verifier_model = AutoModelForSequenceClassification.from_pretrained(
    verifier_path, dtype='auto', device_map='auto',
    attn_implementation='flash_attention_2', num_labels=2
)
verifier_model.eval()  # Set to evaluation mode

# Verifier template from model card
VERIFIER_TEMPLATE = '''<Model Response>
{}
</Model Response>

<Reference Answer>
{}
</Reference Answer>

Your task is to evaluate the model response by comparing it to the reference answer. If the model response is correct and aligns with the reference answer, output "True". If it is incorrect or fails to select the correct option (if options are provided), output "False". {}'''

tokenizer_config.json:   0%|          | 0.00/54.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/332 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/962 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/21.0k [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Chat Template

In [7]:
# Define structured output format for mathematical reasoning
REASONING_START = '<THINK>' # Begin reasoning section
REASONING_END = '</THINK>'  # End reasoning section
ANSWER_START = '<ANSWER>'   # Begin final answer
ANSWER_END = '</ANSWER>'    # End final answer

# System prompt adapted for medical reasoning (inspired by paper's complex reasoning emphasis)
SYSTEM_PROMPT = f'''You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between {REASONING_START} and {REASONING_END}.
2. Provide your final answer between {ANSWER_START} and {ANSWER_END}.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.'''
print(SYSTEM_PROMPT)

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.


In [8]:
chat_template = ( # Build and assign chat_template to the tokenizer
    # If the very first message is a SYSTEM role, print it + <eos>:
    "{% if messages[0]['role'] == 'system' %}"
      "{{ messages[0]['content'] + eos_token }}"
      "{% set loop_messages = messages[1:] %}"
    "{% else %}"
      # Otherwise, inject our system_prompt + <eos>:
      "{{ '{system_prompt}' + eos_token }}"
      "{% set loop_messages = messages %}"
    "{% endif %}"

    # Now loop over the remaining messages (either user or assistant):
    "{% for message in loop_messages %}"
      "{% if message['role'] == 'user' %}"
        "{{ message['content'] }}"
      "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] + eos_token }}"
      "{% endif %}"
    "{% endfor %}"

    # If we asked for "add_generation_prompt", append <REASONING> to the end:
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"
    "{% endif %}"
)
# Replace with out specific template:
tokenizer.chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{SYSTEM_PROMPT}'")\
    .replace("'{reasoning_start}'", f"'{REASONING_START}'")

In [9]:
example_messages = [ # Quick sanity check of the template
    {'role': 'user', 'content': 'What is the most severe complication of dengue fever?'},
    {'role': 'assistant', 'content': (
        f'{REASONING_START}'
        'Dengue fever can lead to severe forms like dengue hemorrhagic fever. '
        'Considering symptoms and progression, the most severe is plasma leakage leading to shock.'
        f'{REASONING_END}{ANSWER_START}Dengue shock syndrome{ANSWER_END}'
    )},
    {'role': 'user', 'content': 'What drug is used for hypertension?'},
]
print(tokenizer.apply_chat_template(example_messages, tokenize=False, add_generation_prompt=True))

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|im_end|>What is the most severe complication of dengue fever?<THINK>Dengue fever can lead to severe forms like dengue hemorrhagic fever. Considering symptoms and progression, the most severe is plasma leakage leading to shock.</THINK><ANSWER>Dengue shock syndrome</ANSWER><|im_end|>What drug is used for hypertension?<THINK>


# Pre Fine-tuning (SFT)

## Data preparation

In [10]:
def format_dataset(x): # Format the dataset to follow our GRPO style formatting (adapted for medical SFT data)
    # The medical-o1-reasoning-SFT has 'Complex CoT' and 'Response' fields
    expected_answer, problem = x['Response'], x['Question']  # 'Response' is the final answer, 'question' is the problem
    thoughts = x['Complex_CoT'].strip()  # Use the provided Complex CoT as reasoning
    final_prompt = REASONING_START + thoughts + REASONING_END + ANSWER_START + expected_answer + ANSWER_END  # Custom formatting
    return [
        {'role': 'system'   , 'content': SYSTEM_PROMPT},
        {'role': 'user'     , 'content': problem},
        {'role': 'assistant', 'content': final_prompt},
    ]

sft_dataset = load_dataset('FreedomIntelligence/medical-o1-reasoning-SFT', 'en', split='train').to_pandas()
sft_dataset['messages'] = sft_dataset.apply(format_dataset, axis=1)
print(tokenizer.apply_chat_template(sft_dataset['messages'][0], tokenize=False))
sft_dataset

README.md: 0.00B [00:00, ?B/s]

medical_o1_sft.json:   0%|          | 0.00/58.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/19704 [00:00<?, ? examples/s]

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|im_end|>Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?<THINK>Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?

But wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.

So, now I'm thinking, 

Unnamed: 0,Question,Complex_CoT,Response,messages
0,Given the symptoms of sudden weakness in the l...,"Okay, let's see what's going on here. We've go...",The specific cardiac abnormality most likely t...,"[{'role': 'system', 'content': 'You are a medi..."
1,A 33-year-old woman is brought to the emergenc...,"Okay, let's figure out what's going on here. A...","In this scenario, the most likely anatomical s...","[{'role': 'system', 'content': 'You are a medi..."
2,A 61-year-old woman with a long history of inv...,"Okay, let's think about this step by step. The...",Cystometry in this case of stress urinary inco...,"[{'role': 'system', 'content': 'You are a medi..."
3,A 45-year-old man with a history of alcohol us...,"Alright, let’s break this down. We have a 45-y...",Considering the clinical presentation of sudde...,"[{'role': 'system', 'content': 'You are a medi..."
4,A 45-year-old man presents with symptoms inclu...,"Okay, so here's a 45-year-old guy who's experi...",Based on the clinical findings presented—wide-...,"[{'role': 'system', 'content': 'You are a medi..."
...,...,...,...,...
19699,A 21-year-old man presents with painful urinat...,"Okay, let's think this through. We've got a 21...",The most likely infectious cause of the sympto...,"[{'role': 'system', 'content': 'You are a medi..."
19700,In a 7-month-old child diagnosed with H. influ...,"Alright, so we're dealing with a 7-month-old c...",In a 7-month-old child who has been diagnosed ...,"[{'role': 'system', 'content': 'You are a medi..."
19701,What is the treatment of choice for a 40-year-...,"Okay, so we have a 40-year-old woman, it's her...",The treatment of choice in this scenario is la...,"[{'role': 'system', 'content': 'You are a medi..."
19702,What is the most likely underlying mechanism c...,"Alright, so we've got a 25-year-old woman deal...",In a 25-year-old woman with systemic lupus ery...,"[{'role': 'system', 'content': 'You are a medi..."


In [11]:
# Truncate pre fine-tuning sft_dataset to max_seq_length / 2 since we don't want too long reasoning traces
sft_dataset['seq_length'] = sft_dataset['messages'].apply(lambda x: len(tokenizer.apply_chat_template(x)))
print('Token-length percentiles (50/90/99):', np.percentile(sft_dataset['seq_length'], [50, 90, 99]))

threshold = max_seq_length / 2
sft_dataset_filtered = sft_dataset.loc[sft_dataset['seq_length'] <= threshold].copy()
print(f'Remaining for training (<= {threshold} tokens): {len(sft_dataset_filtered)}/{len(sft_dataset)}')

sft_dataset_filtered['text'] = tokenizer.apply_chat_template(sft_dataset_filtered['messages'].values.tolist(), tokenize=False)
sft_dataset_filtered = Dataset.from_pandas(sft_dataset_filtered)
sft_dataset_filtered

Token-length percentiles (50/90/99): [ 686.    889.   1115.97]
Remaining for training (<= 1024.0 tokens): 19171/19704


Dataset({
    features: ['Question', 'Complex_CoT', 'Response', 'messages', 'seq_length', 'text', '__index_level_0__'],
    num_rows: 19171
})

## Pre fine-tune to understand custom GRPO formatting

In [12]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset_filtered,
    args=SFTConfig(
        dataset_text_field='text',
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        optim='adamw_8bit',
        weight_decay=0.01,
        learning_rate=2e-4,
        lr_scheduler_type='cosine',
        warmup_ratio=0.05,
        logging_steps=100,
        report_to='none',
    )
)
trainer.train()
trainer.save_model(f'./{model_name}_sft')

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/19171 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 151654}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 19,171 | Num Epochs = 3 | Total steps = 1,800
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 8 x 1) = 32
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
100,1.6115
200,1.3525
300,1.3339
400,1.3192
500,1.304
600,1.3026
700,1.2464
800,1.2424
900,1.2354
1000,1.2384


Step,Training Loss
100,1.6115
200,1.3525
300,1.3339
400,1.3192
500,1.304
600,1.3026
700,1.2464
800,1.2424
900,1.2354
1000,1.2384


## Check if model has learnt to follow the format

In [None]:
# # The FastLanguageModel.from_pretrained should be only called once. Otherwise, it will be OOM
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=f'./{model_name}_sft', # Reload LoRA weights
#     max_seq_length=max_seq_length,
#     load_in_4bit=False,         # False for LoRA 16bit
#     fast_inference=True,        # Enable vLLM fast inference
#     max_lora_rank=lora_rank,
#     gpu_memory_utilization=0.8, # Reduce if out of memory
# )
# FastLanguageModel.for_inference(model)

In [13]:
text = tokenizer.apply_chat_template( # Render into a single string and append <REASONING> for generation
    sft_dataset_filtered[1]['messages'][:2],
    tokenize=False, add_generation_prompt=True, # Append the final <REASONING>
)
_ = model.generate(
    **tokenizer(text, return_tensors='pt').to('cuda'),
    temperature=0, max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=False), # Stream the model's generations (CoT + solution)
)

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|im_end|>A 33-year-old woman is brought to the emergency department 15 minutes after being stabbed in the chest with a screwdriver. Given her vital signs of pulse 110/min, respirations 22/min, and blood pressure 90/65 mm Hg, along with the presence of a 5-cm deep stab wound at the upper border of the 8th rib in the left midaxillary line, which anatomical structure in her chest is most likely to be injured?<THINK>Okay, so we have a 33-year-old woman who just got stabbed in the chest with a screwdriver. That's pretty serious. She's showing some concerning signs too, like a fast pulse and shallow breathing. Her blood pressure is low, which 

In [None]:
del sft_dataset, sft_dataset_filtered
gc.collect()
torch.cuda.empty_cache()

# Post Fine-tuning (RL)

## Data preparation

In [15]:
def process_dataset_sample(example):  # Convert medical example to conversation format for GRPO training
    return {
        'prompt': [  # Create conversation with system prompt for structured reasoning
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['Open-ended Verifiable Question']},
        ],
        # Ground truth answer for reward functions (text, not numerical)
        'answer': example['Ground-True Answer'].strip()
    }

In [16]:
# Load the verifiable medical problems dataset for RL
train_dataset = load_dataset('FreedomIntelligence/medical-o1-verifiable-problem', split='train')
train_dataset = train_dataset.map(process_dataset_sample)
print(f'Training samples: {len(train_dataset):,}\n'  # ~20K as per paper
      f"- Sample question: {train_dataset[0]['prompt'][1]['content']}\n"
      f"- Sample answer (ground truth for rewards): {train_dataset[0]['answer']}\n"
      f"- Prompt (system + user):\n{train_dataset[0]['prompt']}")

README.md: 0.00B [00:00, ?B/s]

medical_o1_verifiable_problem.json:   0%|          | 0.00/12.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40644 [00:00<?, ? examples/s]

Map:   0%|          | 0/40644 [00:00<?, ? examples/s]

Training samples: 40,644
- Sample question: An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of her gastrointestinal blood loss?
- Sample answer (ground truth for rewards): Gastric ulcer
- Prompt (system + user):
[{'content': 'You are a medical reasoning assistant. When given a medical problem:\n1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.\n2. Provide your final answer between <ANSWER> and </ANSWER>.\n3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.', 'role': 'system'}, {'content': 'An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of

In [17]:
# Get the top 90% prompt length so we don't accidentally truncate them, i.e. we'll remove the top 10% long prompts
tokenized_dataset = train_dataset.map(
    lambda x: {'tokens': tokenizer.apply_chat_template(x['prompt'], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {'length': len(x['tokens'])})
print(tokenizer.decode(tokenized_dataset[0]['tokens']))

thresholds = np.percentile(tokenized_dataset['length'], [50, 90, 99])
max_prompt_length = int(thresholds[1])
print('Token-length percentiles (50/90/99):', thresholds, '=> Choose max_prompt_length =', max_prompt_length)

# Filter only samples smaller than 90% max length
train_dataset = train_dataset.select(np.where(np.array(tokenized_dataset['length']) <= max_prompt_length)[0])
print(f'Remaining for training (<= {max_prompt_length} tokens): {len(train_dataset)}/{len(tokenized_dataset)}')
del tokenized_dataset

Map:   0%|          | 0/40644 [00:00<?, ? examples/s]

Map:   0%|          | 0/40644 [00:00<?, ? examples/s]

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|im_end|>An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of her gastrointestinal blood loss?<THINK>
Token-length percentiles (50/90/99): [116. 160. 204.] => Choose max_prompt_length = 160
Remaining for training (<= 160 tokens): 36664/40644


## Multi-reward design

In [None]:
match_format = re.compile(                              # Match the reasoning sections and answers
    rf'{REASONING_END}.*?'                              # We always prepend REASONING_START
    rf'{ANSWER_START}(.+?){ANSWER_END}'                 # Answer section with capture group (text, not number)
    rf'[\s]{{0,}}(?:{re.escape(tokenizer.eos_token)})?' # Add optional EOS token matching
    rf'[\s]{{0,}}$',                                    # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL,                     # Multi-line matching with . matching newlines
)
def verify(guess, true_answer):
    verifier_input = VERIFIER_TEMPLATE.format(guess.strip(), true_answer.strip(), verifier_tokenizer.eos_token)
    input_batch = verifier_tokenizer([verifier_input], return_tensors='pt').to(verifier_model.device)
    with torch.no_grad():
        logits = verifier_model(**input_batch, return_dict=True).logits
    probabilities = F.softmax(logits, dim=-1)
    return probabilities[0, 1].item()

In [None]:
def match_format_strictly(completions, **kwargs) -> list[float]:
    ''' Reward Function 1: Exact Format Compliance
    High reward (3.0) for perfect format adherence. Ensure model learns the complete structured output pattern
    '''
    return [
        3.0 if match_format.search(completion[0]['content']) else 0.0
        for completion in completions
    ]

In [None]:
# If it fails, reward the model if it at least follows the format partially, by counting each symbol
def match_format_softly(completions, **kwargs) -> list[float]:
    ''' Reward Function 2: Partial Format Credit
    Graduated scoring for format elements. Encourage learning individual components even if not perfect
    '''
    rewards = []
    for completion in completions:
        reward = 0
        response = completion[0]['content']

        # Count how many keywords are seen - we penalize if too many!
        # Award +0.5 for correct token count, -0.5 for wrong count
        # reward += 0.5 if response.count(REASONING_START) == 1 else -0.5  # Prepended
        reward += 0.5 if response.count(REASONING_END) == 1 else -0.5
        reward += 0.5 if response.count(ANSWER_START) == 1 else -0.5
        reward += 0.5 if response.count(ANSWER_END) == 1 else -0.5
        rewards.append(reward)
    return rewards

In [None]:
# Extract the generated answer, and reward or penalize it
def check_answer_correctness(completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 3: Graduated scoring for medical accuracy using verifier
    - 5.0: Verifier confirms full alignment (True)
    - 2.0: Partial credit if verifier prob > 0.5 but not full (use prob for gradation)
    - -2.5: Wrong answer (False)
    Handles aliases via semantic verification
    '''
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [  # Extract answers using format pattern
        guess.group(1) if (guess := match_format.search(r)) else None
        for r in responses
    ]
    rewards = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:    # No extractable answer
            rewards.append(-2.0)
            continue

        prob_true = verify(guess, true_answer)     # Use verifier to check semantic alignment
        if prob_true > 0.9: rewards.append(5.0)    # High confidence correct
        elif prob_true > 0.7: rewards.append(3.5)  # Strong alignment
        elif prob_true > 0.5: rewards.append(2.0)  # Partial/approximate
        elif prob_true > 0.3: rewards.append(1.5)  # Reasonable attempt
        else: rewards.append(-2.5)                 # Incorrect
    return rewards

## GRPO training setup

In [22]:
max_prompt_length += 1
max_completion_length = max_seq_length - max_prompt_length

# Encourage exploration during training
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [23]:
training_args = GRPOConfig(          # Configure GRPO training parameters for mathematical reasoning
    output_dir=f'/tmp/{model_name}', # Directory for checkpoints and logs
    vllm_sampling_params=vllm_sampling_params,
    # Training speed control
    num_train_epochs=1,              # Total number of training epochs
    per_device_train_batch_size=4,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=16,  # Effective batch size = 4 * 16 = 64
    # Precision & Optimization
    scale_rewards='batch',           # Calculate mean at local/group level and std at global/batch level enables more robust reward shaping
    loss_type='dr_grpo',             # Fully remove response length bias, dividing by a constant instead of the sequence length
    optim='adamw_8bit',
    weight_decay=0.1,                # Regularization
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    bf16=is_bfloat16_supported(),    # Enable mixed-precision training if a CUDA GPU is available (faster, less memory)
    learning_rate=1e-5,              # Conservative LR to prevent destabilizing reasoning
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    # Generation control
    temperature=1.0,                 # Encourage exploration during training
    num_generations=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    # Reporting and saving
    report_to='wandb',
    logging_steps=79,
    logging_strategy='steps',
    save_total_limit=1,
    # max_steps=100,
)

## Train the model

In [24]:
%%time
trainer = GRPOTrainer(            # Initialize GRPO trainer with multi-reward system (improvement over paper's single-reward PPO)
    model=model,                  # LoRA-adapted quantized model
    processing_class=tokenizer,
    train_dataset=train_dataset,  # Processed medical dataset
    args=training_args,           # Training configuration
    reward_funcs=[                # 3 complementary reward functions
        match_format_strictly,    # Perfect structure compliance
        match_format_softly,      # Partial format credit
        check_answer_correctness, # Semantic accuracy via verifier (handles aliases)
    ]
)
trainer.train()
trainer.save_model(f'./{model_name}_grpo')

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 36,664 | Num Epochs = 1 | Total steps = 2,291
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 16 x 1) = 64
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m18520339[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std
79,0.0733,3.305874,2.651911,635.779272,366.860759,1245.594937,0.001187,634.298561,366.860759,1192.911392,0.860294,2.994066,0.047468,1.495649,0.03481,-1.183841,2.640739
158,0.0531,3.363331,2.639507,590.170886,351.746835,1034.417722,0.000198,589.919348,351.746835,1027.734177,0.858004,2.997627,0.018987,1.49822,0.013009,-1.132516,2.635342
237,0.0364,3.56784,2.813767,549.913568,322.632911,946.64557,0.0,549.913568,322.632911,946.64557,0.870355,2.997033,0.023734,1.497824,0.017405,-0.927017,2.808044
316,0.028,3.737935,2.906902,546.492286,332.594937,907.936709,0.000198,546.229474,332.594937,896.405063,0.853846,2.999407,0.004747,1.497627,0.018987,-0.759098,2.90374
395,0.0291,3.780953,2.927679,590.951345,336.455696,1013.101266,0.0,590.951345,336.455696,1013.101266,0.835061,2.998813,0.009494,1.498616,0.011076,-0.716475,2.924647
474,0.0519,3.758604,2.905591,597.991495,339.544304,1050.670886,0.000198,597.734557,339.544304,1040.379747,0.827254,2.997033,0.023734,1.497429,0.02057,-0.735858,2.899819
553,0.0393,3.991594,3.036073,557.750593,322.734177,981.746835,0.000396,557.223218,322.734177,959.101266,0.835554,2.997627,0.018987,1.498616,0.011076,-0.504648,3.031571
632,0.0424,3.949565,3.028092,561.871835,327.379747,966.873418,0.000198,561.606886,327.379747,954.481013,0.825105,2.998813,0.009494,1.498616,0.011076,-0.547864,3.025849
711,0.0227,3.9643,3.007468,566.715388,319.443038,983.506329,0.0,566.715388,319.443038,983.506329,0.799076,2.999407,0.004747,1.499802,0.001582,-0.534909,3.006602
790,0.0345,4.050534,3.056634,547.531843,308.278481,925.189873,0.000198,547.267393,308.278481,913.379747,0.800785,2.998813,0.009494,1.498022,0.015823,-0.446301,3.053034


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std
79,0.0733,3.305874,2.651911,635.779272,366.860759,1245.594937,0.001187,634.298561,366.860759,1192.911392,0.860294,2.994066,0.047468,1.495649,0.03481,-1.183841,2.640739
158,0.0531,3.363331,2.639507,590.170886,351.746835,1034.417722,0.000198,589.919348,351.746835,1027.734177,0.858004,2.997627,0.018987,1.49822,0.013009,-1.132516,2.635342
237,0.0364,3.56784,2.813767,549.913568,322.632911,946.64557,0.0,549.913568,322.632911,946.64557,0.870355,2.997033,0.023734,1.497824,0.017405,-0.927017,2.808044
316,0.028,3.737935,2.906902,546.492286,332.594937,907.936709,0.000198,546.229474,332.594937,896.405063,0.853846,2.999407,0.004747,1.497627,0.018987,-0.759098,2.90374
395,0.0291,3.780953,2.927679,590.951345,336.455696,1013.101266,0.0,590.951345,336.455696,1013.101266,0.835061,2.998813,0.009494,1.498616,0.011076,-0.716475,2.924647
474,0.0519,3.758604,2.905591,597.991495,339.544304,1050.670886,0.000198,597.734557,339.544304,1040.379747,0.827254,2.997033,0.023734,1.497429,0.02057,-0.735858,2.899819
553,0.0393,3.991594,3.036073,557.750593,322.734177,981.746835,0.000396,557.223218,322.734177,959.101266,0.835554,2.997627,0.018987,1.498616,0.011076,-0.504648,3.031571
632,0.0424,3.949565,3.028092,561.871835,327.379747,966.873418,0.000198,561.606886,327.379747,954.481013,0.825105,2.998813,0.009494,1.498616,0.011076,-0.547864,3.025849
711,0.0227,3.9643,3.007468,566.715388,319.443038,983.506329,0.0,566.715388,319.443038,983.506329,0.799076,2.999407,0.004747,1.499802,0.001582,-0.534909,3.006602
790,0.0345,4.050534,3.056634,547.531843,308.278481,925.189873,0.000198,547.267393,308.278481,913.379747,0.800785,2.998813,0.009494,1.498022,0.015823,-0.446301,3.053034


CPU times: user 15h 31min 43s, sys: 2min 45s, total: 15h 34min 28s
Wall time: 15h 28min 15s


In [28]:
gc.collect()
torch.cuda.empty_cache()

# Evaluation

In [29]:
# https://docs.unsloth.ai/models/qwen3-how-to-run-and-fine-tune#official-recommended-settings
sampling_params = SamplingParams(
    temperature = 0.6,
    min_p = 0.0,
    top_p = 0.95,
    top_k = 20,
    max_tokens = max_completion_length,
)
match_answer_letter = re.compile( # Regex to extract the answer letter
    rf'{ANSWER_START}.*?[\s]{{0,}}([A-D])',
    flags=re.MULTILINE | re.DOTALL
)

## Influence of LoRA

In [30]:
example_text = 'What drug is used for hypertension?'
print(model.fast_generate( # Try the model without any GRPO trained
    example_text, sampling_params=sampling_params,
    lora_request=None
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?

Are there any side effects of using a drug for hypertension? What are the possible side effects of a drug for hypertension?


In [31]:
tensors = {}
with safe_open(f'./{model_name}_grpo/adapter_model.safetensors', framework='pt') as f:
    for key in f.keys(): # Verify both A and B are non zero
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

In [32]:
# Load the LoRA and test without using system prompt
# which should not (or minimal) affect the model's original reasoning ability
text = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': example_text}],
    add_generation_prompt=True, tokenize=False,
)
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Okay, so I'm trying to figure out which drug is commonly used for treating hypertension. Let's start by thinking about what hypertension is. It's when blood pressure is higher than it should be, and it can lead to serious health issues if not managed. So, the main goal here is to lower that blood pressure to a safer level.

Now, there are a few types of drugs used for this. I remember that there are different classes, like ACE inhibitors, beta-blockers, calcium channel blockers, and diuretics. Each of these works a bit differently and has its own benefits and side effects.

Let's break it down. ACE inhibitors, for example, are known to help reduce blood pressure by relaxing blood vessels. They're often prescribed for people with heart failure or diabetes. Then there are beta-blockers, which are used to slow the heart rate and reduce the force of the heartbeat. They're helpful for those with high blood pressure and heart issues.

Calcium channel blockers are another category. They work 

In [33]:
# Test using system prompt
text = tokenizer.apply_chat_template([
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user'  , 'content': example_text},
], add_generation_prompt=True, tokenize=False)

# Compare results with system prompt but without LoRA
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=None,
)[0].outputs[0].text)

# Reasoning model is much better - it's not always correct, since we only trained it for an hour
# It'll be better if we extend the sequence length and train for longer
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Okay, the user is asking about the drug used for hypertension. Let me start by recalling the main classes of antihypertensive medications. The first thing that comes to mind is ACE inhibitors, like lisinopril. Then there are beta-blockers, such as metoprolol. Diuretics are another class, like hydrochlorothiazide. Also, calcium channel blockers, like amlodipine. 

Wait, but the user might be looking for a specific drug, so maybe they want the most commonly prescribed one. I should consider the most common and widely used drugs. For example, ACE inhibitors are often first-line for certain patients, especially those with diabetes or heart failure. Beta-blockers are used in cases where the patient has risk factors like CAD or hypertension with diabetes. Diuretics are commonly used as first-line for patients with kidney disease or in combination with other drugs.

I should also think about the patient's comorbidities. For example, if the patient has renal disease, diuretics might be preferr

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Alright, let's think about hypertension. It's when blood pressure is higher than it should be, and it's a common issue that many people face. So, what are the main ways to manage it? Well, there are several types of medications. I know there are different classes like ACE inhibitors, beta-blockers, diuretics, and calcium channel blockers. Each of these works in a slightly different way to help lower blood pressure. 

Now, if I had to pick one drug that's widely used for hypertension, I'd probably go with a diuretic like thiazide or a calcium channel blocker like amlodipine. But I need to be more specific. Let me think. Oh, I remember that thiazide diuretics are often used because they help by reducing sodium and water retention. This makes the blood pressure drop. 

But wait, maybe there are other drugs like ACE inhibitors that are also commonly used. I've heard about them being beneficial for certain types of hypertension. For example, they can be especially helpful in people with hea

## Performance on benchmark datasets

In [34]:
def format_mcq_prompt(example, dataset_name):
    user_content = 'Please answer the following multiple-choice question, ensuring your response only '
    user_content += 'contain the correct option letter with no extra text:\n' + example['question'] + '\n'

    if dataset_name == 'medqa':
        user_content += '\n'.join([f'{k}. {v}' for k, v in example['options'].items()])
        gold_letter = example['answer_idx']
    elif dataset_name == 'medmcqa':
        choices = [example['opa'], example['opb'], example['opc'], example['opd']]
        user_content += '\n'.join([f'{chr(65+i)}. {choice}' for i, choice in enumerate(choices)])
        gold_letter = chr(ord('A') + example['cop'])
    elif dataset_name == 'pubmedqa':
        contexts = ' '.join(example['context']['contexts'])  # Join abstracts as context
        user_content = f'Context: {contexts}\n{user_content}A. yes\nB. no\nC. maybe'
        gold_letter = 'A' if example['final_decision'] == 'yes' else 'B' if example['final_decision'] == 'no' else 'C'
    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': user_content},
        ],
        'answer': gold_letter.strip()
    }

In [35]:
# Load evaluation benchmarks as per paper: MedQA-USMLE (test), MedMCQA (validation), PubMedQA (PQA-L test equivalent)
medqa_dataset = load_dataset('GBaker/MedQA-USMLE-4-options', split='test') # 1273 samples; gold in 'answer_idx' but use open-ended, verifier with options['answer_idx']
medmcqa_dataset = load_dataset('openlifescienceai/medmcqa', split='validation') # 4183 samples; gold in 'cop' (choice index)
pubmedqa_dataset = load_dataset('pubmed_qa', 'pqa_labeled', split='train') # 1000 questions in the PQA-L are used as the test set
eval_datasets = {
    'MedQA_USLME_test': medqa_dataset.map(lambda x: format_mcq_prompt(x, 'medqa')),
    'MedMCQA_validation': medmcqa_dataset.map(lambda x: format_mcq_prompt(x, 'medmcqa')),
    'PubMedQA_test': pubmedqa_dataset.map(lambda x: format_mcq_prompt(x, 'pubmedqa')),
}

README.md:   0%|          | 0.00/654 [00:00<?, ?B/s]

phrases_no_exclude_train.jsonl:   0%|          | 0.00/16.2M [00:00<?, ?B/s]

phrases_no_exclude_test.jsonl: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/10178 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/936k [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

pqa_labeled/train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1273 [00:00<?, ? examples/s]

Map:   0%|          | 0/4183 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
%%time
results = {}
for ds_name, dataset in eval_datasets.items():
    num_samples = len(dataset)
    print(f'\n{ds_name}: {num_samples} samples')

    test_texts = [
        tokenizer.apply_chat_template(sample['prompt'], add_generation_prompt=True, tokenize=False)
        for sample in dataset
    ]
    outputs_with_lora = model.fast_generate(
        test_texts, sampling_params=sampling_params,
        lora_request=model.load_lora(f'./{model_name}_grpo'),
    )
    outputs_without_lora = model.fast_generate(
        test_texts, sampling_params=sampling_params,
        lora_request=None,
    )

    # Compare the correct amount of using and not using LoRA
    no_lora_format_cnt = lora_format_cnt = 0
    no_lora_answer_cnt = lora_answer_cnt = 0
    no_lora_all_cnt = lora_all_cnt = 0

    for output_with_lora, output_without_lora, gt_answer in zip(outputs_with_lora, outputs_without_lora, dataset['answer']):
        # With LoRA
        response_lora = output_with_lora.outputs[0].text
        correct_format_lora = match_format.search(response_lora) is not None
        extracted_guess_lora = match_answer_letter.search(response_lora).group(1) if correct_format_lora else None
        correct_answer_lora = extracted_guess_lora == gt_answer
        correct_all_lora = correct_format_lora and correct_answer_lora

        if correct_format_lora: lora_format_cnt += 1
        if correct_answer_lora: lora_answer_cnt += 1
        if correct_all_lora: lora_all_cnt += 1

        # Without LoRA
        response_no_lora = output_without_lora.outputs[0].text
        correct_format_no_lora = match_format.search(response_no_lora) is not None
        extracted_guess_no_lora = match_answer_letter.search(response_no_lora).group(1) if correct_format_no_lora else None
        correct_answer_no_lora = extracted_guess_no_lora == gt_answer
        correct_all_no_lora = correct_format_no_lora and correct_answer_no_lora

        if correct_format_no_lora: no_lora_format_cnt += 1
        if correct_answer_no_lora: no_lora_answer_cnt += 1
        if correct_all_no_lora: no_lora_all_cnt += 1

    results[ds_name] = {
        'Without LoRA': {
            'Correct Format': f'{no_lora_format_cnt}/{num_samples} ({no_lora_format_cnt / num_samples * 100:.2f}%)',
            'Correct Answer': f'{no_lora_answer_cnt}/{num_samples} ({no_lora_answer_cnt / num_samples * 100:.2f}%)',
            'Correct Both': f'{no_lora_all_cnt}/{num_samples} ({no_lora_all_cnt / num_samples * 100:.2f}%)',
        },
        'With LoRA': {
            'Correct Format': f'{lora_format_cnt}/{num_samples} ({lora_format_cnt / num_samples * 100:.2f}%)',
            'Correct Answer': f'{lora_answer_cnt}/{num_samples} ({lora_answer_cnt / num_samples * 100:.2f}%)',
            'Correct Both': f'{lora_all_cnt}/{num_samples} ({lora_all_cnt / num_samples * 100:.2f}%)',
        },
        'Improvement': {
            'Correct Format': f'+{lora_format_cnt - no_lora_format_cnt} ({(lora_format_cnt - no_lora_format_cnt) / num_samples * 100:.2f}%)',
            'Correct Answer': f'+{lora_answer_cnt - no_lora_answer_cnt} ({(lora_answer_cnt - no_lora_answer_cnt) / num_samples * 100:.2f}%)',
            'Correct Both': f'+{lora_all_cnt - no_lora_all_cnt} ({(lora_all_cnt - no_lora_all_cnt) / num_samples * 100:.2f}%)',
        }
    }
    display(pd.DataFrame(results[ds_name]).T)


MedQA_USLME_test: 1273 samples


Adding requests:   0%|          | 0/1273 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1273 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/1273 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1273 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Unnamed: 0,Correct Format,Correct Answer,Correct Both
Without LoRA,646/1273 (50.75%),375/1273 (29.46%),375/1273 (29.46%)
With LoRA,1273/1273 (100.00%),629/1273 (49.41%),629/1273 (49.41%)
Improvement,+627 (49.25%),+254 (19.95%),+254 (19.95%)



MedMCQA_validation: 4183 samples


Adding requests:   0%|          | 0/4183 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4183 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/4183 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4183 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Unnamed: 0,Correct Format,Correct Answer,Correct Both
Without LoRA,3152/4183 (75.35%),1601/4183 (38.27%),1601/4183 (38.27%)
With LoRA,4183/4183 (100.00%),1927/4183 (46.07%),1927/4183 (46.07%)
Improvement,+1031 (24.65%),+326 (7.79%),+326 (7.79%)



PubMedQA_test: 1000 samples


Adding requests:   0%|          | 0/1000 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/1000 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

# Inference

In [None]:
# # The FastLanguageModel.from_pretrained should be only called once. Otherwise, it will be OOM
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=f'./{model_name}_grpo', # Reload LoRA weights
#     max_seq_length=max_seq_length,
#     load_in_4bit=False,         # False for LoRA 16bit
#     fast_inference=True,        # Enable vLLM fast inference
#     max_lora_rank=lora_rank,
#     gpu_memory_utilization=0.8, # Reduce if out of memory
# )
# FastLanguageModel.for_inference(model)

In [37]:
def generate_with_reasoning(questions, max_completion_length=1024, system_prompt=SYSTEM_PROMPT,):
    conversations = [[ # Format input using conversation template
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': question},
    ] for question in questions]

    prompts = [tokenizer.apply_chat_template( # Apply chat template and tokenize
        conversation,
        add_generation_prompt=True,         # Add assistant prompt
        tokenize=False,                     # Return string, not tokens
    ) for conversation in conversations]

    # Generate response with reasoning-optimized parameters
    inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(model.device)
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(           # Generate response with reasoning-optimized parameters
            **inputs,
            max_new_tokens=max_completion_length,
            temperature=0.6,                # Balance creativity and consistency
            top_p=0.95,                     # Nucleus sampling for quality
            top_k=20,
            do_sample=True,                 # Enable sampling for varied reasoning paths
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,         # Reduce repetitive reasoning steps
            length_penalty=1.0,             # Neutral preference for response length
            early_stopping=True,            # Stop at natural completion
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )
    end_time = time.time()
    inference_duration = end_time - start_time
    num_generated_tokens = output_ids.shape[1] - inputs['input_ids'].shape[1]

    output_ids = output_ids[:, inputs['input_ids'][0].shape[-1]:output_ids.shape[-1]]
    responses = tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Decode and extract only the generated portion
    return responses, inference_duration, num_generated_tokens

In [38]:
test_dataset = eval_datasets['MedQA_USLME_test']
medical_question = test_dataset[2]['prompt'][-1]['content']
expected_answer = test_dataset[2]['answer']

print(medical_question, '\n\n===== Response =====')
medical_responses, inference_duration, num_generated_tokens = generate_with_reasoning([medical_question], max_completion_length)
medical_response = medical_responses[0]
print('Inference time (secs):', inference_duration)
print('Generated tokens:', num_generated_tokens)

Please answer the following multiple-choice question, ensuring your response only contain the correct option letter with no extra text:
Two weeks after undergoing an emergency cardiac catherization with stenting for unstable angina pectoris, a 61-year-old man has decreased urinary output and malaise. He has type 2 diabetes mellitus and osteoarthritis of the hips. Prior to admission, his medications were insulin and naproxen. He was also started on aspirin, clopidogrel, and metoprolol after the coronary intervention. His temperature is 38°C (100.4°F), pulse is 93/min, and blood pressure is 125/85 mm Hg. Examination shows mottled, reticulated purplish discoloration of the feet. Laboratory studies show:
Hemoglobin count 14 g/dL
Leukocyte count 16,400/mm3
Segmented neutrophils 56%
Eosinophils 11%
Lymphocytes 31%
Monocytes 2%
Platelet count 260,000/mm3
Erythrocyte sedimentation rate 68 mm/h
Serum
Urea nitrogen 25 mg/dL
Creatinine 4.2 mg/dL
Renal biopsy shows intravascular spindle-shaped vac

In [39]:
# Validate format compliance
has_answer = ANSWER_START in medical_response and ANSWER_END in medical_response
print('Reasoning section:', REASONING_END in medical_response)
print('Answer section:', has_answer)

if has_answer: # Check answer accuracy if answer section exists
    # answer_text = medical_response.split(ANSWER_START)[1].split(ANSWER_END)[0].strip()
    answer_text = match_answer_letter.search(medical_response).group(1).upper()
    print('Extracted:', answer_text)
    print('Expected:', expected_answer)
    print('Correct:', answer_text == expected_answer)

Reasoning section: True
Answer section: True
Extracted: B
Expected: B
Correct: True
