In [None]:
import torch
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset, DatasetDict
from peft import LoraConfig
from transformers import AutoTokenizer, EarlyStoppingCallback, BitsAndBytesConfig

from loguru import logger

DATA_FILENAME = "./data/my_custom_data.jsonl"

models = {"phi": "microsoft/Phi-4-mini-instruct", # https://huggingface.co/microsoft/Phi-4-mini-instruct
          "smol-135M": "HuggingFaceTB/SmolLM-135M-Instruct", # https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct
          "smol-3B": "HuggingFaceTB/SmolLM3-3B", # https://huggingface.co/HuggingFaceTB/SmolLM3-3B
}
MODEL_NAME = models["smol-135M"]

# Clear any cached memory
torch.cuda.empty_cache()

In [3]:
def load_and_prepare_data(file_path: str, cfg, seed: int = 816):
    """Load dataset and prepare for training with TRL-compatible conversational format."""
    
    # Load raw data
    raw_dataset = load_dataset("json", data_files=file_path)["train"]
    
    # Convert to conversational format that SFTTrainer expects
    def format_to_messages(example):
        messages = [
            {"role": "system", "content": cfg.system_prompt},
            {"role": "user", "content": example["instruction"]},
            {"role": "assistant", "content": example["response"]},
        ]
        return {"messages": messages}
    
    # Apply the formatting
    dataset = raw_dataset.map(format_to_messages, remove_columns=raw_dataset.column_names)
    
    # Split data
    test_size = cfg.test_split + cfg.validation_split
    val_ratio = cfg.validation_split / test_size
    train_val = dataset.train_test_split(test_size=test_size, seed=seed)
    val_test = train_val["test"].train_test_split(test_size=val_ratio, seed=seed)
    
    final_dataset = DatasetDict({
        "train": train_val["train"],
        "validation": val_test["train"],  
        "test": val_test["test"]
    })
    
    logger.info(f"Train: {len(final_dataset['train'])}, Val: {len(final_dataset['validation'])}, Test: {len(final_dataset['test'])}")
    return final_dataset

In [4]:
# Configuration
class Config:
    def __init__(self):
        self.seed = 816
        self.system_prompt = (
            "You are a careful medical assistant providing evidence-based information. "
            "Always end with 'This response was generated by AI. Please check with medical practitioners.'"
        )
        
        # Data settings
        self.data = type('DataConfig', (), {
            'train_file': './data/my_custom_data.jsonl',
            'test_split': 0.1,
            'validation_split': 0.1,
            'system_prompt': self.system_prompt
        })()

        self.output_directory = "my_great_llm_model"

cfg = Config()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load and prepare the dataset in the format SFTTrainer expects
dataset = load_and_prepare_data(DATA_FILENAME, cfg.data, cfg.seed)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

