<a href="https://colab.research.google.com/github/palindromeRice/Knowledge_Distillation_Demo/blob/main/KD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 T5 Knowledge Distillation & Evaluation

This notebook shows how to:

1. **Distill** a fine‑tuned T5‑Base teacher into a T5‑Small student on a local CSV.  
2. **Evaluate** both models using ROUGE and BERTScore, and report student retention.


## 📥 Downloading the Dataset from Kaggle

We use the `kagglehub` library to download the **News Summary** dataset.

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("sunnysai12345/news-summary")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/news-summary


## 📊 Dataset Analysis

In this cell, we load and inspect the two CSV files from the News Summary dataset:

- **`news_summary.csv`**: The main dataset used for training and distillation.
- **`news_summary_more.csv`**: An extended version with more textual content.

We perform:
- Basic structure inspection (shape, column names)
- Random sample previews for quick content checks
- Column name comparison to identify overlaps and differences between the two files

This helps ensure we use the correct CSV (`news_summary.csv`) for both training and evaluation.


In [4]:
import pandas as pd

# 📊 Load News Summary Datasets
summary_df = pd.read_csv(f"{path}/news_summary.csv", encoding="latin-1")
more_df = pd.read_csv(f"{path}/news_summary_more.csv", encoding="latin-1")

# 📝 Overview of news_summary.csv
print("="*40)
print("🗂️ Dataset 1: news_summary.csv")
print("="*40)
print("🔍 Shape:", summary_df.shape)
print("🧾 Columns:", list(summary_df.columns))
print("\n🧠 Sample Rows:")
display(summary_df.sample(3))
print("\n📋 Data Info:")
summary_df.info()

# 📝 Overview of news_summary_more.csv
print("\n" + "="*40)
print("🗂️ Dataset 2: news_summary_more.csv")
print("="*40)
print("🔍 Shape:", more_df.shape)
print("🧾 Columns:", list(more_df.columns))
print("\n🧠 Sample Rows:")
display(more_df.sample(3))
print("\n📋 Data Info:")
more_df.info()

# 🔁 Compare Column Names
print("\n" + "="*40)
print("🔍 Column Name Comparison")
print("="*40)
common_cols = set(summary_df.columns).intersection(set(more_df.columns))
unique_summary = set(summary_df.columns) - set(more_df.columns)
unique_more = set(more_df.columns) - set(summary_df.columns)

print(f"✅ Common Columns: {list(common_cols)}")
print(f"📁 Only in news_summary.csv: {list(unique_summary)}")
print(f"📁 Only in news_summary_more.csv: {list(unique_more)}")


🗂️ Dataset 1: news_summary.csv
🔍 Shape: (4514, 6)
🧾 Columns: ['author', 'date', 'headlines', 'read_more', 'text', 'ctext']

🧠 Sample Rows:


Unnamed: 0,author,date,headlines,read_more,text,ctext
4053,Chhavi Tyagi,"27 Mar 2017,Monday",AAP doing nothing for cows: Ved Prakash after ...,http://indiatoday.intoday.in/story/delhi-civic...,"Former AAP MLA from Bawana Ved Prakash Satish,...",In the run up to Delhi's municipal corporation...
21,Ayushi Ahluwalia,"03 Aug 2017,Thursday","Delhi's AIIMS, Safdarjung to be declared no-ha...",http://indiatoday.intoday.in/story/delhi-aiims...,The North Delhi Municipal Corporation has said...,Dozens of street vendors will have to take the...
1289,Chhavi Tyagi,"06 Feb 2017,Monday",Jayalalithaa died of multiple organ failure: D...,http://indiatoday.intoday.in/story/jayalalitha...,Over two months after the death of former Tami...,Over two months after the death of former Tami...



