In [None]:
import wandb
wandb.login()

In [None]:
from datasets import load_dataset

# https://huggingface.co/datasets/HuggingFaceTB/smoltalk
dataset = load_dataset("HuggingFaceTB/smoltalk", 'all')

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# https://huggingface.co/distilbert/distilgpt2
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", use_fast=True)

In [None]:
import multiprocessing
num_proc = multiprocessing.cpu_count()

def chatml_tokenize(batch):
    texts = []
    for messages in batch["messages"]:
        chat = ""
        for msg in messages:
            if msg["role"] == "user":
                chat += "<|user|> " + msg["content"].strip() + " " + tokenizer.eos_token + " "
            elif msg["role"] == "assistant":
                chat += "<|assistant|> " + msg["content"].strip() + " " + tokenizer.eos_token + " "
        texts.append(chat.strip())
    return tokenizer(texts, padding=False, truncation=False)

tokenized_train = dataset["train"].map(
    chatml_tokenize, batched=True, batch_size=1000, num_proc=num_proc, remove_columns=["messages"]
)
tokenized_test = dataset["test"].map(
    chatml_tokenize, batched=True, batch_size=1000, num_proc=num_proc, remove_columns=["messages"]
)

In [None]:
import torch

device = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

In [None]:
# Add special tokens
special_tokens = ["<|user|>", "<|assistant|>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

# Full device migration for resize operation
model = model.to("cpu")  # Move entire model to CPU

# Perform resize on CPU
model.resize_token_embeddings(len(tokenizer))

# Move back to original device
model = model.to(device)

# Verify
print(f"Embeddings device: {model.get_input_embeddings().weight.device}")
print(f"New vocab size: {len(tokenizer)}")

In [None]:
# Evaluate WITHOUT ChatML formatting
def base_model_eval(question):
    encoded = tokenizer(question, return_tensors="pt").to(device)
    generated = model.generate(**encoded, max_new_tokens=20)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

print("BEFORE TRAINING (Raw model):")
print(base_model_eval("The capital of France is"))
print('\n...\n')
print(base_model_eval("What is the capital of France?"))

In [None]:
import random

# sample random indices from the test set
random_indices = random.sample(range(len(tokenized_test)), 50)

# create a new Dataset with only those
sampled_eval_dataset = tokenized_test.select(random_indices)

In [None]:
from peft import LoraConfig

# r: rank dimension for LoRA update matrices (smaller = more compression)
rank_dimension = 6
# lora_alpha: scaling factor for LoRA layers (higher = stronger adaptation)
lora_alpha = 12
# lora_dropout: dropout probability for LoRA layers (helps prevent overfitting)
lora_dropout = 0.05

peft_config = LoraConfig(
    r=rank_dimension,  # Rank dimension - typically between 4-32
    lora_alpha=lora_alpha,  # LoRA scaling factor - typically 2x rank
    lora_dropout=lora_dropout,  # Dropout probability for LoRA layers
    # bias="none",  # Bias type for LoRA. the corresponding biases will be updated during training.
    bias="lora_only",
    target_modules="all-linear",  # Which modules to apply LoRA to
    task_type="CAUSAL_LM",  # Task type for model architecture
)

In [None]:
from transformers import DataCollatorForLanguageModeling
from trl import SFTConfig, SFTTrainer

# Memory optimization setup
model.gradient_checkpointing_enable()
model.config.use_cache = False

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = SFTConfig(
    output_dir="./trainer_output",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    max_steps=1000,
    learning_rate=1e-5,
    bf16=True,
    logging_steps=10,
    save_total_limit=2,  # Keep last 2 checkpoints
    save_strategy="steps",
    save_steps=100,  # Save every 50 steps
    eval_strategy="steps",
    eval_steps=100,
    dataloader_num_workers=1,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    max_grad_norm=5, # Gradient clipping to combat exploding gradients
    num_train_epochs=1,
    run_name="m2-lora"
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    peft_config=peft_config,  # LoRA configuration
    train_dataset=tokenized_train,
    eval_dataset=sampled_eval_dataset,
    data_collator=data_collator,
)

trainer.train()

In [None]:
# Evaluate WITH ChatML formatting
def chatml_eval(question):
    formatted_prompt = f"<|user|> {question} <|assistant|>"
    encoded = tokenizer(formatted_prompt, return_tensors="pt").to(device)
    generated = model.generate(**encoded, max_new_tokens=100)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

print("\nAFTER TRAINING (ChatML-formatted):")
print(chatml_eval("The capital of France is"))
print('\n...\n')
print(chatml_eval("What is the capital of France?"))