# Environment Setup

In [None]:
try: import numpy, PIL; get_numpy = f'numpy=={numpy.__version__}'; get_pil = f'pillow=={PIL.__version__}'
except: get_numpy = 'numpy'; get_pil = 'pillow'
try: import subprocess; is_t4 = 'Tesla T4' in str(subprocess.check_output(['nvidia-smi']))
except: is_t4 = False
get_vllm, get_triton = ('vllm==0.9.2', 'triton==3.2.0') if is_t4 else ('vllm==0.10.2', 'triton')
!uv pip install -qqq --upgrade unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
!uv pip install -qqq {get_triton}
!uv pip install flash-attn --no-build-isolation
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/RL/multi-reward-medical-reasoning
!ls

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/.shortcut-targets-by-id/1UKjVaf_VMR_2xjW4oCOfinHhpq5hpQJx/RL/verifiable-medical-agent
grpo_trainer_lora_model       qwen3-1.7b-base_sft
huggingface_tokenizers_cache  Qwen3.ipynb
llama-3.2-1b_sft	      unsloth_compiled_cache
Llama3.ipynb		      unsloth_training_checkpoints
qwen3-1.7b-base_grpo	      wandb


In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTConfig, GRPOConfig, SFTTrainer, GRPOTrainer
from vllm import SamplingParams

import gc
import re
import time
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from safetensors import safe_open
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextStreamer

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.




INFO 10-29 10:21:52 [__init__.py:216] Automatically detected platform cuda.
Switching to PyTorch attention since your Xformers is broken.

Requires Flash-Attention version >=2.7.1,<=2.8.2 but got 2.8.3.
🦥 Unsloth Zoo will now patch everything to make training faster!


# Model Setup

In [None]:
model_id = 'unsloth/Qwen3-1.7B-Base'          # Select model optimized for instruction-following and reasoning
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048                         # Can increase for longer reasoning traces
lora_rank = 32                                # Larger rank = smarter, but slower

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=max_seq_length,
    load_in_4bit=False,         # False for LoRA 16bit
    fast_inference=True,        # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,                          # Rank: adaptation capacity (16 good for reasoning tasks)
    lora_alpha=lora_rank * 2,             # Scaling factor (typically 2x rank)
    lora_dropout=0.1,                     # Regularization to prevent overfitting
    target_modules=[                      # Remove QKVO if out of memory
        'q_proj', 'k_proj', 'v_proj', 'o_proj',
        'gate_proj', 'up_proj', 'down_proj',
    ],
    use_gradient_checkpointing='unsloth', # Reduces memory usage
    random_state=2025,
)

