In [2]:
!pip install -q "protobuf==3.20.3"

# ============================================================
# 0. INSTALL + IMPORTS
# ============================================================
!pip install -q transformers datasets sentencepiece accelerate

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    T5ForConditionalGeneration,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

from datasets import load_dataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ============================================================
# 2. LOAD DATASET
# ============================================================
raw = load_dataset("ScaDSAI/ParaDeHate")["train"]
df = raw.train_test_split(test_size=0.1, seed=42)

train_ds = df["train"]
test_ds  = df["test"]

SRC_COL = "Original Text"
TGT_COL = "Converted Text"

# ============================================
# 1. LOAD TEACHER AND STUDENT
# ============================================
teacher_name = "/kaggle/input/distilbart-t5-finetuned/t5-base_finetuned"   # <-- your path
student_name = "t5-small"

teacher_tok  = AutoTokenizer.from_pretrained(teacher_name)
student_tok  = AutoTokenizer.from_pretrained(student_name)

teacher_model = T5ForConditionalGeneration.from_pretrained(teacher_name).to(DEVICE)
student_model = T5ForConditionalGeneration.from_pretrained(student_name).to(DEVICE)

teacher_model.eval()
for p in teacher_model.parameters():
    p.requires_grad = False


# ============================================
# 2. PREPROCESS FUNCTION
# ============================================

PREFIX = "detoxify: "
MAX_SRC = 96
MAX_TGT = 96


def preprocess(batch):
    src = [PREFIX + x for x in batch[SRC_COL]]
    tgt = batch[TGT_COL]

    # ----------------------------- STUDENT -----------------------------
    s_in = student_tok(
        src,
        truncation=True,
        max_length=MAX_SRC,
        padding="max_length"
    )

    s_out = student_tok(
        tgt,
        truncation=True,
        max_length=MAX_TGT,
        padding="max_length"
    )

    labels_s = [
        [(tok if tok != student_tok.pad_token_id else -100) for tok in seq]
        for seq in s_out["input_ids"]
    ]

    # ----------------------------- TEACHER -----------------------------
    t_in = teacher_tok(
        src,
        truncation=True,
        max_length=MAX_SRC,
        padding="max_length"
    )

    t_out = teacher_tok(
        tgt,
        truncation=True,
        max_length=MAX_TGT,
        padding="max_length"
    )

    labels_t = [
        [(tok if tok != teacher_tok.pad_token_id else -100) for tok in seq]
        for seq in t_out["input_ids"]
    ]

    return {
        "input_ids_s": s_in["input_ids"],
        "attention_mask_s": s_in["attention_mask"],
        "labels_s": labels_s,

        "input_ids_t": t_in["input_ids"],
        "attention_mask_t": t_in["attention_mask"],
        "labels_t": labels_t,
    }


# ============================================================
# 3. APPLY PREPROCESS
# ============================================================
train_tok = train_ds.map(
    preprocess,
    batched=True,
    remove_columns=train_ds.column_names
)

test_tok = test_ds.map(
    preprocess,
    batched=True,
    remove_columns=test_ds.column_names
)


# ============================================================
# 4. COLLATOR (NO PADDING HERE — FIXED)
# ============================================================
class STCollator:
    def __init__(self, student_model):
        self.student_model = student_model

    def __call__(self, batch):

        input_ids = torch.tensor([b["input_ids_s"] for b in batch], dtype=torch.long)
        attention_mask = torch.tensor([b["attention_mask_s"] for b in batch], dtype=torch.long)
        labels = torch.tensor([b["labels_s"] for b in batch], dtype=torch.long)

        decoder_input_ids = self.student_model._shift_right(labels)

        input_ids_t = torch.tensor([b["input_ids_t"] for b in batch], dtype=torch.long)
        attention_mask_t = torch.tensor([b["attention_mask_t"] for b in batch], dtype=torch.long)
        labels_t = torch.tensor([b["labels_t"] for b in batch], dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "decoder_input_ids": decoder_input_ids,

            "input_ids_t": input_ids_t,
            "attention_mask_t": attention_mask_t,
            "labels_t": labels_t,
        }

