# Fine-tune 3 Models with Llama-3.1-8B-Instruct

Trains 3 models sequentially:
1. **Combined**: Balanced dataset (10k one-liners + 10k short stories)
2. **One-Liner**: Full dataset (~2k)
3. **Short-Story**: Full dataset (~50k)

**Prerequisites:** Run `prepare_training_data.ipynb` first to generate JSONL files.

## Setup

In [None]:
import sys
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model

sys.path.append('..')

TRAIN_DIR = Path("../data/train")
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

print(f"Using model: {MODEL_NAME}")
print(f"Training data directory: {TRAIN_DIR}")

## Define Fine-tuning Function

In [None]:
def finetune_model(dataset_name, output_name, num_epochs=3):
    """
    Fine-tune Llama-3.1-8B-Instruct with LoRA.
    
    Args:
        dataset_name: Name of dataset (e.g., 'combined', 'one_liner', 'short_story')
        output_name: Name for output directory (e.g., 'llama-8b-combined')
        num_epochs: Number of training epochs (default: 3)
    """
    print(f"\n{'='*80}")
    print(f"Starting training: {output_name}")
    print(f"{'='*80}\n")
    
    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    # Load model
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    # Configure LoRA
    print("Configuring LoRA...")
    lora_config = LoraConfig(
        r=64,
        lora_alpha=128,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    # Load datasets
    print(f"Loading {dataset_name} dataset...")
    train_dataset = load_dataset('json', data_files=str(TRAIN_DIR / f"{dataset_name}_train.jsonl"), split='train')
    val_dataset = load_dataset('json', data_files=str(TRAIN_DIR / f"{dataset_name}_val.jsonl"), split='train')
    
    print(f"Train examples: {len(train_dataset):,}")
    print(f"Val examples: {len(val_dataset):,}")
    
    # Format and tokenize
    print("Formatting and tokenizing...")
    def format_chat_template(example):
        messages = example['messages']
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        return {"text": text}
    
    def tokenize_function(examples):
        outputs = tokenizer(examples["text"], truncation=True, max_length=1024, padding="max_length")
        outputs["labels"] = outputs["input_ids"].copy()
        return outputs
    
    train_dataset = train_dataset.map(format_chat_template, remove_columns=train_dataset.column_names)
    val_dataset = val_dataset.map(format_chat_template, remove_columns=val_dataset.column_names)
    
    train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    
    # Training arguments
    output_dir = f"./models/{output_name}"
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,
        learning_rate=2e-4,
        warmup_ratio=0.03,
        logging_steps=25,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
        bf16=True,
        report_to="none",
        gradient_checkpointing=True,
    )
    
    # Create trainer
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
    )
    
    # Train
    print("\nStarting training...\n")
    trainer.train()
    
    # Save
    print(f"\nSaving model to {output_dir}...")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    print(f"\nâœ“ Completed: {output_name}\n")
    
    # Clean up VRAM
    del model
    del trainer
    torch.cuda.empty_cache()
    
    return output_dir

## Train All 3 Models Sequentially

In [None]:
# Model 1: Combined (balanced dataset)
model_1 = finetune_model(
    dataset_name="combined",
    output_name="llama-8b-combined",
    num_epochs=3
)

In [None]:
# Model 2: One-liner only
model_2 = finetune_model(
    dataset_name="one_liner",
    output_name="llama-8b-one-liner",
    num_epochs=3
)

In [None]:
# Model 3: Short-story only
model_3 = finetune_model(
    dataset_name="short_story",
    output_name="llama-8b-short-story",
    num_epochs=3
)

## Summary

In [None]:
print("\n" + "="*80)
print("ALL MODELS TRAINED SUCCESSFULLY")
print("="*80)
print(f"\n1. {model_1}")
print(f"2. {model_2}")
print(f"3. {model_3}")
print("\nNext steps: Test the models with inference code.")