INFO 10-29 10:22:01 [vllm_utils.py:694] Unsloth: Patching vLLM v1 graph capture
INFO 10-29 10:22:01 [vllm_utils.py:722] Unsloth: Patching vLLM v0 graph capture
==((====))==  Unsloth 2025.10.11: Fast Qwen3 patching. Transformers: 4.55.4. vLLM: 0.10.2.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-1.7B-Base with actual GPU utilization = 79.54%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 79.32 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 368.
Unsloth: vLLM's KV Cache can use up to 59.82 GB. Also swap space = 6 GB.
Unsloth: Not an error, but `device` is not sup

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 10-29 10:22:33 [default_loader.py:268] Loading weights took 1.05 seconds
INFO 10-29 10:22:33 [punica_selector.py:19] Using PunicaWrapperGPU.
INFO 10-29 10:22:34 [gpu_model_runner.py:2392] Model loading took 3.2919 GiB and 2.278134 seconds
INFO 10-29 10:22:45 [backends.py:539] Using cache directory: /root/.cache/vllm/torch_compile_cache/fdfa5fae6b/rank_0_0/backbone for vLLM's torch.compile
INFO 10-29 10:22:45 [backends.py:550] Dynamo bytecode transform time: 10.61 s
INFO 10-29 10:22:51 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 3.665 s
INFO 10-29 10:22:53 [monitor.py:34] torch.compile takes 10.61 s in total
INFO 10-29 10:22:54 [gpu_worker.py:298] Available KV cache memory: 57.77 GiB
INFO 10-29 10:22:55 [kv_cache_utils.py:864] GPU KV cache size: 540,864 tokens
INFO 10-29 10:22:55 [kv_cache_utils.py:868] Maximum concurrency for 2,048 tokens per request: 264.09x
INFO 10-29 10:22:55 [vllm_utils.py:699] Unsloth: Running patched vLLM v1 `

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 67/67 [00:09<00:00,  7.27it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 49/49 [00:06<00:00,  7.40it/s]

INFO 10-29 10:23:11 [gpu_model_runner.py:3118] Graph capturing finished in 16 secs, took 1.02 GiB
INFO 10-29 10:23:11 [vllm_utils.py:706] Unsloth: Patched vLLM v1 graph capture finished in 16 secs.





INFO 10-29 10:23:13 [gpu_worker.py:391] Free memory on device (78.79/79.32 GiB) on startup. Desired GPU memory utilization is (0.7954304147054094, 63.09 GiB). Actual usage is 3.29 GiB for weight, 2.01 GiB for peak activation, 0.02 GiB for non-torch memory, and 1.02 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=60776112844` to fit into requested memory, or `--kv-cache-memory=77632574976` to fully utilize gpu memory. Current kv cache memory in use is 62032306892 bytes.
INFO 10-29 10:23:13 [core.py:218] init engine (profile, create kv cache, warmup model) took 38.68 seconds
INFO 10-29 10:23:14 [llm.py:295] Supported_tasks: ('generate',)
INFO 10-29 10:23:14 [__init__.py:36] No IOProcessor plugins requested by the model
Unsloth: Just some info: will skip parsing ['post_layernorm', 'post_feedforward_layernorm', 'pre_feedforward_layernorm', 'layer_norm1', 'attention_norm', 'post_attention_layernorm', 'layer_norm2', 'norm2', 'input_layernorm', 'norm1',

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.10.11 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


In [None]:
# Load the medical verifier model (3B SequenceClassification) for reward functions
# This handles semantic equivalence and aliases, outputting True/False based on alignment
verifier_path = 'FreedomIntelligence/medical_o1_verifier_3B'
verifier_tokenizer = AutoTokenizer.from_pretrained(verifier_path)
verifier_model = AutoModelForSequenceClassification.from_pretrained(
    verifier_path, dtype='auto', device_map='auto',
    attn_implementation='flash_attention_2', num_labels=2
)
verifier_model.eval()  # Set to evaluation mode

# Verifier template from model card
VERIFIER_TEMPLATE = '''<Model Response>
{}
</Model Response>

<Reference Answer>
{}
</Reference Answer>

Your task is to evaluate the model response by comparing it to the reference answer. If the model response is correct and aligns with the reference answer, output "True". If it is incorrect or fails to select the correct option (if options are provided), output "False". {}'''

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Chat Template

In [None]:
# Define structured output format for mathematical reasoning
REASONING_START = '<THINK>' # Begin reasoning section
REASONING_END = '</THINK>'  # End reasoning section
ANSWER_START = '<ANSWER>'   # Begin final answer
ANSWER_END = '</ANSWER>'    # End final answer

# System prompt adapted for medical reasoning (inspired by paper's complex reasoning emphasis)
SYSTEM_PROMPT = f'''You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between {REASONING_START} and {REASONING_END}.
2. Provide your final answer between {ANSWER_START} and {ANSWER_END}.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.'''
print(SYSTEM_PROMPT)

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.


In [None]:
chat_template = ( # Build and assign chat_template to the tokenizer
    # If the very first message is a SYSTEM role, print it + <eos>:
    "{% if messages[0]['role'] == 'system' %}"
      "{{ messages[0]['content'] + eos_token }}"
      "{% set loop_messages = messages[1:] %}"
    "{% else %}"
      # Otherwise, inject our system_prompt + <eos>:
      "{{ '{system_prompt}' + eos_token }}"
      "{% set loop_messages = messages %}"
    "{% endif %}"

    # Now loop over the remaining messages (either user or assistant):
    "{% for message in loop_messages %}"
      "{% if message['role'] == 'user' %}"
        "{{ message['content'] }}"
      "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] + eos_token }}"
      "{% endif %}"
    "{% endfor %}"

    # If we asked for "add_generation_prompt", append <REASONING> to the end:
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"
    "{% endif %}"
)
# Replace with out specific template:
tokenizer.chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{SYSTEM_PROMPT}'")\
    .replace("'{reasoning_start}'", f"'{REASONING_START}'")

In [None]:
example_messages = [ # Quick sanity check of the template
    {'role': 'user', 'content': 'What is the most severe complication of dengue fever?'},
    {'role': 'assistant', 'content': (
        f'{REASONING_START}'
        'Dengue fever can lead to severe forms like dengue hemorrhagic fever. '
        'Considering symptoms and progression, the most severe is plasma leakage leading to shock.'
        f'{REASONING_END}{ANSWER_START}Dengue shock syndrome{ANSWER_END}'
    )},
    {'role': 'user', 'content': 'What drug is used for hypertension?'},
]
print(tokenizer.apply_chat_template(example_messages, tokenize=False, add_generation_prompt=True))

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|endoftext|>What is the most severe complication of dengue fever?<THINK>Dengue fever can lead to severe forms like dengue hemorrhagic fever. Considering symptoms and progression, the most severe is plasma leakage leading to shock.</THINK><ANSWER>Dengue shock syndrome</ANSWER><|endoftext|>What drug is used for hypertension?<THINK>


# Pre Fine-tuning (SFT)

## Data preparation

In [None]:
def format_dataset(x): # Format the dataset to follow our GRPO style formatting (adapted for medical SFT data)
    # The medical-o1-reasoning-SFT has 'Complex CoT' and 'Response' fields
    expected_answer, problem = x['Response'], x['Question']  # 'Response' is the final answer, 'question' is the problem
    thoughts = x['Complex_CoT'].strip()  # Use the provided Complex CoT as reasoning
    final_prompt = REASONING_START + thoughts + REASONING_END + ANSWER_START + expected_answer + ANSWER_END  # Custom formatting
    return [
        {'role': 'system'   , 'content': SYSTEM_PROMPT},
        {'role': 'user'     , 'content': problem},
        {'role': 'assistant', 'content': final_prompt},
    ]

sft_dataset = load_dataset('FreedomIntelligence/medical-o1-reasoning-SFT', 'en', split='train').to_pandas()
sft_dataset['messages'] = sft_dataset.apply(format_dataset, axis=1)
print(tokenizer.apply_chat_template(sft_dataset['messages'][0], tokenize=False))
sft_dataset

README.md: 0.00B [00:00, ?B/s]

medical_o1_sft.json:   0%|          | 0.00/58.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/19704 [00:00<?, ? examples/s]

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|endoftext|>Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?<THINK>Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?

But wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.

So, now I'm thinkin

Unnamed: 0,Question,Complex_CoT,Response,messages
0,Given the symptoms of sudden weakness in the l...,"Okay, let's see what's going on here. We've go...",The specific cardiac abnormality most likely t...,"[{'role': 'system', 'content': 'You are a medi..."
1,A 33-year-old woman is brought to the emergenc...,"Okay, let's figure out what's going on here. A...","In this scenario, the most likely anatomical s...","[{'role': 'system', 'content': 'You are a medi..."
2,A 61-year-old woman with a long history of inv...,"Okay, let's think about this step by step. The...",Cystometry in this case of stress urinary inco...,"[{'role': 'system', 'content': 'You are a medi..."
3,A 45-year-old man with a history of alcohol us...,"Alright, let’s break this down. We have a 45-y...",Considering the clinical presentation of sudde...,"[{'role': 'system', 'content': 'You are a medi..."
4,A 45-year-old man presents with symptoms inclu...,"Okay, so here's a 45-year-old guy who's experi...",Based on the clinical findings presented—wide-...,"[{'role': 'system', 'content': 'You are a medi..."
...,...,...,...,...
19699,A 21-year-old man presents with painful urinat...,"Okay, let's think this through. We've got a 21...",The most likely infectious cause of the sympto...,"[{'role': 'system', 'content': 'You are a medi..."
19700,In a 7-month-old child diagnosed with H. influ...,"Alright, so we're dealing with a 7-month-old c...",In a 7-month-old child who has been diagnosed ...,"[{'role': 'system', 'content': 'You are a medi..."
19701,What is the treatment of choice for a 40-year-...,"Okay, so we have a 40-year-old woman, it's her...",The treatment of choice in this scenario is la...,"[{'role': 'system', 'content': 'You are a medi..."
19702,What is the most likely underlying mechanism c...,"Alright, so we've got a 25-year-old woman deal...",In a 25-year-old woman with systemic lupus ery...,"[{'role': 'system', 'content': 'You are a medi..."


In [None]:
# Truncate pre fine-tuning sft_dataset to max_seq_length / 2 since we don't want too long reasoning traces
sft_dataset['seq_length'] = sft_dataset['messages'].apply(lambda x: len(tokenizer.apply_chat_template(x)))
print('Token-length percentiles (50/90/99):', np.percentile(sft_dataset['seq_length'], [50, 90, 99]))

threshold = max_seq_length / 2
sft_dataset_filtered = sft_dataset.loc[sft_dataset['seq_length'] <= threshold].copy()
print(f'Remaining for training (<= {threshold} tokens): {len(sft_dataset_filtered)}/{len(sft_dataset)}')

sft_dataset_filtered['text'] = tokenizer.apply_chat_template(sft_dataset_filtered['messages'].values.tolist(), tokenize=False)
sft_dataset_filtered = Dataset.from_pandas(sft_dataset_filtered)
sft_dataset_filtered

Token-length percentiles (50/90/99): [ 686.    889.   1115.97]
Remaining for training (<= 1024.0 tokens): 19171/19704


Dataset({
    features: ['Question', 'Complex_CoT', 'Response', 'messages', 'seq_length', 'text', '__index_level_0__'],
    num_rows: 19171
})

## Pre fine-tune to understand custom GRPO formatting

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset_filtered,
    args=SFTConfig(
        dataset_text_field='text',
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        optim='adamw_8bit',
        weight_decay=0.01,
        learning_rate=2e-4,
        lr_scheduler_type='cosine',
        warmup_ratio=0.05,
        logging_steps=100,
        report_to='none',
    )
)
trainer.train()
trainer.save_model(f'./{model_name}_sft')

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/19171 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 19,171 | Num Epochs = 3 | Total steps = 1,800
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 8 x 1) = 32
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
100,1.4455
200,1.3135
300,1.3001
400,1.2882
500,1.2759
600,1.2758
700,1.2262
800,1.2227
900,1.2162
1000,1.2198


