This Notebook is Authored by: **Tolulope Ale** for iHARP AI Workshop Series.

# Self-Supervised Learning with LoRA Fine-Tuning

This notebook demonstrates a complete pipeline for:
1. **Self-Supervised Learning (SSL)** - Pre-training a model using denoising autoencoder
2. **LoRA (Low-Rank Adaptation)** - Parameter-efficient fine-tuning
3. **Evaluation** - Comparing SSL+LoRA against baseline models

## Approach Overview
- **Phase 1**: SSL pre-training on CNN/DailyMail dataset using denoising autoencoder
- **Phase 2**: LoRA fine-tuning on DialogSum dataset with minimal labeled data
- **Phase 3**: Evaluation using ROUGE, BLEU, and BERTScore metrics

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Load Base Model

We load **FLAN-T5 Base**, a text-to-text transformer model from Google that has been instruction-tuned. This will serve as our starting point for both SSL pre-training and subsequent fine-tuning.

**FLAN-T5**: Fine-tuned Language Net - Text-to-Text Transfer Transformer (5th generation)

In [None]:
from huggingface_hub import HfApi
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

api = HfApi()

# Loading FLAN-T5 Base model - text-summarization model

model_name= 'google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) #FLAN-T5 tokenizer

## Load SSL Pre-training Dataset

We load the **CNN/DailyMail** dataset which contains news articles and summaries. This large dataset (287k+ training examples) will be used for self-supervised learning.

**Why CNN/DailyMail?** It's a large, diverse corpus perfect for SSL pre-training where we don't need high-quality labels - we'll create our own training signal through data augmentation.

In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

#Text Summarization
ssl_dataset_name = "abisee/cnn_dailymail" #"knkarthick/dialogsum" #loading the dataset from huggingface
ssl_dataset = load_dataset(ssl_dataset_name, '3.0.0')

In [None]:
from datasets import load_dataset_builder #inspection of dataset features

data_builder = load_dataset_builder(ssl_dataset_name, '3.0.0')

print(data_builder.info.features)

In [None]:
ssl_dataset

Examine a few examples from the dataset to understand the article structure and baseline summaries.

In [None]:
#Loading some sample article with their baseline summaries
dialogues = [10, 505, 800, 750]

dash_line = '*'.join('' for x in range(100))

for i, index in enumerate(dialogues):
    print(dash_line)
    print('Sample ', i + 1)
    print(dash_line)
    print('INPUT ARTICLE:')
    print(ssl_dataset['test'][index]['article'])
    print(dash_line)
    print('BASELINE SUMMARY:')
    print(ssl_dataset['test'][index]['highlights'])
    print(dash_line)
    print()

## Baseline Model Performance

Test the **out-of-the-box** FLAN-T5 model without any training. This establishes our baseline - what the model can do with zero task-specific training.

**Note**: No prompt engineering is used here - we're testing raw model capability.

In [None]:
# Base LLM summarization without prompt engineering
for i, index in enumerate(dialogues):
    article = ssl_dataset['test'][index]['article']
    summary = ssl_dataset['test'][index]['highlights']

    inputs = tokenizer(article, return_tensors='pt')
    output = tokenizer.decode(
        model.generate(
            inputs["input_ids"],
            max_new_tokens=500,
        )[0],
        skip_special_tokens=True
    )

    print(dash_line)
    print('Sample ', i + 1)
    print(dash_line)
    print(f'INPUT ARTICLE:\n{article}')
    print(dash_line)
    print(f'BASELINE SUMMARY:\n{summary}')
    print(dash_line)
    print(f'MODEL GENERATION:\n{output}\n')

---

# Self-Supervised Learning (SSL)  

## Pre-text Task

We'll pre-train the model using a **denoising autoencoder** approach.


In [None]:
from transformers import TrainingArguments, Trainer
import torch
import time
#import evaluate
import pandas as pd
import numpy as np

## Reload Model with Optimized Settings

We reload the FLAN-T5 model with **bfloat16** precision to reduce memory usage and speed up training while maintaining model quality.

**bfloat16**: Brain Floating Point 16-bit - A reduced precision format that's more stable than float16 for training.

