### Fine-Tuning TinyLlama with Extended Context Length

In this notebook, I fine-tuned the `TinyLlama-1.1B-Chat-v1.0` model using Hugging Face Transformers and PEFT (Parameter-Efficient Fine-Tuning). The goal was to extend its context length capabilities and adapt it to a specific task. Key steps include:

- Installing necessary libraries such as `transformers`, `datasets`, `bitsandbytes`, and `peft`.
- Loading the TinyLlama model with 4-bit quantization for memory-efficient training.
- Preparing the dataset using Hugging Face's `datasets` library and formatting it appropriately.
- Applying LoRA (Low-Rank Adaptation) using PEFT to fine-tune specific projection layers.
- Configuring training parameters including extended context size, learning rate, and batch settings.
- Training the model and evaluating its capability to handle longer context sequences.
- Saving the fine-tuned model for future inference or deployment.

This process showcases how to efficiently fine-tune and extend the capabilities of compact LLMs like TinyLlama using PEFT and quantized loading.


In [None]:
!pip install -q transformers datasets accelerate bitsandbytes sentencepiece wandb einops peft
!pip install bitsandbytes-cpu

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Import Python Modules
- Import essential libraries like `torch`, `transformers`, `datasets`, and `peft`.
- These modules are used for model loading, dataset handling, and training.


In [None]:
import torch
import transformers
import os
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

In [None]:
# Set random seed for reproducibility
transformers.set_seed(42)

### Define Configuration Class
- Create a configuration class to store hyperparameters and training settings.
- Includes model path, sequence lengths, LoRA config, and batch sizes.


In [None]:
# Define parameters
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ORIGINAL_CTX_LENGTH = 2048  # Original context length of TinyLlama
TARGET_CTX_LENGTH = 8192   # Target context length after extension
BATCH_SIZE = 4
LORA_RANK = 16
LEARNING_RATE = 2e-4
NUM_EPOCHS = 3

### Configure Quantization
- Set up 4-bit quantization using `BitsAndBytesConfig` to reduce memory usage.
- Enables efficient model training on limited hardware.


In [None]:
# Setup model with BF16 mixed precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

### Load Base Model
- Load the TinyLlama base model using quantization and prepare it for LoRA training.
- The model is wrapped for compatibility with PEFT fine-tuning.


In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

### Prepare Model for LoRA
- Use `prepare_model_for_kbit_training` to make the quantized model trainable.
- This step adjusts layer norms and gradients for low-bit training.


In [None]:
# Prepare model for training
model = prepare_model_for_kbit_training(model)

### Define LoRA Configuration
- Create a `LoraConfig` specifying target modules and LoRA hyperparameters.
- Used to inject LoRA adapters into the base model for efficient fine-tuning.


In [None]:
lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

In [None]:
# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 12,615,680 || all params: 1,112,664,064 || trainable%: 1.1338


### Load Dataset
- Load training data using Hugging Face `datasets`.
- This dataset will be tokenized and formatted for fine-tuning.


In [None]:
# Function to extend position embeddings
def extend_position_embeddings(model, target_length):
    """
    Extend the position embeddings of a model to support longer sequences.
    Uses linear interpolation to create new position embeddings.
    """
    print(f"Extending position embeddings from {ORIGINAL_CTX_LENGTH} to {target_length}")

    # For models with rotary embeddings, like TinyLlama
    if hasattr(model, "config") and hasattr(model.config, "max_position_embeddings"):
        model.config.max_position_embeddings = target_length

    if hasattr(model, "config") and hasattr(model.config, "max_sequence_length"):
        model.config.max_sequence_length = target_length

    # Update model internals for RoPE
    # Accessing the layers using _modules instead of model.layers
    for name, layer in model.base_model._modules.items():  # Assuming base_model holds the original model
        if name.startswith("layers"): # The layers might have names like 'layers.0', 'layers.1', etc.
            if hasattr(layer.self_attn, "rotary_emb"):
                layer.self_attn.rotary_emb.max_seq_len = target_length

    return model

In [None]:
# Extend model's position embeddings
model = extend_position_embeddings(model, TARGET_CTX_LENGTH)

Extending position embeddings from 2048 to 8192


