In [None]:
from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")


def preprocess_function(example):
    # Format as a conversation for SFTTrainer
    messages = [
        {"role": "user", "content": example["Question"]},
        {
            "role": "assistant",
            "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}",
        },
    ]
    return {"messages": messages}


dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])

# Split the training dataset to create train/validation/test sets
# (80% train, 10% validation, 10% test)
first_split = dataset["train"].train_test_split(test_size=0.2, seed=816)  # 80% train, 20% temp
temp_dataset = first_split["test"]
second_split = temp_dataset.train_test_split(test_size=0.5, seed=816)  # Split the 20% into 10% each

train_dataset = first_split["train"]  # 80%
eval_dataset = second_split["train"]  # 10%
test_dataset = second_split["test"]  # 10%

print("Sample:", next(iter(train_dataset)))

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")

Sample: {'messages': [{'content': 'A patient presents with microcytic hypochromic anemia, hemoglobin level of 9%, serum iron of 20 µg/dL, ferritin level of 800 ng/mL, and transferrin percentage saturation of 64%. Based on these laboratory findings, what is the possible diagnosis?', 'role': 'user'}, {'content': "<think>Okay, so we have a case of microcytic hypochromic anemia. That generally means the red blood cells are small and pale, which can occur in a few different conditions.\n\nLet's start by looking at the serum iron level. It’s reported at 20 µg/dL, which is definitely on the low side. Low serum iron is commonly seen in iron deficiency anemia, but it can also happen due to chronic diseases or other less common conditions.\n\nNext, there’s the ferritin level to consider. It’s really high at 800 ng/mL. Ferritin being high makes me think more about conditions like an inflammatory state or chronic disease, rather than iron deficiency, where ferritin would typically be low.\n\nNow, 

In [2]:
from os.path import join

import yaml

In [None]:
# Load configuration from config.yaml
with open("config.yaml") as f:
    config = yaml.safe_load(f)

MODEL_NAME = config["base_model_name"]
print(f"Using model: {MODEL_NAME}")

adapter_dir = join(config["adapter_dir_prefix"], MODEL_NAME)
print(f"LoRA adapter directory will be saved to: {adapter_dir}")

lora_rank = config["lora_rank"]
lora_alpha = config["lora_alpha"]
print(f"LoRA rank is {lora_rank} and LoRA alpha is {lora_alpha}")

Using model: HuggingFaceTB/SmolLM-135M-Instruct
LoRA adapter directory will be saved to: lora_adapter/HuggingFaceTB/SmolLM-135M-Instruct


In [None]:
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # Changed from float16 to bfloat16
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

In [4]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,  # Changed from float16 to bfloat16 and dtype to torch_dtype
    use_cache=True,  # Whether to cache attention outputs to speed up inference
    quantization_config=bnb_config,
    local_files_only=True,  # Use cache first
    device_map="auto",
    attn_implementation="flash_attention_2",
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    local_files_only=True,
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"],
)

In [6]:
# Configure the SFT training parameters
sft_config = SFTConfig(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_length=512,
    logging_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=20,
    save_steps=20,
    bf16=True,
    gradient_checkpointing=True,
    loss_type="dft",  # Dynamic fine tuning
    completion_only_loss=True,  # Train only on assistant responses
)

In [7]:
from transformers import DataCollatorWithFlattening

data_collator = DataCollatorWithFlattening()

In [8]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=sft_config,
    peft_config=peft_config,
    data_collator=data_collator,
)

In [9]:
# Check GPU memory usage before training
GB = 2**30
if torch.cuda.is_available():
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / GB:.2f} GB")
    print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / GB:.2f} GB")
    print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / GB:.2f} GB")
else:
    print("CUDA is not available")

GPU Memory allocated: 0.11 GB
GPU Memory reserved: 0.18 GB
GPU Memory available: 8.00 GB


In [10]:
trainer.train()

Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
20,0.145,0.144826,2.234576,320470.0,0.449162
40,0.144,0.143242,2.178591,640524.0,0.450605
60,0.1418,0.141499,2.120491,962519.0,0.452169
80,0.1404,0.139755,2.066365,1284493.0,0.45377


KeyboardInterrupt: 

In [None]:
# Save the LoRA adapter
print(f"Saving LoRA adapter to {adapter_dir}")

trainer.model.save_pretrained(adapter_dir)
print(f"LoRA adapter saved successfully to {adapter_dir}!")
tokenizer.save_pretrained(adapter_dir)

In [1]:
print(
    "Now run the notebook `trl_medical_reasoning_inference.ipynb` to use the LoRA fine-tuned model."
)

Now run the notebook `trl_medical_reasoning_inference.ipynb` to use the LoRA fine-tuned model.