In [None]:
# We will reload the FLAN-T5 base model but the small version due to compute
# by setting torch_dtype=torch.bfloat16

model_name='google/flan-t5-base'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, dtype=torch.bfloat16)
original_tokenizer = AutoTokenizer.from_pretrained(model_name)

Create a utility function to examine how many parameters are trainable vs frozen. This is crucial for understanding computational requirements and training dynamics.

**Key Insight**: FLAN-T5 Base has ~248M parameters - training all of them requires significant compute!

In [None]:
# Examining the number of trainable model parameters

def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(original_model))

## Create Denoising Autoencoder Setup

**Self-Supervised Learning** core implementation!

### What is a Denoising Autoencoder?
A denoising autoencoder learns to reconstruct clean input from corrupted input. The model learns robust representations by recovering original data from noisy versions.

### Our Approach:
- **Input**: Article with 20% of words randomly dropped (noisy)
- **Target**: Original clean article
- **Learning Signal**: Model learns to predict missing words, understanding context and semantics

This creates a training signal from **unlabeled data** - no human annotations needed

In [None]:
import random

# Create noise function for denoising autoencoder
def add_noise(text, drop_prob=0.2):
    """Randomly drop words from text to create noisy input"""
    words = text.split()
    noisy = [w for w in words if random.random() > drop_prob]
    return ' '.join(noisy) if noisy else text  # Return original if all dropped