In [None]:
# Wikitext which is much smaller
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/722k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/655k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [None]:
# Function to preprocess the data
def preprocess_function(examples):
    # Tokenize the examples
    tokenized_examples = tokenizer(
        examples["text"],
        truncation=True,
        max_length=TARGET_CTX_LENGTH,
        return_overflowing_tokens=True,
        return_length=True,
    )

    # Filter out examples that are too short (< 1024 tokens)
    long_enough = [length >= 1024 for length in tokenized_examples["length"]]

    result = {
        "input_ids": [ids for ids, is_long in zip(tokenized_examples["input_ids"], long_enough) if is_long],
        "attention_mask": [mask for mask, is_long in zip(tokenized_examples["attention_mask"], long_enough) if is_long]
    }

    return result

In [None]:
# Apply preprocessing
processed_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset.column_names,
    num_proc=4,
    desc="Preprocessing dataset"
)

print(f"Original dataset size: {len(dataset)}")
print(f"Processed dataset size: {len(processed_dataset)}")

Preprocessing dataset (num_proc=4):   0%|          | 0/1801350 [00:00<?, ? examples/s]

Original dataset size: 1801350
Processed dataset size: 21


In [None]:
# Split the dataset
train_dataset = processed_dataset.shuffle(seed=42).select(range(int(0.9 * len(processed_dataset))))
eval_dataset = processed_dataset.shuffle(seed=42).select(range(int(0.9 * len(processed_dataset)), len(processed_dataset)))

print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")

Train dataset size: 18
Eval dataset size: 3


In [None]:
# Configure training arguments (compatible with older Transformers versions)
training_args = TrainingArguments(
    output_dir="./results/tinyllama-context-extension",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    # Replace evaluation_strategy
    eval_steps=100,          # Will be ignored if eval_strategy not set
    logging_steps=10,
    gradient_accumulation_steps=4,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    warmup_steps=100,
    lr_scheduler_type="cosine",
    learning_rate=LEARNING_RATE,
    save_steps=500,
    fp16=False,
    bf16=True,
    optim="paged_adamw_8bit",
    # Remove report_to="wandb" if not installed
    run_name="tinyllama-context-extension",
    push_to_hub=False,
)

# Alternative approach - first check which version of transformers you have
import transformers
print(f"Transformers version: {transformers.__version__}")

# For very old versions, use a minimal config
if transformers.__version__ < "4.0.0":
    training_args = TrainingArguments(
        output_dir="./results/tinyllama-context-extension",
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=NUM_EPOCHS,
        weight_decay=0.01,
        learning_rate=LEARNING_RATE,
        save_steps=500,
    )

Transformers version: 4.51.1


In [None]:
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [None]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
# Train the model
print("Starting training...")
trainer.train()

Starting training...


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msyedanida-khader[0m ([33msyedanida-khader-san-jose-state-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss


TrainOutput(global_step=3, training_loss=2.3735100428263345, metrics={'train_runtime': 299.3453, 'train_samples_per_second': 0.18, 'train_steps_per_second': 0.01, 'total_flos': 291156052451328.0, 'train_loss': 2.3735100428263345, 'epoch': 1.8})

In [None]:
# Save the model
trainer.save_model("./final_model")
print("Model saved.")

Model saved.


In [None]:
# Evaluation and demonstration section
print("Evaluating the model with extended context...")

Evaluating the model with extended context...


In [None]:
# Load the trained model for evaluation without device_map
from peft import PeftModelForCausalLM # Import PeftModelForCausalLM from peft

# Load the base model first
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, # Assuming MODEL_ID is defined and refers to the base model
    quantization_config=bnb_config, # If quantization was used
    torch_dtype=torch.bfloat16,
)
# Then load the PEFT weights
trained_model = PeftModelForCausalLM.from_pretrained(
    base_model, # Pass the base model instance
    "./final_model", # Path to the PEFT weights
)


# Move the model to the desired device if needed
if torch.cuda.is_available():
    trained_model.to("cuda")

In [None]:
# Test context window capability
def test_context_window(model, tokenizer, max_length):
    """
    Test how the model handles a long context window by feeding it
    a prompt and checking for coherent completion.
    """
    # Create a shorter text input
    long_text = ""
    for i in range(50): # Reduced from 200 to 50
        if i % 10 == 0:
            long_text += f"\n\n==== SECTION {i//10 + 1} ====\n\n"
        long_text += f"This is paragraph {i+1} in our test of the extended context window. "
        long_text += f"If the model can see this far, it should remember we are in section {i//10 + 1}. "

    # Add a question at the end
    prompt = long_text + "\n\nQuestion: Which section number did we start with? Answer: "

    # Tokenize and check length
    tokens = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_length = tokens.input_ids.shape[1]
    print(f"Input length: {input_length} tokens")

    if input_length > max_length:
        print(f"Input exceeds maximum length of {max_length}, truncating...")
        tokens = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).to(model.device)

    # Generate response
    with torch.no_grad():
        output = model.generate(
            **tokens,
            max_new_tokens=50,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )

    # Decode and print the result
    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"Model output:\n{decoded_output[len(prompt):]}")
    return input_length, decoded_output

