# Lightweight Fine-tuning: Gemma-2B on TinyStories

This notebook demonstrates fine-tuning Google's Gemma-2B model on the TinyStories dataset using QLoRA. This is a lightweight example that should train quickly on most GPUs.

## 1. Setup and Installation

First, check GPU availability and install dependencies.

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Clone the repository
!git clone https://github.com/vmm/llm-trainer.git
%cd llm-trainer

In [None]:
# Install dependencies
!pip install -r requirements.txt

In [None]:
# Fix module import issues
import os
import sys

# Check and fix the working directory
if not os.path.exists('src'):
    # If we're not in the repo root, try to find it
    if os.path.exists('llm-trainer'):
        %cd llm-trainer
    else:
        # If we can't find it, raise an error
        raise FileNotFoundError("Cannot find repository root directory with 'src' folder")

# Add the current directory to Python's path
sys.path.append('.')
print(f"Working directory: {os.getcwd()}")
print(f"Python path includes current directory: {'./' in sys.path or '.' in sys.path}")

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Configure output directory in Google Drive (change this to your preferred location)
DRIVE_OUTPUT_DIR = "llm-trainer-output"  # Will be created under /content/drive/MyDrive/

# Full path to the output directory
DRIVE_BASE_PATH = f"/content/drive/MyDrive/{DRIVE_OUTPUT_DIR}"

# Specific paths for different components
DRIVE_DATASET_PATH = f"{DRIVE_BASE_PATH}/datasets/tinystories_processed"
DRIVE_MODEL_PATH = f"{DRIVE_BASE_PATH}/models/gemma_tinystories"
DRIVE_EVAL_PATH = f"{DRIVE_BASE_PATH}/evaluation/tinystories_results"
DRIVE_ADAPTER_PATH = f"{DRIVE_BASE_PATH}/lora_adapter"
DRIVE_ADAPTER_ZIP = f"{DRIVE_BASE_PATH}/gemma_tinystories_adapter.zip"

# Create project directories in Drive
!mkdir -p {DRIVE_BASE_PATH}/datasets
!mkdir -p {DRIVE_BASE_PATH}/models
!mkdir -p {DRIVE_BASE_PATH}/evaluation
!mkdir -p {DRIVE_BASE_PATH}/logs

print(f"All outputs will be saved to Google Drive under: {DRIVE_BASE_PATH}")

## 2. Authenticate with Hugging Face

Authenticate to access the Gemma model from Hugging Face.

In [None]:
# Authenticate with Hugging Face
import os
from huggingface_hub import login

# Replace with your actual token
HF_TOKEN = "your_huggingface_token_here"  

# Log in to Hugging Face
login(token=HF_TOKEN)

# Set environment variable for other libraries
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
os.environ["HF_TOKEN"] = HF_TOKEN

In [None]:
# Set up logging and checkpoint saving to Google Drive
import time
import threading
import os
import datetime
import sys
import logging

# Create a log directory in Drive
DRIVE_LOG_PATH = f"{DRIVE_BASE_PATH}/logs"
!mkdir -p {DRIVE_LOG_PATH}

# Set up logging to both console and file
log_file = f"{DRIVE_LOG_PATH}/training_log_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler(sys.stdout)
    ]
)

print(f"Logging enabled to {log_file}")
logging.info(f"Logging enabled to {log_file}")