# SSL tokenization function
def tokenize_function_ssl(example):
    """Create noisy input and clean target for denoising task"""
    # Apply noise to article
    noisy_article = [add_noise(d) for d in example["article"]]

    # Noisy article as input
    example['input_ids'] = tokenizer(
        noisy_article,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).input_ids

    # Clean dialogue as target (labels)
    example['labels'] = tokenizer(
        example["article"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).input_ids

    return example

# Apply SSL tokenization to dataset
print("Applying SSL tokenization...")
ssl_datasets = ssl_dataset.map(tokenize_function_ssl, batched=True)
ssl_datasets = ssl_datasets.remove_columns(['id', 'article', 'highlights'])

print(f"SSL Dataset shapes:")
print(f"Training: {ssl_datasets['train'].shape}")
print(f"Validation: {ssl_datasets['validation'].shape}")

## Verify that we're creating meaningful corruption (not too much, not too little).

In [None]:
# Displaying samples of noisy vs clean data
print("="*100)
print("EXAMPLES OF NOISY DATA vs CLEAN DATA")
print("="*100)

num_examples = 3
sample_indices = [0, 50, 100]  # Show different samples

for i, idx in enumerate(sample_indices):
    clean_text = ssl_dataset['train'][idx]['article']
    noisy_text = add_noise(clean_text, drop_prob=0.2)

    # Calculate how many words were dropped
    clean_words = len(clean_text.split())
    noisy_words = len(noisy_text.split())
    dropped_words = clean_words - noisy_words

    print(f"\n{'*'*100}")
    print(f"SAMPLE {i+1} (Index {idx})")
    print(f"{'*'*100}")
    print(f"\nCLEAN TEXT ({clean_words} words):")
    print(f"{clean_text[:500]}...")  # Show first 500 chars
    print(f"\n{'-'*100}")
    print(f"\nNOISY TEXT ({noisy_words} words, {dropped_words} words dropped, {dropped_words/clean_words*100:.1f}% reduction):")
    print(f"{noisy_text[:500]}...")  # Show first 500 chars
    print(f"\n{'='*100}")

print(f"\n Noise function randomly drops ~20% of words to create denoising task")

## Configure SSL Training

Set up training parameters for self-supervised pre-training:
- **10 epochs** on the large CNN/DailyMail dataset
- **Early stopping** to prevent overfitting
- **Evaluation every 2000 steps** to monitor progress

This teaches the model general language understanding.

In [None]:
from transformers import EarlyStoppingCallback

# Load fresh model for SSL training
ssl_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

# SSL training configuration
ssl_output_dir = f'/content/drive/MyDrive/Mdata_lab/Dataset/Self_Supervised_Learning/model/flan-t5-ssl-{str(int(time.time()))}'

ssl_training_args = TrainingArguments(
    output_dir=ssl_output_dir,
    num_train_epochs=10,  # SSL pre-training
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    logging_steps=500,
    eval_steps=2000,
    save_steps=2000,
    eval_strategy="steps",
    bf16=True,
    fp16=False,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

ssl_trainer = Trainer(
    model=ssl_model,
    args=ssl_training_args,
    train_dataset=ssl_datasets['train'],
    eval_dataset=ssl_datasets['validation'],
)

print("Starting SSL pre-training (Denoising Autoencoder)...")
print(f"Training on {len(ssl_datasets['train'])} examples")

## Execute SSL Pre-training

**WARNING**: This cell takes several hours to complete!

The model learns to:
- Understand context and word relationships
- Predict missing words from surrounding context
- Build robust internal representations

These learned representations will transfer to downstream tasks with minimal fine-tuning.

In [None]:
# Train SSL model
ssl_trainer.train()

# Save SSL model
ssl_trainer.save_model(ssl_output_dir)
tokenizer.save_pretrained(ssl_output_dir)
print(f"\nSSL model saved to: {ssl_output_dir}")

---

# LoRA Fine-Tuning

Now we'll take our SSL pre-trained model and fine-tune it using **LoRA** (Low-Rank Adaptation) on a small labeled dataset.

### What is LoRA?
**Low-Rank Adaptation** adds small trainable "adapter" matrices to frozen model weights. Instead of updating all 248M parameters, we train only ~4M adapter parameters.

## Why LoRA?
- **Parameter Efficient**: Only trains ~1% of model parameters
- **Fast**: Significantly faster than full fine-tuning
- **Memory Efficient**: Reduces GPU memory requirements
- **Modular**: LoRA adapters can be swapped without retraining base model

## Load SSL Model and Apply LoRA

### LoRA Configuration:
- **r=16**: Rank of adaptation matrices (higher = more capacity, more parameters)
- **lora_alpha=32**: Scaling factor for LoRA updates
- **target_modules**: Apply LoRA to attention query, key, value, and output projections

The SSL-pretrained checkpoint provides our starting point.

In [None]:
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

# Load SSL pre-trained model
#print(f"Loading SSL model from: {ssl_output_dir}")
#ssl_pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(ssl_output_dir, torch_dtype=torch.bfloat16)

latest_checkpoint = '/content/drive/MyDrive/Mdata_lab/Dataset/Self_Supervised_Learning/model/flan-t5-ssl-1767904233/checkpoint-30000'

# Load model from checkpoint
ssl_pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(
    latest_checkpoint,
    torch_dtype=torch.bfloat16
)

# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=16,  # LoRA rank
    lora_alpha=32,  # LoRA scaling factor
    lora_dropout=0.1,
    target_modules=["q", "k", "v", "o"]  # Apply LoRA to query and value projections
)

# Apply PEFT to SSL model
peft_model = get_peft_model(ssl_pretrained_model, lora_config)

# Print trainable parameters
print("\n" + print_number_of_trainable_model_parameters(peft_model))

## Load Target Task Dataset

Now we load **DialogSum** - a smaller, high-quality dataset of conversations and summaries. This is our target task: dialogue summarization.

**Key Point**: We use DialogSum for supervised fine-tuning with much less data than the SSL pre-training phase. The SSL learning transfers to this new task!

In [None]:
dataset_name = "knkarthick/dialogsum" #loading the dataset from huggingface
dataset = load_dataset(dataset_name)
dataset

## Tokenize Supervised Learning Data

Convert dialogue-summary pairs into a format the model can learn from:
- **Input**: "Summarize the following conversation.\n\n[dialogue]\n\nSummary: "
- **Target**: The actual summary

This is **supervised learning** - we have explicit input-output pairs with labels.

In [None]:
# Preprocess the dataset by converting the dialogue-summary (prompt-reponse) pairs into explicit instructions
# Then convert the dialogue-response into tokens

def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = original_tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = original_tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids

    return example

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'dialogue', 'summary',])

In [None]:
# Shape of the tokenized datasets