## Check if model has learnt to follow the format

In [None]:
# # The FastLanguageModel.from_pretrained should be only called once. Otherwise, it will be OOM
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=f'./{model_name}_sft', # Reload LoRA weights
#     max_seq_length=max_seq_length,
#     load_in_4bit=False,         # False for LoRA 16bit
#     fast_inference=True,        # Enable vLLM fast inference
#     max_lora_rank=lora_rank,
#     gpu_memory_utilization=0.8, # Reduce if out of memory
# )
# FastLanguageModel.for_inference(model)

In [None]:
text = tokenizer.apply_chat_template( # Render into a single string and append <REASONING> for generation
    sft_dataset_filtered[1]['messages'][:2],
    tokenize=False, add_generation_prompt=True, # Append the final <REASONING>
)
_ = model.generate(
    **tokenizer(text, return_tensors='pt').to('cuda'),
    temperature=0, max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=False), # Stream the model's generations (CoT + solution)
)

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|endoftext|>A 33-year-old woman is brought to the emergency department 15 minutes after being stabbed in the chest with a screwdriver. Given her vital signs of pulse 110/min, respirations 22/min, and blood pressure 90/65 mm Hg, along with the presence of a 5-cm deep stab wound at the upper border of the 8th rib in the left midaxillary line, which anatomical structure in her chest is most likely to be injured?<THINK>Okay, so we have a 33-year-old woman who was stabbed in the chest with a screwdriver. That's pretty serious. Let's think about what's going on here. She's got a 5-cm deep stab wound at the upper border of the 8th rib in the le

In [None]:
del sft_dataset, sft_dataset_filtered
gc.collect()
torch.cuda.empty_cache()

# Post Fine-tuning (RL)

## Data preparation

In [None]:
def process_dataset_sample(example):  # Convert medical example to conversation format for GRPO training
    return {
        'prompt': [  # Create conversation with system prompt for structured reasoning
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['Open-ended Verifiable Question']},
        ],
        # Ground truth answer for reward functions (text, not numerical)
        'answer': example['Ground-True Answer'].strip()
    }

