# RAFT: Retrieval Augmented Fine-Tuning
### *Implementation of [arXiv:2403.10131](https://arxiv.org/pdf/2403.10131)*

This notebook is an implementation of the **RAFT** paper from scratch. It demonstrates how to fine-tune a model to ignore "distractor" documents during Retrieval Augmented Generation (RAG).

**Technical Note:** We utilize **Unsloth** for this implementation to optimize the fine-tuning process. This reduces VRAM usage by ~60% and accelerates training by 2x, allowing this entire pipeline to run on a free Google Colab T4 GPU.

In [None]:
pip install unsloth

In [None]:
import random
from datasets import Dataset

# These represent the "Oracle" chunks retrieved from our Postgres Vector DB

medical_facts = [
    # Emergency Protocols
    "Sepsis Protocol (SEP-1): Administer 30mL/kg crystalloid fluid challenge within 3 hours of presentation for hypotension or lactate >= 4mmol/L.",
    "Acute Coronary Syndrome (ACS): Dual antiplatelet therapy (Aspirin 325mg + P2Y12 inhibitor) should be administered immediately upon diagnosis of NSTEMI.",
    "Stroke (Ischemic): tPA (Alteplase) is indicated within 4.5 hours of symptom onset if no hemorrhage is detected on CT head.",
    "Anaphylaxis: First-line treatment is IM Epinephrine 0.01 mg/kg  (max 0.5 mg) into the mid-outer thigh, repeatable every 5-15 minutes.",

    # Chronic Management
    "Type 2 Diabetes: Metformin is first-line therapy. Add SGLT2 inhibitor if patient has established ASCVD or Heart Failure.",
    "Hypertension: Stage 1 is 130-139/80-89 mmHg. Start monotherapy (ACEi/ARB, CCB, or Thiazide) if ASCVD risk > 10%.",
    "Asthma: GINA 2023 guidelines recommend ICS-Formoterol as the preferred reliever for all severity steps, replacing SABA-only treatment.",

    # Drug Specifics
    "Vancomycin Dosing: Target trough levels of 15-20 mcg/mL for complicated infections (endocarditis, osteomyelitis, meningitis).",
    "Warfarin Reversal: For major bleeding with elevated INR, administer 4-factor Prothrombin Complex Concentrate (PCC) and Vitamin K IV.",
    "Hyperkalemia: For K+ > 6.5 with ECG changes, give Calcium Gluconate 1g IV immediately to stabilize cardiac membrane."
]


# 2. THE DISTRACTORS (Noise)
# These simulate "Bad Retrieval" - documents that appeared in the search
# but are irrelevant to the specific question asked.

distractors = [
    "Hospital Policy 101: The cafeteria is open from 06:00 to 20:00. Staff discount applies with ID badge.",
    "Visitor Policy: ICU visiting hours are restricted to immediate family members between 10:00 and 14:00.",
    "IT Support: To reset your EMR password, contact the helpdesk at extension 5555. Do not share credentials.",
    "Billing: ICD-10 code R07.9 (Chest pain, unspecified) requires additional documentation for reimbursement.",
    "Parking: Staff parking in Lot B is prohibited during construction (Jan-Mar 2025). Use Lot C shuttle.",
    "Pediatrics: The pediatric dosage for Amoxicillin is 20-40mg/kg/day divided q8h.",
    "OB/GYN: Pre-eclampsia prophylaxis with Aspirin 81mg should start at 12 weeks gestation for high-risk patients.",
    "Orthopedics: Post-op hip replacement patients require DVT prophylaxis for 35 days.",
    "Grand Rounds: Dr. Oghalai will present on 'Cochlear Mechanics' this Friday at noon in the main auditorium.",
    "Fire Safety: In case of Code Red, adhere to the RACE protocol (Rescue, Alarm, Contain, Extinguish)."
]




# 3. RAFT DATA GENERATOR

raft_data = []

def get_context_from_postgres_simulation(target_fact):
    """
    SIMULATING POSTGRES PGVECTOR RETRIEVAL:
    In production, this function would look like this:

    def get_real_postgres_context(question_embedding):
        sql = \"\"\"
        (
            -- 1. Get the "Gold" Document (The closest vector)
            SELECT content, embedding <=> %s as dist, 'gold' as type
            FROM clinical_guidelines
            ORDER BY dist ASC LIMIT 1
        )
        UNION ALL
        (
            -- 2. Get Distractors (Vectors that are 'kind of' close but not the answer, or random)
            SELECT content, embedding <=> %s as dist, 'distractor' as type
            FROM clinical_guidelines
            WHERE embedding <=> %s > 0.4 -- Filter for things not TOO close
            ORDER BY RANDOM() LIMIT 2
        )
        \"\"\"
        cursor.execute(sql, (question_embedding, question_embedding, question_embedding))
        return cursor.fetchall()
    """

    # Since we don't have a live DB connection here, we simulate the result:
    # 1 Gold Fact + 2 Random Distractors

    noise_docs = random.sample(distractors, 2)
    context_docs = noise_docs + [target_fact]
    random.shuffle(context_docs) # Shuffle so model doesn't just learn "Answer is always last"
    return context_docs


# Manually pairing Questions with Facts for the Training Set
# (In a real pipeline, we'd use GPT-4 to generate questions FROM the facts)

