In [None]:
# - Tokenizes prior to training
# - Adds data collator to produce labels
# - Uses BitsAndBytesConfig + device_map auto with safe fallback
# - Memory-conscious defaults (change constants below as needed)

In [None]:
# Install deps (quiet)
! pip install -q transformers accelerate datasets bitsandbytes peft sentencepiece safetensors

In [None]:
import os
import torch

# Optional: reduce fragmentation before allocations
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

# speed/memory tradeoffs
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# -------------------------
# User tuneable constants
# -------------------------
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"  #s maller Qwen for 14GB GPU
MAX_LENGTH = 1024        # try 512 if you still hit OOM
LORA_R = 8               # lower to 4 to save more memory
TRAIN_STEPS = 600
PER_DEVICE_BATCH = 1
GRAD_ACCUM_STEPS = 1
OUTPUT_DIR = "./qwen-medquad-output"
CSV_PATH = "/content/medquad.csv"   # path to your CSV
SEED = 42


In [None]:
# -------------------------
# 0) Basic checks
# -------------------------
torch.manual_seed(SEED)


In [None]:
# -------------------------
# 1) Load dataset
# -------------------------
from datasets import load_dataset

dataset = load_dataset("csv", data_files={"train": CSV_PATH})
dataset = dataset["train"]
print("Rows:", len(dataset))
print("Example:", dataset[0])

In [None]:
# -------------------------
# 2) Format chat entries
# -------------------------
def format_chat(example):
    q = example.get("question", "") or ""
    a = example.get("answer", "") or ""
    example["text"] = (
        "<|im_start|>user\n"
        "You are a safe and helpful medical assistant.\n"
        f"Question: {q}\n"
        "<|im_end|>\n"
        "<|im_start|>assistant\n"
        f"{a}\n"
        "<|im_end|>"
    )
    return example

dataset = dataset.map(format_chat)
# keep only text column
dataset = dataset.remove_columns([c for c in dataset.column_names if c != "text"])
dataset = dataset.train_test_split(test_size=0.02)
print("Train / Eval sizes:", len(dataset["train"]), len(dataset["test"]))

In [None]:
# -------------------------
# 3) Load tokenizer
# -------------------------
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# ensure pad token exists and align config
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:

# -------------------------
# 4) Tokenize dataset
# -------------------------
def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        padding=False,
        truncation=True,
        max_length=MAX_LENGTH
    )

tokenized_dataset = dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=["text"]
)

# Optionally do a tiny smoke test subset
# tokenized_dataset["train"] = tokenized_dataset["train"].select(range(200))
# tokenized_dataset["test"] = tokenized_dataset["test"].select(range(50))

In [None]:
# -------------------------
# 5) Prepare BitsAndBytesConfig + load model safely
# -------------------------
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    # compute dtype: try "bfloat16" on supported HW, else "float16"
    bnb_4bit_compute_dtype="float16",
    # this flag helps when device_map auto wants to offload to CPU/disk
    # it allows some FP32 cpu offload paths; it's used for int8 but helpful fallback
    llm_int8_enable_fp32_cpu_offload=True,
)

# Try auto device_map (fastest) with quantization_config; if HF complains about module placement, fallback
model = None
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True,
    )
    print("Loaded model with device_map='auto'")
except Exception as e:
    print("Auto load failed:", e)
    print("Falling back to loading all on CUDA:0 (may OOM).")
    # Fallback: try to load everything on GPU 0
    if torch.cuda.is_available():
        device_map = {"": 0}
    else:
        device_map = {"": "cpu"}

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map=device_map,
        quantization_config=bnb_config,
        trust_remote_code=True,
    )
    print("Loaded with explicit device_map:", device_map)

In [None]:
# -------------------------
# 6) Apply LoRA (PEFT)
# -------------------------
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],  # good defaults for Qwen-like
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, lora_config)

# Enable gradient checkpointing (saves GPU mem)
try:
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
except Exception:
    # some wrappers may not support enable_input_require_grads
    pass


In [None]:

# -------------------------
# 7) Data collator -> creates labels for causal LM
# -------------------------
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)


In [None]:
# -------------------------
# 8) Trainer + TrainingArguments
# -------------------------
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    max_steps=TRAIN_STEPS,
    learning_rate=2e-4,
    fp16=True,
    gradient_checkpointing=True,
    logging_steps=20,
    save_steps=200,
    report_to="none",
    # remove_unused_columns default True is fine since we tokenized
)

# free some GPU memory before Trainer allocs
torch.cuda.empty_cache()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],   # set to None to skip eval if OOM
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
# -------------------------
# 9) Train
# -------------------------
trainer.train()

In [None]:
# -------------------------
# 10) Merge adapters and save
# -------------------------
# PEFT: merge adapters (makes a standalone model) and save the merged model & tokenizer
try:
    merged = model.merge_and_unload()
    merged.save_pretrained(os.path.join(OUTPUT_DIR, "merged"))
    tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "merged"))
    print("Saved merged model at", os.path.join(OUTPUT_DIR, "merged"))
except Exception as e:
    print("Warning: could not merge adapters (maybe running in-place). Saving PEFT weights only.")
    model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)


In [None]:
# -------------------------
# 11) Quick inference sanity-check
# -------------------------
from transformers import AutoModelForCausalLM, AutoTokenizer

mp = os.path.join(OUTPUT_DIR, "merged") if os.path.isdir(os.path.join(OUTPUT_DIR, "merged")) else OUTPUT_DIR
tok = AutoTokenizer.from_pretrained(mp, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(mp, device_map="auto", trust_remote_code=True)

def ask_medical(question: str):
    prompt = "<|im_start|>user\n" + question + "\n<|im_end|>\n<|im_start|>assistant\n"
    inputs = tok(prompt, return_tensors="pt").to(m.device)
    out = m.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=False)
    txt = tok.decode(out[0], skip_special_tokens=True)
    return txt.split("<|im_start|>assistant")[-1].strip()

print(ask_medical("What are the symptoms of iron deficiency anemia?"))