In [None]:
# Load the verifiable medical problems dataset for RL
train_dataset = load_dataset('FreedomIntelligence/medical-o1-verifiable-problem', split='train')
train_dataset = train_dataset.map(process_dataset_sample)
print(f'Training samples: {len(train_dataset):,}\n'  # ~20K as per paper
      f"- Sample question: {train_dataset[0]['prompt'][1]['content']}\n"
      f"- Sample answer (ground truth for rewards): {train_dataset[0]['answer']}\n"
      f"- Prompt (system + user):\n{train_dataset[0]['prompt']}")

Training samples: 40,644
- Sample question: An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of her gastrointestinal blood loss?
- Sample answer (ground truth for rewards): Gastric ulcer
- Prompt (system + user):
[{'content': 'You are a medical reasoning assistant. When given a medical problem:\n1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.\n2. Provide your final answer between <ANSWER> and </ANSWER>.\n3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.', 'role': 'system'}, {'content': 'An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of

In [None]:
# Get the top 90% prompt length so we don't accidentally truncate them, i.e. we'll remove the top 10% long prompts
tokenized_dataset = train_dataset.map(
    lambda x: {'tokens': tokenizer.apply_chat_template(x['prompt'], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {'length': len(x['tokens'])})
print(tokenizer.decode(tokenized_dataset[0]['tokens']))

thresholds = np.percentile(tokenized_dataset['length'], [50, 90, 99])
max_prompt_length = int(thresholds[1])
print('Token-length percentiles (50/90/99):', thresholds, '=> Choose max_prompt_length =', max_prompt_length)

# Filter only samples smaller than 90% max length
train_dataset = train_dataset.select(np.where(np.array(tokenized_dataset['length']) <= max_prompt_length)[0])
print(f'Remaining for training (<= {max_prompt_length} tokens): {len(train_dataset)}/{len(tokenized_dataset)}')
del tokenized_dataset

You are a medical reasoning assistant. When given a medical problem:
1. Show your step-by-step complex reasoning (including reflection, backtracking, and alternative paths) between <THINK> and </THINK>.
2. Provide your final answer between <ANSWER> and </ANSWER>.
3. Be precise, consider medical aliases/synonyms, and show all deliberation steps clearly.<|endoftext|>An 88-year-old woman with osteoarthritis is experiencing mild epigastric discomfort and has vomited material resembling coffee grounds multiple times. Considering her use of naproxen, what is the most likely cause of her gastrointestinal blood loss?<THINK>
Token-length percentiles (50/90/99): [116. 160. 204.] => Choose max_prompt_length = 160
Remaining for training (<= 160 tokens): 36664/40644


## Multi-reward design

In [None]:
match_format = re.compile(                              # Match the reasoning sections and answers
    rf'{REASONING_END}.*?'                              # We always prepend REASONING_START
    rf'{ANSWER_START}(.+?){ANSWER_END}'                 # Answer section with capture group (text, not number)
    rf'[\s]{{0,}}(?:{re.escape(tokenizer.eos_token)})?' # Add optional EOS token matching
    rf'[\s]{{0,}}$',                                    # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL,                     # Multi-line matching with . matching newlines
)
def verify(guess, true_answer):
    verifier_input = VERIFIER_TEMPLATE.format(guess.strip(), true_answer.strip(), verifier_tokenizer.eos_token)
    input_batch = verifier_tokenizer([verifier_input], return_tensors='pt').to(verifier_model.device)
    with torch.no_grad():
        logits = verifier_model(**input_batch, return_dict=True).logits
    probabilities = F.softmax(logits, dim=-1)
    return probabilities[0, 1].item()

In [None]:
def match_format_strictly(completions, **kwargs) -> list[float]:
    ''' Reward Function 1: Exact Format Compliance
    High reward (3.0) for perfect format adherence. Ensure model learns the complete structured output pattern
    '''
    return [
        3.0 if match_format.search(completion[0]['content']) else 0.0
        for completion in completions
    ]

In [None]:
# If it fails, reward the model if it at least follows the format partially, by counting each symbol
def match_format_softly(completions, **kwargs) -> list[float]:
    ''' Reward Function 2: Partial Format Credit
    Graduated scoring for format elements. Encourage learning individual components even if not perfect
    '''
    rewards = []
    for completion in completions:
        reward = 0
        response = completion[0]['content']

        # Count how many keywords are seen - we penalize if too many!
        # Award +0.5 for correct token count, -0.5 for wrong count
        # reward += 0.5 if response.count(REASONING_START) == 1 else -0.5  # Prepended
        reward += 0.5 if response.count(REASONING_END) == 1 else -0.5
        reward += 0.5 if response.count(ANSWER_START) == 1 else -0.5
        reward += 0.5 if response.count(ANSWER_END) == 1 else -0.5
        rewards.append(reward)
    return rewards

In [None]:
# Extract the generated answer, and reward or penalize it
def check_answer_correctness(completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 3: Graduated scoring for medical accuracy using verifier
    - 5.0: Verifier confirms full alignment (True)
    - 2.0: Partial credit if verifier prob > 0.5 but not full (use prob for gradation)
    - -2.5: Wrong answer (False)
    Handles aliases via semantic verification
    '''
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [  # Extract answers using format pattern
        guess.group(1) if (guess := match_format.search(r)) else None
        for r in responses
    ]
    rewards = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:    # No extractable answer
            rewards.append(-2.0)
            continue

        prob_true = verify(guess, true_answer)     # Use verifier to check semantic alignment
        if prob_true > 0.9: rewards.append(5.0)    # High confidence correct
        elif prob_true > 0.7: rewards.append(3.5)  # Strong alignment
        elif prob_true > 0.5: rewards.append(2.0)  # Partial/approximate
        elif prob_true > 0.3: rewards.append(1.5)  # Reasonable attempt
        else: rewards.append(-2.5)                 # Incorrect
    return rewards

## GRPO training setup

In [None]:
max_prompt_length += 1
max_completion_length = max_seq_length - max_prompt_length

# Encourage exploration during training
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [None]:
training_args = GRPOConfig(          # Configure GRPO training parameters for mathematical reasoning
    output_dir=f'/tmp/{model_name}', # Directory for checkpoints and logs
    vllm_sampling_params=vllm_sampling_params,
    # Training speed control
    num_train_epochs=1,              # Total number of training epochs
    per_device_train_batch_size=4,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=16,  # Effective batch size = 4 * 16 = 64
    # Precision & Optimization
    scale_rewards='batch',           # Calculate mean at local/group level and std at global/batch level enables more robust reward shaping
    loss_type='dr_grpo',             # Fully remove response length bias, dividing by a constant instead of the sequence length
    optim='adamw_8bit',
    weight_decay=0.1,                # Regularization
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    bf16=is_bfloat16_supported(),    # Enable mixed-precision training if a CUDA GPU is available (faster, less memory)
    learning_rate=1e-5,              # Conservative LR to prevent destabilizing reasoning
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    # Generation control
    temperature=1.0,                 # Encourage exploration during training
    num_generations=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    # Reporting and saving
    report_to='wandb',
    logging_steps=79,
    logging_strategy='steps',
    save_total_limit=1,
    # max_steps=100,
)

## Train the model

In [None]:
%%time
trainer = GRPOTrainer(            # Initialize GRPO trainer with multi-reward system (improvement over paper's single-reward PPO)
    model=model,                  # LoRA-adapted quantized model
    processing_class=tokenizer,
    train_dataset=train_dataset,  # Processed medical dataset
    args=training_args,           # Training configuration
    reward_funcs=[                # 3 complementary reward functions
        match_format_strictly,    # Perfect structure compliance
        match_format_softly,      # Partial format credit
        check_answer_correctness, # Semantic accuracy via verifier (handles aliases)
    ]
)
trainer.train()
trainer.save_model(f'./{model_name}_grpo')

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 36,664 | Num Epochs = 1 | Total steps = 2,291
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 16 x 1) = 64
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m18520339[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std
79,0.0301,3.860759,3.001184,522.364122,340.708861,819.64557,0.000198,522.095166,340.708861,813.265823,0.529842,2.999407,0.004747,1.499407,0.004747,-0.638054,2.999443
158,0.0235,3.864023,2.98145,513.598299,331.797468,810.658228,0.000198,513.327259,331.797468,797.316456,0.526748,2.998813,0.009494,1.499209,0.006329,-0.633999,2.978941
237,0.0265,4.013944,3.106013,486.568829,318.974684,763.037975,0.0,486.568829,318.974684,763.037975,0.533123,3.0,0.0,1.5,0.0,-0.486056,3.106013
316,0.0223,4.090882,3.10523,461.945807,303.063291,701.544304,0.0,461.945807,303.063291,701.544304,0.503646,3.0,0.0,1.5,0.0,-0.409118,3.10523
395,0.0148,4.108386,3.090155,439.083267,282.025316,676.898734,0.0,439.083267,282.025316,676.898734,0.504607,3.0,0.0,1.5,0.0,-0.391614,3.090155
474,0.0176,4.03392,3.080857,438.765229,286.0,677.367089,0.0,438.765229,286.0,677.367089,0.500037,3.0,0.0,1.5,0.0,-0.46608,3.080857
553,0.0185,4.293908,3.184084,415.195214,263.050633,638.455696,0.0,415.195214,263.050633,638.455696,0.484424,2.999407,0.004747,1.499802,0.001582,-0.205301,3.183212
632,0.0185,4.168315,3.115968,398.099288,256.367089,615.0,0.0,398.099288,256.367089,615.0,0.470487,3.0,0.0,1.5,0.0,-0.331685,3.115968
711,0.017,4.258801,3.1739,404.322785,258.974684,639.506329,0.000198,404.030406,258.974684,625.189873,0.467929,2.999407,0.004747,1.499407,0.004747,-0.240012,3.172238
790,0.0222,4.29371,3.186585,387.652888,243.493671,597.518987,0.0,387.652888,243.493671,597.518987,0.488015,3.0,0.0,1.5,0.0,-0.20629,3.186585


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std
79,0.0301,3.860759,3.001184,522.364122,340.708861,819.64557,0.000198,522.095166,340.708861,813.265823,0.529842,2.999407,0.004747,1.499407,0.004747,-0.638054,2.999443
158,0.0235,3.864023,2.98145,513.598299,331.797468,810.658228,0.000198,513.327259,331.797468,797.316456,0.526748,2.998813,0.009494,1.499209,0.006329,-0.633999,2.978941
237,0.0265,4.013944,3.106013,486.568829,318.974684,763.037975,0.0,486.568829,318.974684,763.037975,0.533123,3.0,0.0,1.5,0.0,-0.486056,3.106013
316,0.0223,4.090882,3.10523,461.945807,303.063291,701.544304,0.0,461.945807,303.063291,701.544304,0.503646,3.0,0.0,1.5,0.0,-0.409118,3.10523
395,0.0148,4.108386,3.090155,439.083267,282.025316,676.898734,0.0,439.083267,282.025316,676.898734,0.504607,3.0,0.0,1.5,0.0,-0.391614,3.090155
474,0.0176,4.03392,3.080857,438.765229,286.0,677.367089,0.0,438.765229,286.0,677.367089,0.500037,3.0,0.0,1.5,0.0,-0.46608,3.080857
553,0.0185,4.293908,3.184084,415.195214,263.050633,638.455696,0.0,415.195214,263.050633,638.455696,0.484424,2.999407,0.004747,1.499802,0.001582,-0.205301,3.183212
632,0.0185,4.168315,3.115968,398.099288,256.367089,615.0,0.0,398.099288,256.367089,615.0,0.470487,3.0,0.0,1.5,0.0,-0.331685,3.115968
711,0.017,4.258801,3.1739,404.322785,258.974684,639.506329,0.000198,404.030406,258.974684,625.189873,0.467929,2.999407,0.004747,1.499407,0.004747,-0.240012,3.172238
790,0.0222,4.29371,3.186585,387.652888,243.493671,597.518987,0.0,387.652888,243.493671,597.518987,0.488015,3.0,0.0,1.5,0.0,-0.20629,3.186585


CPU times: user 11h 5min 5s, sys: 1min 55s, total: 11h 7min
Wall time: 11h 3min 4s


In [None]:
gc.collect()
torch.cuda.empty_cache()

# Evaluation

In [None]:
# https://docs.unsloth.ai/models/qwen3-how-to-run-and-fine-tune#official-recommended-settings
sampling_params = SamplingParams(
    temperature = 0.6,
    min_p = 0.0,
    top_p = 0.95,
    top_k = 20,
    max_tokens = max_completion_length,
)
match_answer_letter = re.compile( # Regex to extract the answer letter
    rf'{ANSWER_START}.*?[\s]{{0,}}([A-D])',
    flags=re.MULTILINE | re.DOTALL
)

## Influence of LoRA

In [None]:
example_text = 'What drug is used for hypertension?'
print(model.fast_generate( # Try the model without any GRPO trained
    example_text, sampling_params=sampling_params,
    lora_request=None
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 A: Nifedipine B: Clonidine C: Atenolol D: Losartan
The correct answer is:

**D: Losartan**

**Losartan** is an angiotensin II receptor blocker (ARB) commonly used to treat hypertension (high blood pressure). It works by blocking the action of angiotensin II, a hormone that causes blood vessels to narrow, thereby lowering blood pressure.

The other options are also used in hypertension treatment but fall into different classes:

- **A: Nifedipine** is a calcium channel blocker.
- **B: Clonidine** is a centrally acting alpha-2 adrenergic agonist.
- **C: Atenolol** is a beta-blocker.


In [None]:
tensors = {}
with safe_open(f'./{model_name}_grpo/adapter_model.safetensors', framework='pt') as f:
    for key in f.keys(): # Verify both A and B are non zero
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

In [None]:
# Load the LoRA and test without using system prompt
# which should not (or minimal) affect the model's original reasoning ability
text = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': example_text}],
    add_generation_prompt=True, tokenize=False,
)
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Alright, let's think about how hypertension, or high blood pressure, is managed. It's a condition where the blood pressure is consistently too high, and it can lead to serious health issues if not properly controlled. So, the main goal is to reduce this pressure to keep the heart, brain, and other organs healthy.

The first thing to do is to identify the cause of the high blood pressure. Sometimes, it's due to lifestyle factors like diet and exercise, and sometimes, it's related to underlying conditions like diabetes or kidney problems. If the cause is medication-related, we'll need to adjust the treatment plan accordingly.

Now, when it comes to medications, there are several types of drugs that are commonly used. These include:

1. **ACE inhibitors** – these work by preventing the formation of angiotensin II, a hormone that narrows blood vessels, thus helping to lower blood pressure.

2. **Beta-blockers** – they help to reduce the heart rate and the force of the heartbeat, which can 

In [None]:
# Test using system prompt
text = tokenizer.apply_chat_template([
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user'  , 'content': example_text},
], add_generation_prompt=True, tokenize=False)

# Compare results with system prompt but without LoRA
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=None,
)[0].outputs[0].text)

# Reasoning model is much better - it's not always correct, since we only trained it for an hour
# It'll be better if we extend the sequence length and train for longer
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 
To determine the appropriate drug for hypertension, we need to consider several factors including the patient's age, weight, gender, comorbidities, and specific blood pressure readings. Here’s a step-by-step approach:

1. **Initial Assessment**: Gather detailed information about the patient's medical history, current medications, and any previous treatments for hypertension. This includes checking for any contraindications or interactions with other drugs.
2. **Blood Pressure Monitoring**: Measure the patient’s blood pressure regularly to establish a baseline. This helps in assessing the severity of hypertension and guiding treatment decisions.
3. **Risk Factors**: Evaluate the patient’s risk factors such as age, family history, smoking status, alcohol consumption, and physical activity levels. These factors influence the choice of antihypertensive therapy.
4. **Comorbid Conditions**: Consider any coexisting conditions like diabetes, kidney disease, heart disease, or sleep apnea, whi

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Okay, let's think about hypertension. It's a condition where blood pressure is consistently too high. That's not good because it can lead to serious health issues like heart disease and stroke. So, managing it is crucial.

Now, there are several drugs used to treat hypertension, known as antihypertensives. These include things like diuretics, ACE inhibitors, beta-blockers, calcium channel blockers, and angiotensin II receptor blockers.

Each of these drugs works in different ways to lower blood pressure. For example, diuretics help by getting rid of excess sodium and water in the body, reducing blood volume. ACE inhibitors and ARBs help relax blood vessels by preventing the angiotensin II from causing constriction.

Choosing the right medication depends on various factors, like the patient's age, other health conditions, and any medications they might be taking. So, it's not just about one drug but often involves a combination of medications tailored to the individual.

In conclusion, 

## Performance on benchmark datasets

In [None]:
def format_mcq_prompt(example, dataset_name):
    user_content = 'Please answer the following multiple-choice question, ensuring your response only '
    user_content += 'contain the correct option letter with no extra text:\n' + example['question'] + '\n'

    if dataset_name == 'medqa':
        user_content += '\n'.join([f'{k}. {v}' for k, v in example['options'].items()])
        gold_letter = example['answer_idx']
    elif dataset_name == 'medmcqa':
        choices = [example['opa'], example['opb'], example['opc'], example['opd']]
        user_content += '\n'.join([f'{chr(65+i)}. {choice}' for i, choice in enumerate(choices)])
        gold_letter = chr(ord('A') + example['cop'])
    elif dataset_name == 'pubmedqa':
        contexts = ' '.join(example['context']['contexts'])  # Join abstracts as context
        user_content = f'Context: {contexts}\n{user_content}A. yes\nB. no\nC. maybe'
        gold_letter = 'A' if example['final_decision'] == 'yes' else 'B' if example['final_decision'] == 'no' else 'C'
    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': user_content},
        ],
        'answer': gold_letter.strip()
    }

In [None]:
# Load evaluation benchmarks as per paper: MedQA-USMLE (test), MedMCQA (validation), PubMedQA (PQA-L test equivalent)
medqa_dataset = load_dataset('GBaker/MedQA-USMLE-4-options', split='test') # 1273 samples; gold in 'answer_idx' but use open-ended, verifier with options['answer_idx']
medmcqa_dataset = load_dataset('openlifescienceai/medmcqa', split='validation') # 4183 samples; gold in 'cop' (choice index)
pubmedqa_dataset = load_dataset('pubmed_qa', 'pqa_labeled', split='train') # 1000 questions in the PQA-L are used as the test set
eval_datasets = {
    'MedQA_USLME_test': medqa_dataset.map(lambda x: format_mcq_prompt(x, 'medqa')),
    'MedMCQA_validation': medmcqa_dataset.map(lambda x: format_mcq_prompt(x, 'medmcqa')),
    'PubMedQA_test': pubmedqa_dataset.map(lambda x: format_mcq_prompt(x, 'pubmedqa')),
}

Map:   0%|          | 0/1273 [00:00<?, ? examples/s]

Map:   0%|          | 0/4183 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
%%time
results = {}
for ds_name, dataset in eval_datasets.items():
    num_samples = len(dataset)
    print(f'\n{ds_name}: {num_samples} samples')

    test_texts = [
        tokenizer.apply_chat_template(sample['prompt'], add_generation_prompt=True, tokenize=False)
        for sample in dataset
    ]
    outputs_with_lora = model.fast_generate(
        test_texts, sampling_params=sampling_params,
        lora_request=model.load_lora(f'./{model_name}_grpo'),
    )
    outputs_without_lora = model.fast_generate(
        test_texts, sampling_params=sampling_params,
        lora_request=None,
    )

    # Compare the correct amount of using and not using LoRA
    no_lora_format_cnt = lora_format_cnt = 0
    no_lora_answer_cnt = lora_answer_cnt = 0
    no_lora_all_cnt = lora_all_cnt = 0

    for output_with_lora, output_without_lora, gt_answer in zip(outputs_with_lora, outputs_without_lora, dataset['answer']):
        # With LoRA
        response_lora = output_with_lora.outputs[0].text
        correct_format_lora = match_format.search(response_lora) is not None
        extracted_guess_lora = match_answer_letter.search(response_lora).group(1) if correct_format_lora else None
        correct_answer_lora = extracted_guess_lora == gt_answer
        correct_all_lora = correct_format_lora and correct_answer_lora

        if correct_format_lora: lora_format_cnt += 1
        if correct_answer_lora: lora_answer_cnt += 1
        if correct_all_lora: lora_all_cnt += 1

        # Without LoRA
        response_no_lora = output_without_lora.outputs[0].text
        correct_format_no_lora = match_format.search(response_no_lora) is not None
        extracted_guess_no_lora = match_answer_letter.search(response_no_lora).group(1) if correct_format_no_lora else None
        correct_answer_no_lora = extracted_guess_no_lora == gt_answer
        correct_all_no_lora = correct_format_no_lora and correct_answer_no_lora

        if correct_format_no_lora: no_lora_format_cnt += 1
        if correct_answer_no_lora: no_lora_answer_cnt += 1
        if correct_all_no_lora: no_lora_all_cnt += 1

    results[ds_name] = {
        'Without LoRA': {
            'Correct Format': f'{no_lora_format_cnt}/{num_samples} ({no_lora_format_cnt / num_samples * 100:.2f}%)',
            'Correct Answer': f'{no_lora_answer_cnt}/{num_samples} ({no_lora_answer_cnt / num_samples * 100:.2f}%)',
            'Correct Both': f'{no_lora_all_cnt}/{num_samples} ({no_lora_all_cnt / num_samples * 100:.2f}%)',
        },
        'With LoRA': {
            'Correct Format': f'{lora_format_cnt}/{num_samples} ({lora_format_cnt / num_samples * 100:.2f}%)',
            'Correct Answer': f'{lora_answer_cnt}/{num_samples} ({lora_answer_cnt / num_samples * 100:.2f}%)',
            'Correct Both': f'{lora_all_cnt}/{num_samples} ({lora_all_cnt / num_samples * 100:.2f}%)',
        },
        'Improvement': {
            'Correct Format': f'+{lora_format_cnt - no_lora_format_cnt} ({(lora_format_cnt - no_lora_format_cnt) / num_samples * 100:.2f}%)',
            'Correct Answer': f'+{lora_answer_cnt - no_lora_answer_cnt} ({(lora_answer_cnt - no_lora_answer_cnt) / num_samples * 100:.2f}%)',
            'Correct Both': f'+{lora_all_cnt - no_lora_all_cnt} ({(lora_all_cnt - no_lora_all_cnt) / num_samples * 100:.2f}%)',
        }
    }
    display(pd.DataFrame(results[ds_name]).T)


MedQA_USLME_test: 1273 samples


Adding requests:   0%|          | 0/1273 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1273 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/1273 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1273 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Unnamed: 0,Correct Format,Correct Answer,Correct Both
Without LoRA,531/1273 (41.71%),239/1273 (18.77%),239/1273 (18.77%)
With LoRA,1271/1273 (99.84%),570/1273 (44.78%),570/1273 (44.78%)
Improvement,+740 (58.13%),+331 (26.00%),+331 (26.00%)



MedMCQA_validation: 4183 samples


Adding requests:   0%|          | 0/4183 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4183 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/4183 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4183 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Unnamed: 0,Correct Format,Correct Answer,Correct Both
Without LoRA,1543/4183 (36.89%),655/4183 (15.66%),655/4183 (15.66%)
With LoRA,4180/4183 (99.93%),1858/4183 (44.42%),1858/4183 (44.42%)
Improvement,+2637 (63.04%),+1203 (28.76%),+1203 (28.76%)



PubMedQA_test: 1000 samples


Adding requests:   0%|          | 0/1000 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/1000 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Unnamed: 0,Correct Format,Correct Answer,Correct Both
Without LoRA,415/1000 (41.50%),280/1000 (28.00%),280/1000 (28.00%)
With LoRA,1000/1000 (100.00%),653/1000 (65.30%),653/1000 (65.30%)
Improvement,+585 (58.50%),+373 (37.30%),+373 (37.30%)


CPU times: user 11min 43s, sys: 3.07 s, total: 11min 46s
Wall time: 11min 39s


# Inference

In [None]:
# # The FastLanguageModel.from_pretrained should be only called once. Otherwise, it will be OOM
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=f'./{model_name}_grpo', # Reload LoRA weights
#     max_seq_length=max_seq_length,
#     load_in_4bit=False,         # False for LoRA 16bit
#     fast_inference=True,        # Enable vLLM fast inference
#     max_lora_rank=lora_rank,
#     gpu_memory_utilization=0.8, # Reduce if out of memory
# )
# FastLanguageModel.for_inference(model)

In [None]:
def generate_with_reasoning(questions, max_completion_length=1024, system_prompt=SYSTEM_PROMPT,):
    conversations = [[ # Format input using conversation template
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': question},
    ] for question in questions]

    prompts = [tokenizer.apply_chat_template( # Apply chat template and tokenize
        conversation,
        add_generation_prompt=True,         # Add assistant prompt
        tokenize=False,                     # Return string, not tokens
    ) for conversation in conversations]

    # Generate response with reasoning-optimized parameters
    inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(model.device)
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(           # Generate response with reasoning-optimized parameters
            **inputs,
            max_new_tokens=max_completion_length,
            temperature=0.6,                # Balance creativity and consistency
            top_p=0.95,                     # Nucleus sampling for quality
            top_k=20,
            do_sample=True,                 # Enable sampling for varied reasoning paths
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,         # Reduce repetitive reasoning steps
            length_penalty=1.0,             # Neutral preference for response length
            early_stopping=True,            # Stop at natural completion
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )
    end_time = time.time()
    inference_duration = end_time - start_time
    num_generated_tokens = output_ids.shape[1] - inputs['input_ids'].shape[1]

    output_ids = output_ids[:, inputs['input_ids'][0].shape[-1]:output_ids.shape[-1]]
    responses = tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Decode and extract only the generated portion
    return responses, inference_duration, num_generated_tokens

In [None]:
test_dataset = eval_datasets['MedQA_USLME_test']
medical_question = test_dataset[2]['prompt'][-1]['content']
expected_answer = test_dataset[2]['answer']

print(medical_question, '\n\n===== Response =====')
medical_responses, inference_duration, num_generated_tokens = generate_with_reasoning([medical_question], max_completion_length)
medical_response = medical_responses[0]
print('Inference time (secs):', inference_duration)
print('Generated tokens:', num_generated_tokens)

Please answer the following multiple-choice question, ensuring your response only contain the correct option letter with no extra text:
Two weeks after undergoing an emergency cardiac catherization with stenting for unstable angina pectoris, a 61-year-old man has decreased urinary output and malaise. He has type 2 diabetes mellitus and osteoarthritis of the hips. Prior to admission, his medications were insulin and naproxen. He was also started on aspirin, clopidogrel, and metoprolol after the coronary intervention. His temperature is 38°C (100.4°F), pulse is 93/min, and blood pressure is 125/85 mm Hg. Examination shows mottled, reticulated purplish discoloration of the feet. Laboratory studies show:
Hemoglobin count 14 g/dL
Leukocyte count 16,400/mm3
Segmented neutrophils 56%
Eosinophils 11%
Lymphocytes 31%
Monocytes 2%
Platelet count 260,000/mm3
Erythrocyte sedimentation rate 68 mm/h
Serum
Urea nitrogen 25 mg/dL
Creatinine 4.2 mg/dL
Renal biopsy shows intravascular spindle-shaped vac

In [None]:
# Validate format compliance
has_answer = ANSWER_START in medical_response and ANSWER_END in medical_response
print('Reasoning section:', REASONING_END in medical_response)
print('Answer section:', has_answer)

if has_answer: # Check answer accuracy if answer section exists
    # answer_text = medical_response.split(ANSWER_START)[1].split(ANSWER_END)[0].strip()
    answer_text = match_answer_letter.search(medical_response).group(1).upper()
    print('Extracted:', answer_text)
    print('Expected:', expected_answer)
    print('Correct:', answer_text == expected_answer)

Reasoning section: True
Answer section: True
Extracted: B
Expected: B
Correct: True