📋 Data Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4514 entries, 0 to 4513
Data columns (total 6 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   author     4514 non-null   object
 1   date       4514 non-null   object
 2   headlines  4514 non-null   object
 3   read_more  4514 non-null   object
 4   text       4514 non-null   object
 5   ctext      4396 non-null   object
dtypes: object(6)
memory usage: 211.7+ KB

🗂️ Dataset 2: news_summary_more.csv
🔍 Shape: (98401, 2)
🧾 Columns: ['headlines', 'text']

🧠 Sample Rows:


Unnamed: 0,headlines,text
16483,22-year-old tries to steal passenger plane in ...,A 22-year-old student pilot has been arrested ...
59446,"Aditya and I talk about having our 2nd baby, n...",Rani Mukerji said she doesn't talk about work ...
71945,Woman who suffered burns after crew spilled te...,A woman is suing JetBlue on suffering burns af...



📋 Data Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 98401 entries, 0 to 98400
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   headlines  98401 non-null  object
 1   text       98401 non-null  object
dtypes: object(2)
memory usage: 1.5+ MB

🔍 Column Name Comparison
✅ Common Columns: ['text', 'headlines']
📁 Only in news_summary.csv: ['read_more', 'ctext', 'author', 'date']
📁 Only in news_summary_more.csv: []


## 📦 Installing Dependencies

Before running distillation or evaluation, we need to install the necessary libraries:



In [5]:
!pip install -q transformers
!pip install -q datasets
!pip install -q rouge-score
!pip install -q bert_score
!pip install -q pyemd
!pip install -q hf_xet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/183.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/143.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [6]:
import warnings
warnings.filterwarnings('ignore')

## 🔧 T5 Knowledge Distillation

In this section, we distill a larger T5-Base model (fine-tuned for news summarization) into a smaller, faster T5-Small model.

**Why distillation?**  
The goal is to retain most of the teacher model’s performance while significantly reducing the model size and inference time—ideal for deployment on resource-constrained environments.

We use:
- A fine-tuned T5-Base model as the **teacher**
- A pre-trained T5-Small model as the **student**
- A combination of **Forward KL Divergence** (soft labels) and **Cross-Entropy** (true labels) for training

After training, the student model is saved and used for evaluation.


In [7]:
import os
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from datasets import load_dataset
from tqdm.auto import tqdm


def parse_args():
    parser = argparse.ArgumentParser(description="T5 Knowledge Distillation on Local CSV")
    parser.add_argument("--dataset_file", type=str,
                        default="/kaggle/input/news-summary/news_summary.csv",
                        help="Path to local CSV (default: %(default)s)")
    parser.add_argument("--teacher_model", type=str,
                        default="mrm8488/t5-base-finetuned-summarize-news",
                        help="HF repo for fine-tuned T5-Base teacher")
    parser.add_argument("--student_model", type=str,
                        default="google-t5/t5-small",
                        help="HF repo for T5-Small student")
    parser.add_argument("--source_prefix", type=str, default="",
                        help="Task prefix, e.g. 'summarize: '")
    parser.add_argument("--max_source_length", type=int, default=512)
    parser.add_argument("--max_target_length", type=int, default=64)
    parser.add_argument("--output_dir", type=str,
                        default="./distilled_t5",
                        help="Directory to save distilled student model (default: %(default)s)")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--alpha", type=float, default=0.5,
                        help="Weight for distillation vs. CE loss (default: %(default)s)")
    parser.add_argument("--temperature", type=float, default=1.0,
                        help="Softmax temperature for teacher logits (default: %(default)s)")
    parser.add_argument("--warmup_steps", type=int, default=0)
    parser.add_argument("--weight_decay", type=float, default=0.0)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    args, _ = parser.parse_known_args()
    return args


In [8]:
class ForwardKLLoss(torch.nn.Module):
    def __init__(self, ignore_index=-100, temperature=1.0):
        super().__init__()
        self.ignore_index = ignore_index
        self.temperature = temperature

    def forward(self, student_logits, teacher_logits, labels):
        # Compute soft targets
        t_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        s_logprobs = F.log_softmax(student_logits / self.temperature, dim=-1)
        kd = F.kl_div(s_logprobs, t_probs, reduction='batchmean') * (self.temperature ** 2)
        # Cross-entropy loss on true labels
        ce = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=self.ignore_index
        )
        return kd, ce


def preprocess_batch(examples, tokenizer, args):
    # Convert to strings and tokenize inputs
    inputs = [args.source_prefix + str(txt) for txt in examples.get("text", [])]
    model_inputs = tokenizer(
        inputs,
        max_length=args.max_source_length,
        truncation=True,
        padding='max_length'
    )
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        targets = [str(t) for t in examples.get("ctext", [])]
        labels = tokenizer(
            targets,
            max_length=args.max_target_length,
            truncation=True,
            padding='max_length'
        )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [9]:
