In [None]:
%%capture
!pip install transformers datasets

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

In [None]:
# Load model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

max_seq_length = 2048

In [None]:
# Format the dataset
train_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.

### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.
Please answer the following medical question.

### Question:
{}

### Response:
<think>
{}
</think>
{}"""

In [None]:
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en", split="train[0:100]", trust_remote_code=True)

In [None]:
dataset

In [None]:
dataset['Question'][0]

In [None]:
dataset['Complex_CoT'][0]

In [None]:
dataset['Response'][0]

In [None]:
def formatting_prompts_func(examples):
    inputs = examples["Question"]
    cots = examples["Complex_CoT"]
    outputs = examples["Response"]
    texts = []
    for input, cot, output in zip(inputs, cots, outputs):
        text = train_prompt_style.format(input, cot, output) + tokenizer.eos_token
        texts.append(text)
    return {
        "text": texts,
    }

dataset = dataset.map(formatting_prompts_func, batched=True)

In [None]:
# Training setup
training_args = TrainingArguments(
    output_dir="outputs",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=2,  # Specify the number of epochs if needed
    warmup_steps=5,
    max_steps=60,
    learning_rate=2e-4,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    report_to="none"  # No external reporting
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer
)

In [None]:
# Start training
trainer.train()

In [None]:
# Inference example
question = "A 61-year-old woman with a long history of involuntary urine loss during activities like coughing or sneezing but no leakage at night undergoes a gynecological exam and Q-tip test. Based on these findings, what would cystometry most likely reveal about her residual volume and detrusor contractions?"
inputs = tokenizer(train_prompt_style.format(question, ""), return_tensors="pt", truncation=True, padding=True)

outputs = model.generate(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    max_length=512
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response.split("### Response:")[1])

# Save model and tokenizer
domain = "domain-medical-o1-reasoning"
model.save_pretrained(f"deepseek-r1-distill-llama-8b-{domain}")
tokenizer.save_pretrained(f"deepseek-r1-distill-llama-8b-{domain}")