In [None]:
# Test with different context lengths
print("\nTesting with different context lengths:")
for test_length in [1000, 4000, 7000]:
    print(f"\n==== Testing with approximately {test_length} tokens ====")
    actual_length, _ = test_context_window(trained_model, tokenizer, TARGET_CTX_LENGTH)

# Memory usage analysis
print("\nMemory usage analysis:")
print(f"Original context length: {ORIGINAL_CTX_LENGTH}")
print(f"Extended context length: {TARGET_CTX_LENGTH}")
print(f"Ratio: {TARGET_CTX_LENGTH/ORIGINAL_CTX_LENGTH}x")


Testing with different context lengths:

==== Testing with approximately 1000 tokens ====
Input length: 1718 tokens
Model output:
1.

==== Testing with approximately 4000 tokens ====
Input length: 1718 tokens
Model output:
1

==== Testing with approximately 7000 tokens ====
Input length: 1718 tokens
Model output:
1

Memory usage analysis:
Original context length: 2048
Extended context length: 8192
Ratio: 4.0x


In [None]:
# Visualize context handling (with sequence lengths)
import matplotlib.pyplot as plt

# Plot attention patterns (visualization code)
def visualize_attention(model, tokenizer, text, layer_idx=0, head_idx=0):
    """
    Visualize the attention pattern for a given input at a specific layer and head.
    """
    # Reduce max_length in tokenizer call
    tokens = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(model.device)
    input_length = tokens.input_ids.shape[1]

    # Forward pass to get attention
    with torch.no_grad():
        outputs = model(**tokens, output_attentions=True)

    # Get attention for the specified layer and head
    # Convert attention weights to float32 before converting to NumPy array
    attention = outputs.attentions[layer_idx][0, head_idx].type(torch.float32).cpu().numpy()

    # Plot the attention pattern
    plt.figure(figsize=(10, 8))
    plt.imshow(attention, cmap='viridis')
    plt.title(f"Attention Pattern - Layer {layer_idx}, Head {head_idx}")
    plt.xlabel("Key Sequence Position")
    plt.ylabel("Query Sequence Position")
    plt.colorbar(label="Attention Weight")
    plt.tight_layout()
    plt.savefig(f"attention_l{layer_idx}_h{head_idx}.png")
    plt.close()

    return input_length

# Generate a paragraph of text at different lengths for visualization
short_text = "This is a short test."
medium_text = " ".join(["This is paragraph " + str(i) for i in range(100)])
long_text = " ".join(["This is paragraph " + str(i) for i in range(500)])

print("\nVisualizing attention patterns:")
for text, name in [(short_text, "short"), (medium_text, "medium"), (long_text, "long")]:
    length = visualize_attention(trained_model, tokenizer, text)
    print(f"{name.capitalize()} text: {length} tokens")


Visualizing attention patterns:
Short text: 7 tokens
Medium text: 512 tokens
Long text: 512 tokens


