<a href="https://colab.research.google.com/github/upriser72/Bridges-Distribution-Gap-in-Language-Model-Fine-Tuning/blob/main/mT5_Summarization_SDFT_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers[torch] datasets pandas sentencepiece sacremoses

In [None]:
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
import os

# Suppress warnings
os.environ["WANDB_DISABLED"] = "true"

# ===================================================================================
# 1. SETUP AND DATA PREPARATION
# ===================================================================================
print("Step 1: Setting up data...")

# --- MODIFIED SECTION ---
# Load your dataset directly.
# Ensure "Summarization_dataset.csv" is in the same directory as this script.
DATASET_PATH = "Summarization_dataset.csv"

try:
    # Load the dataset using pandas
    df = pd.read_csv(DATASET_PATH)
    # Convert the pandas DataFrame to a Hugging Face Dataset object
    raw_dataset = Dataset.from_pandas(df)
    print(f"Successfully loaded dataset from '{DATASET_PATH}'.")
    print("Dataset preview:")
    print(raw_dataset)
except FileNotFoundError:
    print(f"Error: The file '{DATASET_PATH}' was not found.")
    print("Please make sure your dataset file is in the correct directory and named correctly.")
    exit() # Stop the script if the data file isn't found
# --- END MODIFIED SECTION ---


# Model and tokenizer details
MODEL_NAME = "google/mt5-small"
TEACHER_MODEL_PATH = "./mt5_small_teacher"
STUDENT_MODEL_PATH = "./mt5_small_student_distilled"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Preprocessing function to tokenize the data
# This function assumes your columns are named 'article' and 'highlights'
def preprocess_function(examples):
    prefix = "summarize: "
    inputs = [prefix + str(doc) for doc in examples["article"]]

    # Tokenize inputs and labels
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    labels = tokenizer(text_target=examples["highlights"], max_length=128, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply the preprocessing
tokenized_dataset = raw_dataset.map(preprocess_function, batched=True)
print("\nData setup and tokenization complete.\n")


# ===================================================================================
# 2. FINE-TUNE THE TEACHER MODEL
# ===================================================================================
print("Step 2: Fine-tuning the Teacher Model... 👨‍🏫")

# Load the base model
teacher_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Training arguments for the teacher model
training_args = Seq2SeqTrainingArguments(
    output_dir=TEACHER_MODEL_PATH,
    num_train_epochs=4, # Use more epochs for a real dataset (e.g., 3-5)
    per_device_train_batch_size=4, # Adjust based on your GPU memory
    save_steps=1000,
    save_total_limit=2,
    logging_steps=100,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(), # Use mixed precision if a GPU is available
)

# Data collator for sequence-to-sequence tasks
data_collator = DataCollatorForSeq2Seq(tokenizer, model=teacher_model)

# Trainer for the teacher model
teacher_trainer = Seq2SeqTrainer(
    model=teacher_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train the teacher
teacher_trainer.train()
teacher_trainer.save_model(TEACHER_MODEL_PATH)
print("Teacher model fine-tuned and saved.\n")


# ===================================================================================
# 3. GENERATE "SOFT LABELS" (LOGITS) FROM THE TEACHER
# ===================================================================================
print("Step 3: Generating soft labels (logits) from the teacher...")

# Load the trained teacher model
teacher_model = AutoModelForSeq2SeqLM.from_pretrained(TEACHER_MODEL_PATH)
teacher_model.eval() # Set to evaluation mode
if torch.cuda.is_available():
    teacher_model.to("cuda")

teacher_logits_list = []
# Generate logits for each sample in the dataset
for example in tokenized_dataset:
    input_ids = torch.tensor(example['input_ids']).unsqueeze(0)
    attention_mask = torch.tensor(example['attention_mask']).unsqueeze(0)

    if torch.cuda.is_available():
        input_ids = input_ids.to("cuda")
        attention_mask = attention_mask.to("cuda")

    with torch.no_grad():
        outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
        teacher_logits_list.append(outputs.logits.cpu().numpy())

# Add the teacher logits to the dataset
tokenized_dataset = tokenized_dataset.add_column("teacher_logits", teacher_logits_list)
print("Soft labels generated and added to the dataset.\n")


# ===================================================================================
# 4. SELF-DISTILLATION FOR THE STUDENT MODEL
# ===================================================================================
print("Step 4: Training the Student Model with Self-Distillation... 🧑‍🎓")

# Custom Trainer for Distillation
class DistillationTrainer(Seq2SeqTrainer):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(self, model, inputs, return_outputs=False):
        # Extract ground-truth labels and teacher logits from inputs
        labels = inputs.pop("labels")
        teacher_logits = inputs.pop("teacher_logits").squeeze(0)

        # Get student model's outputs
        outputs_student = model(**inputs)
        student_logits = outputs_student.get("logits")

        # 1. Distillation Loss (KL Divergence)
        soft_log_probs_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_probs_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        loss_kl = F.kl_div(soft_log_probs_student, soft_probs_teacher, reduction='batchmean') * (self.temperature ** 2)

        # 2. Standard Cross-Entropy Loss
        vocab_size = student_logits.size(-1)
        loss_ce = F.cross_entropy(student_logits.view(-1, vocab_size), labels.view(-1), ignore_index=-100)

        # 3. Combine the two losses
        loss = self.alpha * loss_kl + (1.0 - self.alpha) * loss_ce

        return (loss, outputs_student) if return_outputs else loss

# Initialize a NEW mt5-small model for the student
student_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Training arguments for the student model
student_training_args = Seq2SeqTrainingArguments(
    output_dir=STUDENT_MODEL_PATH,
    num_train_epochs=15, # Student can sometimes benefit from more epochs
    per_device_train_batch_size=1, # This MUST be 1 due to how we stored logits
    save_steps=1000,
    save_total_limit=2,
    logging_steps=100,
    fp16=torch.cuda.is_available(),
    remove_unused_columns=False, # IMPORTANT: Keep 'teacher_logits'
)

# Instantiate the custom trainer
distillation_trainer = DistillationTrainer(
    model=student_model,
    args=student_training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    alpha=0.5,       # Balance between teacher and ground-truth (0.0 to 1.0)
    temperature=2.0  # Softens the teacher's predictions (usually > 1.0)
)

# Train the student model
distillation_trainer.train()
distillation_trainer.save_model(STUDENT_MODEL_PATH)
print("Student model trained with distillation and saved.\n")


# ===================================================================================
# 5. INFERENCE WITH THE DISTILLED MODEL
# ===================================================================================
print("Step 5: Performing inference with the final distilled model... 🧪")

# Load the final distilled student model and tokenizer
distilled_model = AutoModelForSeq2SeqLM.from_pretrained(STUDENT_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Example article for summarization (taken from your sample data)
article_text = "Ever noticed how linear television has been overtaken by streaming? Experts say the primary reason is the on-demand nature of platforms like Netflix and Hulu. Viewers can watch what they want, when they want, without being tied to a broadcast schedule. This flexibility is key, allowing for binge-watching and personalized content discovery, which traditional TV cannot offer."

# Prepare the input
inputs = tokenizer(
    "summarize: " + article_text,
    return_tensors="pt",
    max_length=512,
    truncation=True,
    padding="max_length"
)

# Generate the summary
summary_ids = distilled_model.generate(
    inputs.input_ids,
    num_beams=4,
    max_length=60,
    early_stopping=True
)

# Decode and print the summary
generated_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print(f"\nOriginal Article:\n{article_text}")
print(f"\nGenerated Summary:\n{generated_summary}")