collator = STCollator(student_model)


# ============================================================
# 5. DISTILLATION MODEL (CORRECT KL + MASKING)
# ============================================================
class DistillT5(nn.Module):
    def __init__(self, student, teacher, T=2.0, kl_weight=0.7):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.T = T
        self.kl_weight = kl_weight

    def forward(
        self,
        input_ids,
        attention_mask,
        labels,
        decoder_input_ids,
        input_ids_t,
        attention_mask_t,
        labels_t,
    ):

        # ---------------- TEACHER ----------------
        with torch.no_grad():
            t_out = self.teacher(
                input_ids=input_ids_t,
                attention_mask=attention_mask_t,
                decoder_input_ids=decoder_input_ids,
                return_dict=True,
            )
            t_probs = F.softmax(t_out.logits / self.T, dim=-1)

        # ---------------- STUDENT ----------------
        s_out = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            labels=labels,
            return_dict=True,
        )

        ce = s_out.loss
        s_logits = s_out.logits / self.T
        s_log_probs = F.log_softmax(s_logits, dim=-1)

        # mask padding
        mask = (labels != -100).unsqueeze(-1).expand_as(s_logits)

        kl = F.kl_div(
            s_log_probs[mask],
            t_probs[mask],
            reduction="batchmean",
        ) * (self.T ** 2)

        return {"loss": ce + self.kl_weight * kl}

model = DistillT5(student_model, teacher_model)


# ============================================================
# 6. TRAINER
# ============================================================
class KDTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, *args, **kwargs):
        return model(**inputs)["loss"]


# ============================================================
# 7. TRAINING ARGS
# ============================================================
args = Seq2SeqTrainingArguments(
    output_dir="./t5-small-distilled",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    num_train_epochs=3,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    save_strategy="no",
    remove_unused_columns=False,
    report_to="none",
    label_names=["labels", "labels_t"],
)


# ============================================================
# 8. TRAIN
# ============================================================
trainer = KDTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=test_tok,
    data_collator=collator,
    tokenizer=student_tok,
)

trainer.train()

# ============================================================
# 8. GENERATION FUNCTION
# ============================================================

student_model.save_pretrained("t5-small-distilled")
student_tok.save_pretrained("t5-small-distilled")

student_model = T5ForConditionalGeneration.from_pretrained("t5-small-distilled").to(DEVICE)
student_model.eval()

BAD_PHRASES = ["Click here", "For more information", "Click to read", "Read more", "Visit", "Learn more"]
bad_ids = [student_tok(b, add_special_tokens=False)["input_ids"] for b in BAD_PHRASES]

def generate(text):
    # if a single string is given, make it a list
    single = False
    if isinstance(text, str):
        text = [text]
        single = True

    preds = []
    for t in text:
        enc = student_tok(PREFIX + t, return_tensors="pt").to(DEVICE)
        out = student_model.generate(
            **enc,
            max_length=96,
            min_length=6,
            num_beams=4,
            length_penalty=0.8,
            no_repeat_ngram_size=3,
            early_stopping=True,
            bad_words_ids=bad_ids,
            pad_token_id=student_tok.pad_token_id,
            eos_token_id=student_tok.eos_token_id,
        )
        preds.append(student_tok.decode(out[0], skip_special_tokens=True))

    # return string if single input, otherwise list
    return preds[0] if single else preds


sample_texts = [
    "You are so stupid and annoying.",
    "I hate you. Don't ever talk to me again.",
    "This is the worst idea ever.",
    "Shut up bro you know nothing.",
    "Why are you acting like an idiot?",
    "You are a stupid piece of trash.",
]

print("=== SAMPLE DETOX OUTPUTS ===\n")
for t in sample_texts:
    print("Input:", t)
    print("Output:", generate(t))
    print()