print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Validation: {tokenized_datasets['validation'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

## Create Small Labeled Subset

**Critical Experiment Design**: We use only **50% of the training data** to demonstrate that SSL pre-training enables strong performance with limited labeled data. You can experiment with fewer samples.

**Hypothesis**: SSL pre-training provides general language understanding, so we need less task-specific labeled data for fine-tuning.

In [None]:
# Use only 50% of training data for supervised fine-tuning
subset_size = len(tokenized_datasets['train']) // 2
small_train = tokenized_datasets['train'].select(range(subset_size))

print(f"Using {len(small_train)} examples for supervised fine-tuning (50% of training data)")
print(f"Full training set: {len(tokenized_datasets['train'])} examples")

## Configure LoRA Fine-Tuning

Setup training for the PEFT (Parameter-Efficient Fine-Tuning) phase:
- **10 epochs** with early stopping
- **Higher learning rate** (1e-4) because we're only training adapter layers
- **Small batch size** (16) to fit in memory
- Only the LoRA adapter weights are updated, base model stays frozen!

In [None]:
from transformers import EarlyStoppingCallback

# PEFT training configuration
peft_output_dir = f'/content/drive/MyDrive/Mdata_lab/Dataset/Self_Supervised_Learning/model/flan-t5-peft-{str(int(time.time()))}'

peft_training_args = TrainingArguments(
    output_dir=peft_output_dir,
    num_train_epochs=10,
    learning_rate=1e-4,  # Higher learning rate for adapter layers
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    logging_steps=100,
    eval_steps=200,
    save_steps=200,
    eval_strategy="steps",
    bf16=True,
    fp16=False,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=small_train,  # Small labeled subset
    eval_dataset=tokenized_datasets['validation'],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
)

print("Starting PEFT fine-tuning on summarization task...")
print(f"Only LoRA adapters will be trained, base model is frozen")

## Train LoRA Adapters

This should be much faster than SSL pre-training since we're:
1. Training fewer parameters (only LoRA adapters)
2. Using a smaller dataset
3. Building on SSL-learned representations

In [None]:
# Train PEFT model
peft_trainer.train()

# Save PEFT model
peft_trainer.save_model(peft_output_dir)
original_tokenizer.save_pretrained(peft_output_dir)
print(f"\nPEFT model saved to: {peft_output_dir}")

In [None]:
# Load the latest checkpoint to resume training
#latest_checkpoint = '/content/drive/MyDrive/Mdata_lab/Dataset/Self_Supervised_Learning/model/flan-t5-peft-1768103269/checkpoint-49500'

# resume training
#peft_trainer.train(resume_from_checkpoint=latest_checkpoint)
#original_tokenizer.save_pretrained(peft_output_dir)

---

# Evaluation

Now we'll evaluate our SSL+LoRA approach against the baseline model using standard summarization metrics.

## Load Final Trained Model

Load the best checkpoint from our PEFT fine-tuning for evaluation. This model combines:
1. FLAN-T5 base weights
2. SSL-learned representations
3. LoRA adapters fine-tuned on DialogSum dataset

In [None]:
peft= '/content/drive/MyDrive/Mdata_lab/Dataset/Self_Supervised_Learning/model/flan-t5-peft-1768318334/checkpoint-3600'
peft_tokenizer = '/content/drive/MyDrive/Mdata_lab/Dataset/Self_Supervised_Learning/model/flan-t5-peft-1768318334'
peft_model = AutoModelForSeq2SeqLM.from_pretrained(peft, dtype=torch.bfloat16)
peft_tokenizer = AutoTokenizer.from_pretrained(peft_tokenizer)

## Metrics
- **ROUGE-1**: Unigram (1-word) overlap between generated and reference summary
- **ROUGE-2**: Bigram (2-word) overlap - captures phrase similarity
- **ROUGE-L**: Longest common subsequence - measures fluency
- **BLEU**: Precision-focused metric from machine translation
- **BERTScore**: Semantic similarity using BERT embeddings - captures meaning even with different words

**Note**: **ROUGE-L** and **BERTScore** are the main evaluation metrics.

In [None]:
%pip install evaluate rouge-score bert-score
import evaluate

In [None]:
rouge = evaluate.load('rouge')
bleu = evaluate.load('bleu')
bertscore = evaluate.load('bertscore')

In [None]:
# Evaluation function
def evaluate_summaries(predictions, references, method_name):
    """Evaluate predictions against references using multiple metrics"""

    # ROUGE scores
    rouge_results = rouge.compute(
        predictions=predictions,
        references=references,
        use_stemmer=True
    )

    # BLEU score
    bleu_results = bleu.compute(
        predictions=predictions,
        references=references
    )

    # BERTScore (semantic similarity)
    bert_results = bertscore.compute(
        predictions=predictions,
        references=references,
        lang='en',
        model_type='distilbert-base-uncased'  # Faster variant
    )

    # Calculate average BERTScore
    avg_bertscore = sum(bert_results['f1']) / len(bert_results['f1'])

    print(f"\n{'='*70}")
    print(f"Results for: {method_name}")
    print(f"{'='*70}")
    print(f"ROUGE-1:      {rouge_results['rouge1']:.4f}")
    print(f"ROUGE-2:      {rouge_results['rouge2']:.4f}")
    print(f"ROUGE-L:      {rouge_results['rougeL']:.4f}")
    print(f"BLEU:         {bleu_results['bleu']:.4f}")
    print(f"BERTScore: {avg_bertscore:.4f}")

    return {
        'Method': method_name,
        'ROUGE-1': round(rouge_results['rouge1'], 4),
        'ROUGE-2': round(rouge_results['rouge2'], 4),
        'ROUGE-L': round(rouge_results['rougeL'], 4),
        'BLEU': round(bleu_results['bleu'], 4),
        'BERTScore': round(avg_bertscore, 4)
    }

print("Evaluation function ready!")

## Generate Predictions for Comparison

Generate summaries from both models on 10 randomly selected test examples:
1. **Original Model**: FLAN-T5 with no additional training
2. **PEFT Model**: FLAN-T5 + SSL pre-training + LoRA fine-tuning

This allows us to directly compare the impact of our SSL+LoRA approach.

In [None]:
# Generate predictions for all methods on test samples
#test_indices = [1, 55, 80, 150]
import random

# Randomly select 10 samples from the test set
random.seed(48)  #
test_indices = random.sample(range(len(dataset['test'])), 10)

# Storage for predictions
baseline_summaries = []
original_predictions = []
#few_shot_predictions = []
peft_predictions = []

print("Generating predictions for all methods...")
print(f"Testing on {len(test_indices)} samples\n")

for idx, index in enumerate(test_indices):
    article = dataset['test'][index]['dialogue']
    summary = dataset['test'][index]['summary']

    # Store baseline
    baseline_summaries.append(summary)

    print(f"Processing sample {idx + 1}/{len(test_indices)}...")

    # Original prediction
    zero_shot_prompt = f"Summarize the following conversation.\n\n{article}\n\nSummary: "
    inputs = original_tokenizer(zero_shot_prompt, return_tensors='pt').to(original_model.device)
    output = original_tokenizer.decode(
        original_model.generate(inputs["input_ids"], max_new_tokens=2000, do_sample=False)[0],
        skip_special_tokens=True
    )
    original_predictions.append(output)

    # PEFT prediction
    peft_prompt = f"Summarize the following conversation.\n\n{article}\n\nSummary: "
    inputs = peft_tokenizer(peft_prompt, return_tensors='pt')
    inputs = {k: v.to(peft_model.device) for k, v in inputs.items()}
    output = peft_tokenizer.decode(
        peft_model.generate(input_ids=inputs["input_ids"], max_new_tokens=2000, do_sample=False)[0],
        skip_special_tokens=True
    )
    peft_predictions.append(output)

print(f"\n Generated {len(test_indices)} predictions for each method")

In [None]:
# Evaluate all methods
results_list = []

print("\nEvaluating all methods against baseline summaries...")
print("="*70)

# Evaluate Zero-Shot
results_list.append(evaluate_summaries(original_predictions, baseline_summaries, "Original Model"))

# Evaluate PEFT (SSL + LoRA)
results_list.append(evaluate_summaries(peft_predictions, baseline_summaries, "PEFT (SSL+LoRA)"))

The PEFT model show significant improvement across all metrics, demonstrating that:
1. SSL pre-training provides useful representations
2. LoRA enables effective fine-tuning with limited labeled data
3. The combined approach is more data-efficient than training from scratch