[32m2026-01-04 23:01:35.521[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_and_prepare_data[0m:[36m31[0m - [1mTrain: 80, Val: 10, Test: 10[0m


In [5]:
# Configure PEFT with LoRA - optimized for better generalization
peft_config = LoraConfig(
    r=32,                         # Increased LoRA rank for better capacity 
    lora_alpha=64,               # Increased scaling factor proportionally 
    lora_dropout=0.1,            # Increased dropout for better regularization 
    bias="none",                 # Bias training strategy
    task_type="CAUSAL_LM",       # Task type for causal language modeling
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # Added more target modules
)

In [6]:
# Configure BitsAndBytes for memory-efficient quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                    # Enable 4-bit quantization
    bnb_4bit_quant_type="nf4",           # Use nf4 quantization type (recommended)
    bnb_4bit_compute_dtype=torch.bfloat16, # Computation dtype
    bnb_4bit_use_double_quant=True,      # Use double quantization for even better memory savings
    #load_in_8bit=True,                 # Alternative: use 8-bit instead of 4-bit (uncomment if preferred)
)

In [7]:
training_args = SFTConfig(
    model_init_kwargs={
        "quantization_config": bnb_config,   # Add BitsAndBytes quantization
        "device_map": None,                  # Let the trainer handle device placement
        "attn_implementation": "flash_attention_2",  # Enable flash attention for packing compatibility
        "dtype": torch.bfloat16,       # Ensure consistent dtype for flash attention
    },
    packing=True,
    num_train_epochs=100,               # Reduced from 100 to prevent overfitting
    learning_rate=5e-5,               # Reduced learning rate (was 2e-4)
    lr_scheduler_type="cosine",       # Cosine annealing for better convergence
    warmup_ratio=0.1,                 # Warmup for 10% of training steps
    weight_decay=0.01,                # Added weight decay for regularization
    loss_type="dft",                  # Dynamic fine tuning to scale loss by probability of token (https://arxiv.org/pdf/2508.05629)
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    auto_find_batch_size=True,
    output_dir=cfg.output_directory, 
    max_length=1024,
    bf16=True,                             # Enable bfloat16 training for consistency with flash attention
    dataloader_pin_memory=False,           # Disable pin memory to avoid dtype conflicts
    
    # Evaluation configuration
    eval_strategy="steps",             # Evaluate every eval_steps
    eval_steps=25,                     # More frequent evaluation (was 50)
    save_strategy="steps",             # Save model every save_steps
    save_steps=25,                     # More frequent saves (was 50)
    logging_steps=5,                   # More frequent logging (was 10)
    load_best_model_at_end=True,       # Load the best model at the end of training
    metric_for_best_model="eval_loss", # Use validation loss to determine best model
    greater_is_better=False,           # Lower loss is better
    save_total_limit=3,                # Reduced checkpoint limit (was 5)
)

In [8]:
# Import early stopping callback
from transformers import EarlyStoppingCallback

# Initialize early stopping callback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=5,      # Stop after 5 evaluations without improvement
    early_stopping_threshold=0.001   # Minimum improvement required
)

In [9]:
# Initialize trainer with PEFT configuration
# SFTTrainer will automatically handle the conversational format and apply chat templates
trainer = SFTTrainer(
    model=MODEL_NAME,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    processing_class=tokenizer,  # Pass tokenizer as processing_class
    peft_config=peft_config,
    args=training_args,
    callbacks=[early_stopping],  # Use our configured early stopping callback
    #completion_only=True,  # Only calculate loss on the assistant's response, not the full conversation
)

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

Tokenizing train dataset:   0%|          | 0/80 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/80 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

In [10]:
print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
print(f"GPU memory reserved : {torch.cuda.memory_reserved()/1024**3:.2f} GB")
print(f"PEFT config: LoRA with rank={peft_config.r}, alpha={peft_config.lora_alpha}")
print(f"Target modules: {peft_config.target_modules}")

GPU memory allocated: 0.12 GB
GPU memory reserved : 0.22 GB
PEFT config: LoRA with rank=32, alpha=64
Target modules: {'down_proj', 'k_proj', 'gate_proj', 'o_proj', 'v_proj', 'q_proj', 'up_proj'}


In [11]:
trainer.train()

Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
25,0.1291,0.127001,2.591401,87979.0,0.394315
50,0.1191,0.113858,2.107718,176387.0,0.419135
75,0.0843,0.081411,1.369299,264428.0,0.445556
100,0.0666,0.067269,1.106834,352689.0,0.466373
125,0.0561,0.059532,0.962033,440741.0,0.481986
150,0.0487,0.056064,0.910817,526925.0,0.488791
175,0.0423,0.052892,0.863394,615106.0,0.497198
200,0.0387,0.050877,0.825032,703026.0,0.5004
225,0.0352,0.048552,0.788505,791237.0,0.500801
250,0.032,0.046988,0.764803,879599.0,0.500801


TrainOutput(global_step=375, training_loss=0.0547679342230161, metrics={'train_runtime': 763.3183, 'train_samples_per_second': 2.882, 'train_steps_per_second': 0.786, 'total_flos': 917212849233408.0, 'train_loss': 0.0547679342230161, 'epoch': 62.54545454545455})

In [12]:
# Inference using TRL - directly from the trainer
from transformers import GenerationConfig

# Use the trained model directly from trainer
model = trainer.model
tokenizer = trainer.processing_class

# Example question for inference
question = "What are the common symptoms of diabetes?"

# Format the input using the same system prompt as a conversational format
messages = [
    {"role": "system", "content": cfg.system_prompt},
    {"role": "user", "content": question},
]

# Apply chat template (SFTTrainer should have set this up automatically)
formatted_input = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)

# Tokenize input
inputs = tokenizer(
    formatted_input, 
    return_tensors="pt"
).to(model.device)

# Better generation configuration to prevent repetition
generation_config = GenerationConfig(
    max_new_tokens=150,
    temperature=0.8,
    top_p=0.9,
    top_k=50,
    do_sample=True,
    repetition_penalty=1.1,
    no_repeat_ngram_size=3,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

print("Generating response...")

# Generate response using the LoRA-tuned model
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        generation_config=generation_config
    )

# Better approach: Extract only the newly generated tokens
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[0][input_length:]
generated_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

print("=" * 60)
print("QUESTION:")
print(question)
print("\n" + "=" * 60)
print("AI ASSISTANT RESPONSE:")
print(generated_response if generated_response else "[NO RESPONSE GENERATED]")
print("=" * 60)

# Debug: Show the formatted input for comparison
print("\n" + "=" * 30 + " DEBUG " + "=" * 30)
print("FORMATTED INPUT:")
print(repr(formatted_input))
print(f"\nInput token length: {input_length}")
print(f"Total output tokens: {outputs[0].shape[0]}")
print(f"Generated tokens: {len(generated_tokens)}")

Generating response...
QUESTION:
What are the common symptoms of diabetes?

AI ASSISTANT RESPONSE:
Diabetes is often symptomless, but it can cause:

1. **Weight loss**: Difficulty eating or gaining weight.
2. ** Hunger**: Irritability, malnutrition, or undereating.
3 ** pedal numbness** (tingling in the hands and feet).
4 **blindness** ( cataracts, glaucoma, or Kane's disease).
5 **hearing loss** ( deafness, vertigo, or tinnitus).
6 **vision loss**( cataracts, macular degeneration).
7 **diabetic retinopathy** ( retinal damage).
8 **nerve neuropathy** ( nerve damage).Course: Understanding Financial Markets and Institutions

Welcome to this course on financial markets and institutions! You've likely heard about banks, stock exchanges

FORMATTED INPUT:
"<|im_start|>system\nYou are a careful medical assistant providing evidence-based information. Always end with 'This response was generated by AI. Please check with medical practitioners.'<|im_end|>\n<|im_start|>user\nWhat are the common sy