In [None]:
# Modified performance comparison function with memory protection
def measure_performance_safe(model, tokenizer, inputs):
    """
    Measure inference time and memory usage for different input lengths
    with protection against OOM errors.
    """
    results = []

    for text in inputs:
        try:
            # Tokenize
            tokens = tokenizer(text, return_tensors="pt").to(model.device)
            input_length = tokens.input_ids.shape[1]

            # If the input is too long, truncate it
            if input_length > TARGET_CTX_LENGTH:
                tokens = tokenizer(text, return_tensors="pt", truncation=True,
                                  max_length=TARGET_CTX_LENGTH).to(model.device)
                input_length = tokens.input_ids.shape[1]

            # Print expected memory requirement (rough estimate)
            estimated_memory = (input_length**2) * 4 * 2 / 1e9  # Very rough estimate in GB
            print(f"Processing sequence of length {input_length}, estimated memory: {estimated_memory:.2f}GB")

            # Free up GPU memory before measurement
            torch.cuda.empty_cache()

            # Measure GPU memory before
            torch.cuda.synchronize()
            mem_before = torch.cuda.memory_allocated() / 1e9  # GB

            # Measure inference time
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)

            # Use smaller max_new_tokens for longer sequences
            max_tokens = 20 if input_length < 1000 else 5

            start_time.record()
            with torch.no_grad():
                outputs = model.generate(
                    **tokens,
                    max_new_tokens=max_tokens,
                    # Use more memory-efficient generation settings
                    use_cache=True,
                    do_sample=False  # Deterministic generation uses less memory
                )
            end_time.record()

            torch.cuda.synchronize()
            inference_time = start_time.elapsed_time(end_time) / 1000  # seconds

            # Measure GPU memory after
            mem_after = torch.cuda.memory_allocated() / 1e9  # GB
            mem_used = mem_after - mem_before

            results.append({
                "length": input_length,
                "time": inference_time,
                "memory": mem_used
            })

            print(f"Input length: {input_length}, Inference time: {inference_time:.4f}s, Memory used: {mem_used:.4f}GB")

            # Free memory after each run
            torch.cuda.empty_cache()

        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                print(f"Skipping sequence length {input_length} due to OOM error")
                # Free memory after error
                torch.cuda.empty_cache()
                # Add partial result
                results.append({
                    "length": input_length,
                    "time": float('nan'),
                    "memory": float('nan')
                })
            else:
                raise e

        # Add a small delay between tests to help memory recovery
        import time
        time.sleep(2)

    return results

# Use smaller and fewer test inputs to avoid OOM
test_inputs = [
    "This is a short test.",
    " ".join(["Sentence " + str(i) for i in range(50)]),
    " ".join(["Sentence " + str(i) for i in range(200)]),
    " ".join(["Sentence " + str(i) for i in range(500)])  # Max ~2000 tokens
]

print("\nMeasuring performance metrics (safely):")
if torch.cuda.is_available():
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f}GB")
    print(f"Current memory usage: {torch.cuda.memory_allocated() / 1e9:.2f}GB")

    # Force garbage collection
    import gc
    gc.collect()
    torch.cuda.empty_cache()

    perf_results = measure_performance_safe(trained_model, tokenizer, test_inputs)

    # Plot results (skip NaN values if any)
    valid_results = [(r["length"], r["time"], r["memory"])
                     for r in perf_results
                     if not (np.isnan(r["time"]) or np.isnan(r["memory"]))]

    if valid_results:
        lengths, times, memories = zip(*valid_results)

        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(lengths, times, 'o-')
        plt.title("Inference Time vs. Sequence Length")
        plt.xlabel("Sequence Length (tokens)")
        plt.ylabel("Time (seconds)")
        plt.grid(True)

        plt.subplot(1, 2, 2)
        plt.plot(lengths, memories, 'o-')
        plt.title("Memory Usage vs. Sequence Length")
        plt.xlabel("Sequence Length (tokens)")
        plt.ylabel("Memory (GB)")
        plt.grid(True)

        plt.tight_layout()
        plt.savefig("performance_metrics.png")
        plt.close()
    else:
        print("No valid performance measurements to plot")
else:
    print("CUDA not available, skipping performance measurements")


Measuring performance metrics (safely):
Available GPU memory: 15.83GB
Current memory usage: 11.13GB
Processing sequence of length 7, estimated memory: 0.00GB
Input length: 7, Inference time: 3.4475s, Memory used: 0.0000GB
Processing sequence of length 241, estimated memory: 0.00GB
Input length: 241, Inference time: 2.4050s, Memory used: 0.0000GB
Processing sequence of length 1091, estimated memory: 0.01GB
Input length: 1091, Inference time: 1.5944s, Memory used: 0.0000GB
Processing sequence of length 2891, estimated memory: 0.07GB
Input length: 2891, Inference time: 4.1950s, Memory used: 0.0000GB
