In [1]:
from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")


def preprocess_function(example):
    # Format as a conversation for SFTTrainer
    messages = [
        {"role": "user", "content": example["Question"]},
        {
            "role": "assistant",
            "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}",
        },
    ]
    return {"messages": messages}


dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])

# Split the training dataset to create validation set (90% train, 10% validation)
split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=816)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")
print("Sample:", next(iter(train_dataset)))

Training samples: 17733
Validation samples: 1971
Sample: {'messages': [{'content': 'Based on the chest radiograph and abdominal CT scan of a middle-aged male complaining of nagging abdominal pain for the past 2 weeks, what is the probable diagnosis that should be considered?', 'role': 'user'}, {'content': "<think>Okay, let's break this down. We have a middle-aged male who's been dealing with this persistent abdominal pain for a couple of weeks. That doesn't sound fun. I guess we should start by thinking about what could cause that sort of ongoing discomfort in someone like him. Abdominal pain can mean a lot of things, so let's see how imaging can help us out here. \n\nFirst, there's the chest radiograph and the abdominal CT scan to look into. The chest radiograph might seem odd at first when we're dealing with abdominal pain, but sometimes issues in the chest, like lung problems or even things pressing against the diaphragm, can cause referred pain to the abdomen. It's all connected! \

In [None]:
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # Changed from float16 to bfloat16
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceTB/SmolLM-135M-Instruct",
    dtype=torch.bfloat16,  # Changed from float16 to bfloat16 and dtype to torch_dtype
    use_cache=True,  # Whether to cache attention outputs to speed up inference
    quantization_config=bnb_config,
    local_files_only=True,  # Use cache first
    device_map="auto",
    attn_implementation="flash_attention_2",
)

In [None]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"],
)

Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
100,2.7303,2.733198,2.369173,201246.0,0.453621


KeyboardInterrupt: 

In [None]:
# Configure the SFT training parameters
sft_config = SFTConfig(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_length=512,
    logging_steps=50,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=100,
    save_steps=200,
    bf16=True,
    loss_type="dft",
)

In [None]:
from transformers import DataCollatorWithFlattening

data_collator = DataCollatorWithFlattening()

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=sft_config,
    peft_config=peft_config,
    data_collator=data_collator,
)

In [None]:
trainer.train()