# Function to save checkpoints to Google Drive
def save_checkpoint_periodically(interval=300):  # 300 seconds = 5 minutes
    while True:
        time.sleep(interval)
        timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
        print(f"\n[{timestamp}] Saving checkpoint to Google Drive...")
        logging.info(f"[{timestamp}] Saving checkpoint to Google Drive...")
        
        # Ensure directories exist
        !mkdir -p {DRIVE_BASE_PATH}/models 2>/dev/null || true
        !mkdir -p {DRIVE_BASE_PATH}/datasets 2>/dev/null || true
        !mkdir -p {DRIVE_BASE_PATH}/logs 2>/dev/null || true
        
        # Copy training state file to know where we left off
        if os.path.exists('output/gemma_tinystories'):
            # Check for training state files
            if os.path.exists('output/gemma_tinystories/trainer_state.json'):
                !cp output/gemma_tinystories/trainer_state.json {DRIVE_MODEL_PATH}/
                print(f"  - Saved trainer state file")
                logging.info(f"  - Saved trainer state file")
            
            # Check for checkpoint directories
            checkpoints = !ls -d output/gemma_tinystories/checkpoint-* 2>/dev/null || true
            if checkpoints:
                for checkpoint in checkpoints:
                    checkpoint_name = os.path.basename(checkpoint)
                    checkpoint_drive_path = f"{DRIVE_MODEL_PATH}/{checkpoint_name}"
                    # Only copy if it doesn't exist or is newer
                    if not os.path.exists(checkpoint_drive_path):
                        !mkdir -p {checkpoint_drive_path}
                        !cp -r {checkpoint}/* {checkpoint_drive_path}/
                        print(f"  - Saved new checkpoint: {checkpoint_name}")
                        logging.info(f"  - Saved new checkpoint: {checkpoint_name}")
            
            # Check for adapter model
            if os.path.exists('output/gemma_tinystories/adapter_model'):
                adapter_drive_path = f"{DRIVE_MODEL_PATH}/adapter_model"
                !mkdir -p {adapter_drive_path}
                !cp -r output/gemma_tinystories/adapter_model/* {adapter_drive_path}/
                print(f"  - Saved adapter model")
                logging.info(f"  - Saved adapter model")
        
        # Save any processed datasets
        datasets = !ls -d data/*_processed 2>/dev/null || true
        for dataset in datasets:
            dataset_name = os.path.basename(dataset)
            dataset_drive_path = f"{DRIVE_BASE_PATH}/datasets/{dataset_name}"
            if not os.path.exists(dataset_drive_path):
                !mkdir -p {dataset_drive_path}
                !cp -r {dataset}/* {dataset_drive_path}/
                print(f"  - Saved dataset: {dataset_name}")
                logging.info(f"  - Saved dataset: {dataset_name}")
                
        # Update the log file (copy the most recent version)
        !cp {log_file} {DRIVE_LOG_PATH}/
        
        print(f"[{timestamp}] Checkpoint save completed")
        logging.info(f"[{timestamp}] Checkpoint save completed")

# Start the checkpoint thread
checkpoint_thread = threading.Thread(target=save_checkpoint_periodically, daemon=True)
checkpoint_thread.start()
print(f"Automatic checkpointing to Drive enabled (every 5 minutes)")
logging.info(f"Automatic checkpointing to Drive enabled (every 5 minutes)")
print(f"All outputs will persist in: {DRIVE_BASE_PATH}")
logging.info(f"All outputs will persist in: {DRIVE_BASE_PATH}")

## 3. Process the TinyStories Dataset

Process the dataset and prepare it for training.

In [None]:
# Check if dataset already exists in Drive
import os
from datasets import load_from_disk, DatasetDict

if os.path.exists(DRIVE_DATASET_PATH):
    print(f"Dataset found at {DRIVE_DATASET_PATH}")
    
    # Load the dataset to check for validation split
    dataset = load_from_disk(DRIVE_DATASET_PATH)
    
    # Check if validation split exists
    if 'validation' not in dataset.keys():
        print("No validation split found in dataset. Creating validation split...")
        
        # Create validation split (10% of train data)
        if 'train' in dataset:
            split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
            
            # Create new dataset with validation split
            updated_dataset = DatasetDict({
                'train': split_dataset['train'],
                'validation': split_dataset['test']
            })
            
            # Save the updated dataset back to the same location
            updated_dataset.save_to_disk(DRIVE_DATASET_PATH)
            print(f"✅ Created validation split from train data. Updated dataset saved to {DRIVE_DATASET_PATH}")
            
            # Update the dataset variable
            dataset = updated_dataset
    
    # Create a symlink to local directory for easier access
    !mkdir -p data
    !ln -sf {DRIVE_DATASET_PATH} data/TinyStories_processed
    print(f"✅ Using dataset from Drive: {DRIVE_DATASET_PATH}")
    
else:
    # Process the dataset and save directly to Drive
    print(f"Processing dataset and saving to {DRIVE_DATASET_PATH}...")
    !python -m src.data_processors.tinystories_processor --config configs/gemma_tinystories.yaml --output_path {DRIVE_DATASET_PATH}
    
    # Check if validation split was created during processing
    dataset = load_from_disk(DRIVE_DATASET_PATH)
    if 'validation' not in dataset.keys():
        print("No validation split was created during processing. Creating one now...")
        
        # Create validation split (10% of train data)
        if 'train' in dataset:
            split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
            
            # Create new dataset with validation split
            updated_dataset = DatasetDict({
                'train': split_dataset['train'],
                'validation': split_dataset['test']
            })
            
            # Save the updated dataset back to the same location
            updated_dataset.save_to_disk(DRIVE_DATASET_PATH)
            print(f"✅ Created validation split from train data. Updated dataset saved to {DRIVE_DATASET_PATH}")
    
    # Create a symlink to local directory for easier access
    !mkdir -p data
    !ln -sf {DRIVE_DATASET_PATH} data/TinyStories_processed

In [None]:
# Update config to use Google Drive for output
import yaml

with open('configs/gemma_tinystories.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Update output directory to use our Drive path
config['training']['output_dir'] = DRIVE_MODEL_PATH

# Save updated config
with open('configs/gemma_tinystories_drive.yaml', 'w') as f:
    yaml.dump(config, f)

print(f"Updated config saved to configs/gemma_tinystories_drive.yaml with output_dir={DRIVE_MODEL_PATH}")

In [None]:
# Verify the processed dataset
try:
    dataset = load_from_disk("data/TinyStories_processed")
    
    # Print info about the dataset
    print(f"Dataset splits: {dataset.keys()}")
    if 'train' in dataset:
        print(f"Train size: {len(dataset['train'])}")
    if 'validation' in dataset:
        print(f"Validation size: {len(dataset['validation'])}")
    
    # See the first example
    print("\nExample data:")
    print(dataset[list(dataset.keys())[0]][0])
except Exception as e:
    print(f"Error loading dataset: {e}")

In [None]:
# Set up heartbeat monitoring to detect session disconnections
import threading
import time
import os
import datetime

# Create a heartbeat directory
HEARTBEAT_PATH = f"{DRIVE_BASE_PATH}/heartbeat"
!mkdir -p {HEARTBEAT_PATH}

# Write initial heartbeat file
heartbeat_file = f"{HEARTBEAT_PATH}/heartbeat.txt"

def update_heartbeat():
    """Update heartbeat file every minute to track if Colab is still running"""
    while True:
        # Write current timestamp to heartbeat file
        with open(heartbeat_file, 'w') as f:
            timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            f.write(f"Last heartbeat: {timestamp}\n")
            f.write(f"If you're seeing this file, it means the Colab session was running at {timestamp}.\n")
            f.write(f"If this timestamp is old, the session likely disconnected at that time.\n")
        
        # Copy to Drive
        !cp {heartbeat_file} {HEARTBEAT_PATH}/
        
        # Wait for 60 seconds
        time.sleep(60)

# Start heartbeat thread
heartbeat_thread = threading.Thread(target=update_heartbeat, daemon=True)
heartbeat_thread.start()

print(f"Heartbeat monitoring enabled - tracking session activity at {heartbeat_file}")
print(f"If Colab disconnects, you can check when it happened by looking at this file in your Drive.")

## 4. Fine-tune with QLoRA

Fine-tune the Gemma-2B model using QLoRA.

In [None]:
# Clean up memory before training
import gc
import torch

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA cache cleared")
    
# Run garbage collection
gc.collect()
print("Garbage collection completed")

# Show current GPU memory usage
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    
# Print current GPU usage
!nvidia-smi | grep MiB

In [None]:
# Function to backup all artifacts in case of manual shutdown
def backup_all_training_artifacts():
    """
    Perform a complete backup of all training artifacts to Google Drive.
    Call this manually when you want to ensure everything is saved.
    """
    import os
    import time
    
    print(f"\n{'='*40}")
    print(f"PERFORMING FINAL BACKUP TO GOOGLE DRIVE")
    print(f"{'='*40}\n")
    
    # Create all required directories
    os.makedirs(f"{DRIVE_BASE_PATH}/models", exist_ok=True)
    os.makedirs(f"{DRIVE_BASE_PATH}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_BASE_PATH}/logs", exist_ok=True)
    os.makedirs(f"{DRIVE_BASE_PATH}/evaluation", exist_ok=True)
    
    # Check for local output directory
    if os.path.exists('output'):
        # Copy all outputs (models, checkpoints, logs)
        !rsync -av --progress output/ {DRIVE_BASE_PATH}/models/ 2>/dev/null || cp -r output/* {DRIVE_BASE_PATH}/models/ 2>/dev/null || true
        print(f"✓ Backed up output directory to {DRIVE_BASE_PATH}/models/")
    
    # Check for local data directory 
    if os.path.exists('data'):
        # Copy all datasets
        !rsync -av --progress data/ {DRIVE_BASE_PATH}/datasets/ 2>/dev/null || cp -r data/* {DRIVE_BASE_PATH}/datasets/ 2>/dev/null || true
        print(f"✓ Backed up data directory to {DRIVE_BASE_PATH}/datasets/")
    
    # Check for local logs
    if os.path.exists('logs'):
        # Copy all logs
        !rsync -av --progress logs/ {DRIVE_BASE_PATH}/logs/ 2>/dev/null || cp -r logs/* {DRIVE_BASE_PATH}/logs/ 2>/dev/null || true
        print(f"✓ Backed up logs directory to {DRIVE_BASE_PATH}/logs/")
    
    # Check for local evaluation results
    if os.path.exists('evaluation'):
        # Copy all evaluation results
        !rsync -av --progress evaluation/ {DRIVE_BASE_PATH}/evaluation/ 2>/dev/null || cp -r evaluation/* {DRIVE_BASE_PATH}/evaluation/ 2>/dev/null || true
        print(f"✓ Backed up evaluation directory to {DRIVE_BASE_PATH}/evaluation/")
    
    # Copy tensorboard logs if they exist
    if os.path.exists('runs'):
        os.makedirs(f"{DRIVE_BASE_PATH}/tensorboard", exist_ok=True)
        !rsync -av --progress runs/ {DRIVE_BASE_PATH}/tensorboard/ 2>/dev/null || cp -r runs/* {DRIVE_BASE_PATH}/tensorboard/ 2>/dev/null || true
        print(f"✓ Backed up tensorboard logs to {DRIVE_BASE_PATH}/tensorboard/")
    
    print(f"\n{'='*40}")
    print(f"BACKUP COMPLETED - ALL TRAINING ARTIFACTS SAVED")
    print(f"{'='*40}\n")
    
    # List all backed up directories
    print("Contents of Drive backup directory:")
    !find {DRIVE_BASE_PATH} -type d | sort

# Register this function for manual use
print("Run 'backup_all_training_artifacts()' at any time to ensure all artifacts are backed up to Drive")

In [None]:
# Function to check for existing training state and set up for resuming
def check_for_resume_point():
    """Check if there's an existing training state to resume from"""
    import os
    import glob
    import json
    import re
    import yaml
    import torch
    import gc

    # First check if the fine-tuned model already exists (complete training)
    if os.path.exists(os.path.join(DRIVE_MODEL_PATH, "adapter_model")):
        print(f"✓ Fine-tuned model already exists at {DRIVE_MODEL_PATH}/adapter_model")
        print("Skipping training step. If you want to retrain, delete this directory from your Drive.")
        return True
    
    # If not complete, check for checkpoints to resume from
    print("Looking for checkpoints to resume training...")
    checkpoints = glob.glob(f"{DRIVE_MODEL_PATH}/checkpoint-*")
    
    if checkpoints:
        # Find the latest checkpoint by sorting (checkpoint numbers should be sequential)
        checkpoints.sort(key=lambda x: int(re.search(r'checkpoint-(\d+)', x).group(1)), reverse=True)
        latest_checkpoint = checkpoints[0]
        checkpoint_num = re.search(r'checkpoint-(\d+)', latest_checkpoint).group(1)
        
        print(f"✓ Found checkpoint: {latest_checkpoint}")
        
        # Check if trainer state exists
        trainer_state_path = os.path.join(DRIVE_MODEL_PATH, "trainer_state.json")
        if os.path.exists(trainer_state_path):
            try:
                with open(trainer_state_path, 'r') as f:
                    trainer_state = json.load(f)
                total_steps = trainer_state.get('max_steps', 'unknown')
                completed_steps = trainer_state.get('global_step', 0)
                print(f"✓ Training was at step {completed_steps}/{total_steps}")
            except Exception as e:
                print(f"Could not parse trainer state: {e}")
        
        # Create local output dir if needed
        !mkdir -p output/gemma_tinystories
        
        # Copy checkpoint to local storage for use
        local_checkpoint = f"output/gemma_tinystories/checkpoint-{checkpoint_num}"
        if not os.path.exists(local_checkpoint):
            print(f"Copying checkpoint from Drive to local storage for resuming...")
            !mkdir -p {local_checkpoint}
            !cp -r {latest_checkpoint}/* {local_checkpoint}/
        
        # Add resume flag to config
        print(f"Modifying config to resume from checkpoint...")
        with open('configs/gemma_tinystories_drive.yaml', 'r') as f:
            config = yaml.safe_load(f)
        
        # Point to local checkpoint for resuming
        config['model']['adapter_name_or_path'] = local_checkpoint
        
        with open('configs/gemma_tinystories_resume.yaml', 'w') as f:
            yaml.dump(config, f)
            
        print(f"⏳ Resuming training from checkpoint-{checkpoint_num}...")
        print(f"Model will continue saving to {DRIVE_MODEL_PATH}")
        
        try:
            !python -m src.trainers.qlora_trainer configs/gemma_tinystories_resume.yaml
        except Exception as e:
            print(f"Training failed: {e}")
            print("Trying again with clean memory...")
            
            # Clear memory before retrying
            torch.cuda.empty_cache()
            gc.collect()
            !python -m src.trainers.qlora_trainer configs/gemma_tinystories_resume.yaml
        
        return True  # Training was resumed
    
    return False  # No resumption point found

In [None]:
# Main training execution - checks for resume point or starts fresh
if not check_for_resume_point():
    print("No existing checkpoints found. Starting new training...")
    print(f"Model will be saved to {DRIVE_MODEL_PATH}")
    
    # Start fresh training
    !python -m src.trainers.qlora_trainer configs/gemma_tinystories_drive.yaml

# Backup everything when complete
backup_all_training_artifacts()

## 5. Test the Fine-tuned Model

Try out the fine-tuned model by generating some stories.

In [None]:
# Load and test model from Google Drive
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig

# Load the adapter config
config = PeftConfig.from_pretrained(DRIVE_MODEL_PATH)

# Load base model with authentication
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    load_in_8bit=True,
    device_map="auto",
    trust_remote_code=True,
    token=HF_TOKEN
)

# Load adapter model
model = PeftModel.from_pretrained(base_model, DRIVE_MODEL_PATH, is_trainable=False)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    config.base_model_name_or_path, 
    trust_remote_code=True,
    token=HF_TOKEN
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create text generation pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=200,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)

In [None]:
# Test the model with some story starters
story_starters = [
    "Once upon a time, there was a little rabbit who",
    "The small dog was very happy because",
    "In a tiny house at the edge of the forest"
]

for starter in story_starters:
    print(f"Prompt: {starter}")
    result = pipe(starter, return_full_text=True)[0]["generated_text"]
    print(f"Generated story:\n{result}")
    print("-" * 80)

## 6. Compare with Base Model

Compare the fine-tuned model with the base model.

In [None]:
# Load the base model to compare
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    load_in_8bit=True,
    trust_remote_code=True,
    token=HF_TOKEN
)

# Create a pipeline for the base model
base_pipe = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=tokenizer,
    max_new_tokens=200,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)

In [None]:
# Compare base model and fine-tuned model
test_prompt = "Once upon a time, there was a little rabbit who"

print("BASE MODEL OUTPUT:")
base_result = base_pipe(test_prompt, return_full_text=True)[0]["generated_text"]
print(base_result)

print("\n" + "-"*80 + "\n")

print("FINE-TUNED MODEL OUTPUT:")
ft_result = pipe(test_prompt, return_full_text=True)[0]["generated_text"]
print(ft_result)

In [None]:
# Display all output locations in Google Drive
print(f"\n=== ALL OUTPUT LOCATIONS (in Google Drive) ===\n")
print(f"Root directory:     {DRIVE_BASE_PATH}")
print(f"Processed Dataset:  {DRIVE_DATASET_PATH}")
print(f"Fine-tuned Model:   {DRIVE_MODEL_PATH}")
print(f"Evaluation Results: {DRIVE_EVAL_PATH}")

# List all saved directories in Drive
print("\n=== DIRECTORIES CREATED IN GOOGLE DRIVE ===\n")
!find {DRIVE_BASE_PATH} -type d | sort

# Display a summary of what was created
print("\n=== LLM Fine-tuning Summary ===\n")
print(f"Dataset: {'✓' if os.path.exists(DRIVE_DATASET_PATH) else '✗'}")
print(f"Trained Model: {'✓' if os.path.exists(DRIVE_MODEL_PATH) else '✗'}")
print(f"Adapter Files: {'✓' if os.path.exists(os.path.join(DRIVE_MODEL_PATH, 'adapter_model')) else '✗'}")
print("\nAll files are stored in your Google Drive and will be available after this Colab session ends.")

## 7. Summary

This notebook demonstrated a lightweight fine-tuning of Gemma-2B on TinyStories data with Google Drive integration. Key highlights:

1. Successfully fine-tuned Gemma-2B using QLoRA in 1-2 hours
2. Used a small dataset subset for quick training
3. Enabled proper validation during training
4. Added automatic checkpointing to Google Drive every 5 minutes
5. Provided checkpoint resumption for interrupted training
6. Compared base model vs fine-tuned model outputs
7. All results safely stored in Google Drive