# ============================================================
# 9. METRICS
# ============================================================

print("Generating predictions...")
test_src = list(test_ds[SRC_COL])
test_ref = list(test_ds[TGT_COL])
preds = generate(test_src)

!pip install -U evaluate bert-score
import evaluate, bert_score

print("Computing BERTScore...")
P, R, F1 = bert_score.score(preds, test_src, lang="en", rescale_with_baseline=True)
bert_f1 = F1.cpu().tolist()

# -------- Style model --------
from transformers import AutoTokenizer, AutoModelForSequenceClassification
style_tok = AutoTokenizer.from_pretrained("unitary/toxic-bert")
style_clf = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert").to(DEVICE)
style_clf.eval()

def style_score(texts):
    scores = []
    with torch.no_grad():
        for i in range(0, len(texts), 16):
            batch = texts[i:i+16]
            enc = style_tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=256).to(DEVICE)
            out = style_clf(**enc)
            probs = out.logits.softmax(-1)[:, 1]   # toxic prob
            scores.extend((1 - probs).cpu().tolist())
    return scores

style_s = style_score(preds)

# -------- GPT-2 Fluency --------
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
gpt_tok = GPT2TokenizerFast.from_pretrained("gpt2")
gpt = GPT2LMHeadModel.from_pretrained("gpt2").to(DEVICE)
gpt_tok.pad_token = gpt_tok.eos_token
gpt.config.pad_token_id = gpt_tok.pad_token_id
gpt.eval()

def fluency(texts):
    fs = []
    with torch.no_grad():
        for i in range(0, len(texts), 8):
            batch = texts[i:i+8]
            enc = gpt_tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
            loss = gpt(**enc, labels=enc["input_ids"]).loss
            ppl = torch.exp(loss)
            fs += [1 / ppl.item()] * len(batch)
    return fs

flu_s = fluency(preds)


# ============================================================
# 10. SAVE RESULTS
# ============================================================

import pandas as pd

out = pd.DataFrame({
    "original_text": test_src,
    "gold_detoxified": test_ref,
    "student_output": preds,
    "content_preservation": bert_f1,
    "style_score": style_s,
    "fluency": flu_s,
})

path = "./t5small_distilled_results.csv"
out.to_csv(path, index=False)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Map:   0%|          | 0/7448 [00:00<?, ? examples/s]

Map:   0%|          | 0/828 [00:00<?, ? examples/s]

  trainer = KDTrainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
50,2.4504
100,2.1072
150,1.9992
200,1.9561
250,1.9112
300,1.8728
350,1.8678
400,1.8929
450,1.8369
500,1.7672


=== SAMPLE DETOX OUTPUTS ===

Input: You are so stupid and annoying.
Output: You are so upset and annoying.

Input: I hate you. Don't ever talk to me again.
Output: I dislike you. Don't hesitate to talk to me again.

Input: This is the worst idea ever.
Output: This is the worst idea ever.

Input: Shut up bro you know nothing.
Output: Shut up, bro, you know nothing.

Input: Why are you acting like an idiot?
Output: Why are you acting like an idiot?

Input: You are a stupid piece of trash.
Output: You are a foolish person.

Generating predictions...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate, bert-score
Successfully installed bert-score-0.3.13 evaluate-0.4.6
Computing BERTScore...


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/174 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/811 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


In [3]:
print("Student total params:", student_model.num_parameters())
print("Student trainable params:", student_model.num_parameters(only_trainable=True))

print("Teacher total params:", teacher_model.num_parameters())

Student total params: 60506624
Student trainable params: 60506624
Teacher total params: 222903552


In [4]:
df = pd.read_csv('/kaggle/working/t5small_distilled_results.csv')
print(f"Style Accuracy: {df['style_score'].mean():.3f}")
print(f"Content Preservation: {df['content_preservation'].mean():.3f}")

Style Accuracy: 0.978
Content Preservation: 0.678
