In [None]:
!pip install torch transformers datasets


In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

In [None]:
teacher_name = "gpt2"
teacher = AutoModelForCausalLM.from_pretrained(teacher_name)
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
class TinyTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(512, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=512)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        B, L = input_ids.shape
        pos = torch.arange(0, L, device=input_ids.device).unsqueeze(0).expand(B, L)
        x = self.embed(input_ids) + self.pos_embed(pos)
        x = self.transformer(x)               # [B, L, d_model]
        logits = self.lm_head(x)
        return logits

In [None]:
vocab_size = tokenizer.vocab_size
student = TinyTransformerLM(vocab_size)
vanilla_student = TinyTransformerLM(vocab_size)



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
teacher.to(device)
student.to(device)
vanilla_student.to(device)
print("---")

cuda
---


In [None]:
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")  # small slice
def tok_fn(ex):
    return tokenizer(ex["text"], truncation=True, padding="max_length", max_length=64)
ds = ds.map(tok_fn, batched=True)
ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
dl = DataLoader(ds, batch_size=8, shuffle=True)


In [None]:
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")
    return total, trainable
count_parameters(teacher), count_parameters(student)
print("--")

Total parameters: 124,439,808
Trainable parameters: 0
Total parameters: 26,967,121
Trainable parameters: 26,967,121
--


In [None]:
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)
T = 2.0
alpha = 0.5

Given the total loss function:

$$
L_{\text{total}} = \alpha L_{\text{CE}} + (1 - \alpha) L_{\text{KL}}
$$

The gradient with respect to the model parameters $\theta$ is:

$$
\nabla_\theta L_{\text{total}}
= \alpha \nabla_\theta L_{\text{CE}}
+ (1 - \alpha) \nabla_\theta L_{\text{KL}}
$$

$\Rightarrow$ The update step depends on both components.Since the derivative of a sum is the sum of the derivatives, each term contributes proportionally to its weight ($\alpha$ and $1 - \alpha$).


In [None]:
# the model is updating its weights with a double objective --> decrease the difference between its own (student) probability distribution overt the vocabulary and
# the distribution of the teacher (we can say the student is punished when its own probability distribution is too different from the one of the teacher) and
# at the same time it has some freedom to update the weights wrt its own loss --> the one that comes from its own prediction.

# given Ltotal​=αLCE​+(1−α)LKL
# ∇θ​Ltotal​=α∇θ​LCE​+(1−α)∇θ​LKL --> update depends on both --> and we know that derivative of sum is sum of derivative
#
# One that says: “adjust yourself to reduce prediction errors vs the ground truth.”
# ∇CE points toward the minimum where you get correct labels.
#
# Another that says: “also align your output distribution with what the teacher believes.”
# ∇KL points toward the minimum where you look like the teacher.


for epoch in range(10):
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        labels = input_ids[:, 1:].contiguous()
        inputs = input_ids[:, :-1].contiguous()



        with torch.no_grad():
            t_logits = teacher(inputs).logits  # [B, L, V]
        s_logits = student(inputs)

        # Distillation loss
        t_probs = F.softmax(t_logits / T, dim=-1)
        s_log_probs = F.log_softmax(s_logits / T, dim=-1)
        kl = F.kl_div(s_log_probs, t_probs, reduction="batchmean") * (T * T)

        # Hard CE loss
        ce = F.cross_entropy(
            s_logits.view(-1, s_logits.size(-1)),
            labels.view(-1),
            ignore_index=tokenizer.pad_token_id,
        )

        loss = alpha * ce + (1 - alpha) * kl

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}: loss={loss.item():.4f}")


Epoch 0: loss=54.5110
Epoch 1: loss=43.7234
Epoch 2: loss=55.5211
Epoch 3: loss=18.7598
Epoch 4: loss=57.9055
Epoch 5: loss=40.7901
Epoch 6: loss=63.5930
Epoch 7: loss=27.3593
Epoch 8: loss=57.7921
Epoch 9: loss=27.0922


In [None]:
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)
alpha = 1

In [None]:
for epoch in range(10):
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        labels = input_ids[:, 1:].contiguous()
        inputs = input_ids[:, :-1].contiguous()



        with torch.no_grad():
            t_logits = teacher(inputs).logits  # [B, L, V]
        s_logits = vanilla_student(inputs)

        # Distillation loss
        t_probs = F.softmax(t_logits / T, dim=-1)
        s_log_probs = F.log_softmax(s_logits / T, dim=-1)
        kl = F.kl_div(s_log_probs, t_probs, reduction="batchmean") * (T * T)

        # Hard CE loss
        ce = F.cross_entropy(
            s_logits.view(-1, s_logits.size(-1)),
            labels.view(-1),
            ignore_index=tokenizer.pad_token_id,
        )

        loss = alpha * ce + (1 - alpha) * kl

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}: loss={loss.item():.4f}")


Epoch 0: loss=10.9947
Epoch 1: loss=10.9750
Epoch 2: loss=10.9391
Epoch 3: loss=10.9614
Epoch 4: loss=11.0002
Epoch 5: loss=10.9829
Epoch 6: loss=11.0245
Epoch 7: loss=11.0257
Epoch 8: loss=10.9148
Epoch 9: loss=10.9978


In [None]:
max_token = 10
prompt = "Once upon a time"

for t in range(max_token):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    out = teacher(input_ids)

    next_token_logits = out.logits[:, -1, :] # batch, seq_len, vocab_size

    next_token_id = next_token_logits.argmax(dim=-1) # take the greatest value of the logits in along the vocab_size

    next_token = tokenizer.decode(next_token_id) # decode and add
    prompt += next_token

print(prompt)

Once upon a time, the world was a place of great beauty and


In [None]:
max_token = 10
prompt = "Once upon a time"

for t in range(max_token):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    out = vanilla_student(input_ids)

    next_token_logits = out[:, -1, :] # batch, seq_len, vocab_size
    next_token_id = next_token_logits.argmax(dim=-1) # take the greatest value of the logits in along the vocab_size

    next_token = tokenizer.decode(next_token_id) # decode and add
    prompt += next_token

print(prompt)

Once upon a timews repeatedlyAutabyte statue circumcision Colourhesis regarded Expansion


In [None]:
max_token = 10
prompt = "Once upon a time"

for t in range(max_token):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    out = student(input_ids)

    next_token_logits = out[:, -1, :] # batch, seq_len, vocab_size
    next_token_id = next_token_logits.argmax(dim=-1) # take the greatest value of the logits in along the vocab_size

    next_token = tokenizer.decode(next_token_id) # decode and add
    prompt += next_token

print(prompt)

Once upon a time and the the the the the the the the the