qa_pairs = [
    ("What is the protocol for sepsis fluid resuscitation?", medical_facts[0], "Give 30mL/kg crystalloid within 3 hrs."),
    ("When should DAPT be started for NSTEMI?", medical_facts[1], "Immediately upon diagnosis."),
    ("What is the time window for tPA in ischemic stroke?", medical_facts[2], "Within 4.5 hours of symptom onset."),
    ("How do you treat anaphylaxis?", medical_facts[3], "IM Epinephrine 0.01 mg/kg immediately."),
    ("What is the first-line drug for Type 2 Diabetes?", medical_facts[4], "Metformin."),
    ("What are the BP targets for Stage 1 Hypertension?", medical_facts[5], "Treat if >130/80 and ASCVD risk >10%."),
    ("What is the preferred reliever for Asthma in GINA 2023?", medical_facts[6], "ICS-Formoterol."),
    ("What is the target trough for Vancomycin in endocarditis?", medical_facts[7], "15-20 mcg/mL."),
    ("How do you reverse Warfarin bleeding?", medical_facts[8], "4-factor PCC and IV Vitamin K."),
    ("What is the emergency treatment for Hyperkalemia with ECG changes?", medical_facts[9], "IV Calcium Gluconate 1g.")
]

for question, fact, short_answer in qa_pairs:

    # 1. Retrieve Context (Simulating the Postgres RAG step)
    retrieved_docs = get_context_from_postgres_simulation(fact)

    # 2. Format Context String
    context_str = "\n".join([f"Document [{i+1}]: {doc}" for i, doc in enumerate(retrieved_docs)])

    # 3. Generate Chain of Thought (The RAFT Magic)
    # We teach the model to explicitly cite the GOLD document and ignore the DISTRACTORS.

    cot = f"Reading the context, Document containing '{fact[:20]}...' discusses {short_answer}. The other documents regarding cafeteria hours or parking are irrelevant."

    raft_sample = {
        "instruction": "You are an expert medical scribe. Answer the question based strictly on the provided context docs, citing your source.",
        "input": f"Context:\n{context_str}\n\nQuestion:\n{question}",
        "output": f"Thinking: {cot}\nAnswer: {short_answer}"
    }
    raft_data.append(raft_sample)

# Convert to Dataset
dataset = Dataset.from_list(raft_data)
print(f"Generated {len(dataset)} RAFT training samples.")
print("Sample Input:\n", dataset[0]['input'])

In [None]:
import psutil
import builtins
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
builtins.psutil = psutil

In [None]:
# 1. FORMATTING THE PROMPT
raft_prompt = """### Instruction:
{}
,
### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for inst, inp, out in zip(instructions, inputs, outputs):
        text = raft_prompt.format(inst, inp, out) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts }

# Apply the format to the new dataset
train_dataset = dataset.map(formatting_prompts_func, batched = True)


In [None]:
# 2. THE TRAINER (Fine-Tuning Loop)
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    dataset_text_field = "text",
    max_seq_length = 2048,
    dataset_num_proc = 2,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "raft_outputs",
    ),
)

print("Starting RAFT Fine-Tuning...")
trainer.train()
print("Training Complete!")

In [None]:
# 3. INFERENCE TEST

# (ONLY Run AFTER the training step finishes)

from unsloth import FastLanguageModel

FastLanguageModel.for_inference(model)

# A test case with distractors (for ex. lets try : Pizza & Parking)

test_context = """Context:
Document [1]: The hospital cafeteria serves pizza on Fridays.
Document [2]: Acute Coronary Syndrome (ACS): Dual antiplatelet therapy (Aspirin + P2Y12) should be administered immediately.
Document [3]: Parking structure B is closed for maintenance.

Question:
When should DAPT be started for NSTEMI?"""

inputs = tokenizer(
[
    raft_prompt.format(
        "You are an expert medical scribe. Answer the question based strictly on the provided context docs, citing your source.", # Instruction
        test_context, # Input
        "", # Output (Left blank for generation)
    )
], return_tensors = "pt").to("cuda")


outputs = model.generate(**inputs, max_new_tokens = 128, use_cache = True)


# Decode the output and strip the prompt
print("\n=== MODEL OUTPUT ===")
print(tokenizer.batch_decode(outputs)[0].split("### Response:")[-1])

In [None]:
# 4. MINI-BATCH EVALUATION


import pandas as pd
eval_cases = [
    {
        "topic": "Sepsis",
        "context": """Document [A]: The cafetera closes at 8pm.\nDocument [B]: Sepsis Protocol: Administer 30mL/kg crystalloid fluid within 3 hours.\nDocument [C]: Dr. Smith is on vacation.""",
        "question": "What is the fluid resuscitation protocol for sepsis?"
    },
    {
        "topic": "Stroke",
        "context": """Document [A]: Stroke (Ischemic): tPA is indicated within 4.5 hours.\nDocument [B]: Parking Lot C is for visitors only.\nDocument [C]: To reset wifi password, call IT.""",
        "question": "What is the time window for tPA?"
    },
    {
        "topic": "Anaphylaxis",
        "context": """Document [A]: Fire drill at noon.\nDocument [B]: Anaphylaxis: First-line treatment is IM Epinephrine 0.01 mg/kg.\nDocument [C]: Bagels are available in the break room.""",
        "question": "What is the first-line treatment for anaphylaxis?"
    }
]

results = []

print("Running Mini-Eval...")
for case in eval_cases:
    inputs = tokenizer(
        [
            raft_prompt.format(
                "Answer strictly based on context.",
                f"Context:\n{case['context']}\n\nQuestion:\n{case['question']}",
                "",
            )
        ], return_tensors = "pt").to("cuda")

    outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
    response = tokenizer.batch_decode(outputs)[0].split("### Response:")[-1].strip()

    results.append({
        "Topic": case['topic'],
        "Model Response": response
    })

# RESULTS

df = pd.DataFrame(results)
print("\n=== EVALUATION REPORT ===")
print(df.to_markdown(index=False))