def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load and filter dataset
    raw = load_dataset(
        'csv', data_files={'train': args.dataset_file}, encoding='latin-1'
    )['train']
    raw = raw.filter(
        lambda x: isinstance(x.get('text'), str) and x['text'].strip() and
                  isinstance(x.get('ctext'), str) and x['ctext'].strip()
    )

    # Load tokenizer and models
    tokenizer = T5TokenizerFast.from_pretrained(args.teacher_model)
    teacher = T5ForConditionalGeneration.from_pretrained(
        args.teacher_model
    ).to(device)
    student = T5ForConditionalGeneration.from_pretrained(
        args.student_model
    ).to(device)
    teacher.eval()
    for param in teacher.parameters():
        param.requires_grad = False

    # Tokenize dataset
    tokenized = raw.map(
        lambda examples: preprocess_batch(examples, tokenizer, args),
        batched=True,
        remove_columns=raw.column_names
    )
    tokenized.set_format(
        type='torch',
        columns=['input_ids', 'attention_mask', 'labels']
    )

    train_loader = DataLoader(
        tokenized,
        batch_size=args.batch_size,
        shuffle=True
    )

    # Optimizer & scheduler
    optimizer = AdamW(
        student.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay
    )
    total_steps = (
        len(train_loader) // args.gradient_accumulation_steps
    ) * args.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, args.warmup_steps, total_steps
    )

    kd_loss_fn = ForwardKLLoss(
        ignore_index=tokenizer.pad_token_id,
        temperature=args.temperature
    )

    global_step = 0
    # Training loop
    for epoch in range(1, args.num_epochs + 1):
        student.train()
        total_loss = 0.0
        progress = tqdm(
            train_loader, desc=f"Epoch {epoch}/{args.num_epochs}"
        )
        optimizer.zero_grad()

        for step, batch in enumerate(progress, start=1):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Teacher forward (with labels) to get logits
            with torch.no_grad():
                teacher_outputs = teacher(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                teacher_logits = teacher_outputs.logits

            # Student forward (with labels)
            student_outputs = student(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            student_logits = student_outputs.logits

            # Compute losses
            kd_loss, ce_loss = kd_loss_fn(
                student_logits, teacher_logits, labels
            )
            loss = args.alpha * kd_loss + (1 - args.alpha) * ce_loss
            loss = loss / args.gradient_accumulation_steps
            loss.backward()
            total_loss += loss.item()

            # Step optimizer
            if step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                progress.set_postfix(
                    avg_loss=total_loss / global_step
                )

        # Save checkpoint after each epoch
        student.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        print(f"Epoch {epoch} complete. Model saved to {args.output_dir}")

    print("Distillation finished.")


if __name__ == '__main__':
    main()

Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/4514 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Map:   0%|          | 0/4396 [00:00<?, ? examples/s]

Epoch 1/3:   0%|          | 0/550 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch 1 complete. Model saved to ./distilled_t5


Epoch 2/3:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch 2 complete. Model saved to ./distilled_t5


Epoch 3/3:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch 3 complete. Model saved to ./distilled_t5
Distillation finished.


## 📊 T5 Evaluation

In this section, we evaluate how well the **T5-Small student model**, trained via distillation, retains the summarization performance of the original **T5-Base teacher model**.

**Why evaluate?**  
To validate that the student model remains competitive while being more lightweight and efficient.

We use:
- The **same dataset** (`text` → `ctext`) as the training phase
- **ROUGE** scores for lexical overlap
- **BERTScore** for semantic similarity

The script also calculates **retention percentages**, showing how much of the teacher’s performance the student preserves after distillation.


In [10]:
import os
import argparse
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast
from datasets import load_dataset
from rouge_score import rouge_scorer, scoring
from bert_score import score as bert_score
from tqdm.auto import tqdm


def parse_args():
    parser = argparse.ArgumentParser("Evaluate T5 Teacher vs. Student on ROUGE and BERTScore")
    parser.add_argument("--dataset_file", type=str,
                        default="/kaggle/input/news-summary/news_summary.csv",
                        help="Path to local CSV with test examples")
    parser.add_argument("--teacher_model", type=str,
                        default="mrm8488/t5-base-finetuned-summarize-news",
                        help="HF repo or local dir for teacher model")
    parser.add_argument("--student_model", type=str,
                        default="./distilled_t5",
                        help="Local dir for distilled student model")
    parser.add_argument("--prefix", type=str, default="summarize: ",
                        help="Task prefix (e.g., 'summarize: ')")
    parser.add_argument("--max_source_length", type=int, default=512,
                        help="Max input tokens")
    parser.add_argument("--max_target_length", type=int, default=64,
                        help="Max generated tokens")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="Batch size for generation")
    parser.add_argument("--num_samples", type=int, default=100,
                        help="Number of examples to evaluate (start of dataset)")
    args, _ = parser.parse_known_args()
    return args


In [11]:
def load_models_and_tokenizer(teacher_path, student_path, device):
    tokenizer = T5TokenizerFast.from_pretrained(teacher_path)
    teacher = T5ForConditionalGeneration.from_pretrained(teacher_path).to(device)
    student = T5ForConditionalGeneration.from_pretrained(student_path).to(device)
    teacher.eval()
    student.eval()
    return tokenizer, teacher, student


def prepare_dataset(path, num_samples):
    raw = load_dataset('csv', data_files={'test': path}, split='test', encoding='latin-1')
    raw = raw.filter(lambda x: isinstance(x.get('text'), str) and x['text'].strip() and
                             isinstance(x.get('ctext'), str) and x['ctext'].strip())
    return raw.select(range(min(len(raw), num_samples)))


def evaluate(models, tokenizer, dataset, args):
    device = next(iter(models.values())).device
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    aggregator = scoring.BootstrapAggregator()

    references = []
    teacher_preds = []
    student_preds = []

    for ex in tqdm(dataset, desc="Evaluating"):
        input_text = args.prefix + ex['text']
        inputs = tokenizer(input_text, return_tensors='pt', max_length=args.max_source_length,
                           truncation=True).to(device)
        reference = ex['ctext']
        references.append(reference)

        for name, model in models.items():
            summary_ids = model.generate(input_ids=inputs['input_ids'],
                                         attention_mask=inputs['attention_mask'],
                                         max_length=args.max_target_length,
                                         num_beams=4,
                                         early_stopping=True)
            pred = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            if name == 'teacher': teacher_preds.append(pred)
            else: student_preds.append(pred)

            scores = rouge.score(reference, pred)
            aggregator.add_scores({f"{name}_{k}": v for k, v in scores.items()})

    rouge_results = {k: float(v.mid.fmeasure) for k, v in aggregator.aggregate().items()}

    # BERTScore
    _, _, bs_teacher = bert_score(references, teacher_preds, model_type='roberta-large', lang='en', rescale_with_baseline=True)
    _, _, bs_student = bert_score(references, student_preds, model_type='roberta-large', lang='en', rescale_with_baseline=True)

    bs_results = {
        "teacher_bertscore": bs_teacher.mean().item(),
        "student_bertscore": bs_student.mean().item()
    }

    return rouge_results, bs_results

In [12]:
def main():
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, teacher, student = load_models_and_tokenizer(
        args.teacher_model, args.student_model, device
    )
    dataset = prepare_dataset(args.dataset_file, args.num_samples)
    models = {'teacher': teacher, 'student': student}

    rouge_scores, bs_scores = evaluate(models, tokenizer, dataset, args)

    print("\n=== ROUGE F1 Scores ===")
    for metric, score in sorted(rouge_scores.items()):
        print(f"{metric}: {score:.4f}")

    print("\n=== ROUGE Retention (% of Teacher) ===")
    for m in ['rouge1', 'rouge2', 'rougeL']:
        t = rouge_scores[f'teacher_{m}']
        s = rouge_scores[f'student_{m}']
        print(f"{m}: {(s/t*100 if t > 0 else 0):.1f}% retained")

    print("\n=== BERTScore F1 ===")
    for metric, score in sorted(bs_scores.items()):
        print(f"{metric}: {score:.4f}")

    print("\n=== BERTScore Retention ===")
    t = bs_scores['teacher_bertscore']
    s = bs_scores['student_bertscore']
    print(f"bertscore: {(s/t*100 if t > 0 else 0):.1f}% retained")

if __name__ == '__main__':
    main()

Generating test split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/4514 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



=== ROUGE F1 Scores ===
student_rouge1: 0.2086
student_rouge2: 0.1259
student_rougeL: 0.1599
teacher_rouge1: 0.2297
teacher_rouge2: 0.1407
teacher_rougeL: 0.1771

=== ROUGE Retention (% of Teacher) ===
rouge1: 90.8% retained
rouge2: 89.5% retained
rougeL: 90.3% retained

=== BERTScore F1 ===
student_bertscore: 0.1480
teacher_bertscore: 0.1747

=== BERTScore Retention ===
bertscore: 84.7% retained
