# Setup


Make sure to use Kaggle's T4\*2 GPUs for training. P100 GPUs may encounter issues with some torch-dependent libraries due to their age.

In [None]:
import os
import wandb
from kaggle_secrets import UserSecretsClient
from typing import Literal

TRAIN_TYPE: Literal["SFT", "DPO"] = "DPO"  # SFT or DPO

USE_ACCELERATE = True # set to True if you have multiple GPUs, it fastens the evaluation time
MAX_STEPS = 1  # set to 500 after debugging
LIMIT = 10  # set to None after debugging

try:
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = wandb_api_key
    wandb.login(key=wandb_api_key)
except Exception as e:
    print("WANDB_API_KEY not set or failed to load.")
    print("Reason:", str(e))
    print("In Kaggle, add it via Add-ons → Secrets → Add Secret.")

MODEL_BASE_UIDS = [
    "unsloth/llama-3.2-1B-bnb-4bit",
    "unsloth/llama-3.2-3B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
]
# those are our finetuned SFT models
MODEL_SFT_UIDS = [
    "alextsiak/llama-3.2-1B-bnb-4bit-mix-500st",
    "rodionzorin/llama-3.2-3B-bnb-4bit_finetuned_tulu-3-sft-mixture",
    "mzarev/Llama-3.1-8B-bnb-4bit-mix-500st",
]
DATASET_SFT_UIDS = [
    "allenai/tulu-3-sft-personas-math-grade",
    "allenai/tulu-3-sft-personas-math",
    "allenai/tulu-3-sft-personas-instruction-following",
    "allenai/tulu-3-sft-personas-algebra",
    "allenai/tulu-3-sft-personas-code",
]
DATASET_SFT_MIXTURE_UIDS = ["allenai/tulu-3-sft-mixture"]
DATASET_DPO_UIDS = [
    "allenai/llama-3.1-tulu-3-8b-preference-mixture",
    "allenai/llama-3.1-tulu-3-70b-preference-mixture",
    "allenai/llama-3.1-tulu-3-405b-preference-mixture",
]
LM_EVAL_UIDS = [
    "hellaswag",
    "gsm8k",
    "arc_easy",
    "truthfulqa",
    "winogrande",
    "humaneval",
]


MODEL_BASE_UID = MODEL_BASE_UIDS[
    1
]  # choose depending on choice of finetuned model (if any)
MODEL_FINETUNED_UID = MODEL_SFT_UIDS[1]  # choose
DATASET_UIDS = DATASET_DPO_UIDS  # make sure this represents the datasets you're currently interested in
DATASET_UID = DATASET_UIDS[2]  # choose your dataset


print("Config done...")

In [None]:
!git clone -q https://github.com/EleutherAI/lm-evaluation-harness.git
!pip install -q -e ./lm-evaluation-harness/.
!pip install -q unsloth transformers datasets wandb pandas

# Fine-Tuning


In [None]:
from unsloth import FastLanguageModel
import wandb
from transformers import BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
import torch
from datasets import load_dataset
import json

torch.cuda.empty_cache()

model_name = MODEL_BASE_UID.split("/")[-1]
dataset_name = DATASET_UID.split("/")[-1]

if TRAIN_TYPE == "SFT":
    base_model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_BASE_UID, max_seq_length=2048, dtype=None, load_in_4bit=True
    )
    base_model.save_pretrained(model_name)

    model = FastLanguageModel.get_peft_model(
        base_model,
        r=16,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=16,
        lora_dropout=0,  # Supports any, but = 0 is optimized
        bias="none",  # Supports any, but = "none" is optimized
        use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
        max_seq_length=2048,
        use_rslora=False,  # We support rank stabilized LoRA
        loftq_config=None,  # And LoftQ
    )
else:
    # this is the original base model
    ref_model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_BASE_UID,
        max_seq_length=2048,
        dtype=None,
        load_in_4bit=True,
    )
    ref_model.save_pretrained(model_name)
    tokenizer.save_pretrained(model_name)
    
    # this is our LoRA-adapted SFT model
    model, _ = FastLanguageModel.from_pretrained(
        model_name=MODEL_FINETUNED_UID,
        max_seq_length=2048,
        dtype=None,
        load_in_4bit=True,
    )

    tokenizer.chat_template = """<s>[INST] {{ user }} [/INST] {{ assistant }}</s>"""

    # don't want to train the reference model
    for param in ref_model.parameters():
        param.requires_grad = False

model.config.quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

train_dataset = load_dataset(DATASET_UID, split="train")
print(train_dataset[0].keys())


if TRAIN_TYPE == "SFT":
    print(json.dumps(train_dataset[0]["messages"], indent=2))

    def formatting_func(examples):
        messages = examples["messages"]
        texts = [
            "".join([m["content"].strip() + "\n" for m in convo]).strip()
            for convo in messages
        ]
        return {"text": texts}

    train_dataset = train_dataset.map(formatting_func, batched=True)

wandb.login(key=os.environ["WANDB_API_KEY"])
wandb.init(
    project="pm-pt",
    name=f"{model_name}_{dataset_name}",
    config={
        "model": MODEL_BASE_UID,
        "dataset": DATASET_UID,
        "max_steps": MAX_STEPS,
        "learning_rate": 2e-4,
        "batch_size": 2,
        "gradient_accumulation_steps": 4,
    },
)

if TRAIN_TYPE == "SFT":
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        args=SFTConfig(
            dataset_text_field="text",
            max_seq_length=2048,
            learning_rate=2e-4,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            warmup_steps=5,
            max_steps=MAX_STEPS,
            report_to="wandb",
            run_name=f"{model_name}_{dataset_name}",
            output_dir="outputs",
            optim="adamw_8bit",
        ),
    )
else:
    trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=DPOConfig(
            beta=0.1,
            max_length=2048,
            learning_rate=2e-4,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            warmup_steps=5,
            max_steps=MAX_STEPS,
            report_to="wandb",
            run_name=f"{model_name}_{dataset_name}",
            output_dir="outputs",
            optim="adamw_8bit",
        ),
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )

trainer.train()

wandb.finish()

model.save_pretrained(f"{model_name}_{TRAIN_TYPE}_finetuned_{dataset_name}")
tokenizer.save_pretrained(f"{model_name}_{TRAIN_TYPE}_finetuned_{dataset_name}")

# Evaluation


In [None]:
import os
import subprocess

model_name = MODEL_BASE_UID.split("/")[-1]
dataset_name = DATASET_UID.split("/")[-1]

peft_path = f"./{model_name}_{TRAIN_TYPE}_finetuned_{dataset_name}"

os.environ["HF_ALLOW_CODE_EVAL"] = "1"

tasks_str = ",".join(LM_EVAL_UIDS)

# Base arguments shared across both modes
base_args = [
    "--model", "hf",
    "--model_args", f"pretrained={model_name},peft={peft_path}",
    "--tasks", tasks_str,
    "--confirm_run_unsafe_code",
    "--device", "cuda",
    "--batch_size", "auto",
]

if LIMIT is not None:
    base_args += ["--limit", str(LIMIT)]

if USE_ACCELERATE:
    num_processes = torch.cuda.device_count()
    command = [
        "accelerate", "launch",
        "--multi_gpu",
        f"--num_processes={num_processes}",
        "-m", "lm_eval"
    ] + base_args
else:
    command = ["lm_eval"] + base_args

print(command)

subprocess.run(command)

# Creating Excel Sheet Template


In [None]:
import pandas as pd
from itertools import product

columns = ["model_uid", "dataset_uid"] + LM_EVAL_UIDS

model_dataset_pairs = list(product(MODEL_BASE_UIDS, DATASET_UIDS))

empty_eval_df = pd.DataFrame(columns=columns)

for model_uid, dataset_uid in model_dataset_pairs:
    row = {
        "model_uid": model_uid,
        "dataset_uid": dataset_uid,
    }
    for task in LM_EVAL_UIDS:
        row[task] = None
    empty_eval_df.loc[len(empty_eval_df)] = row

empty_eval_df.to_excel("empty_eval_results.xlsx", index=False)

print("Created empty eval_results.xlsx")