# Setup


In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# import os
# os.chdir('/content/drive/MyDrive/Projects/Distillation')

In [3]:
# !pip install datasets --quiet
# !pip install evaluate --quiet

In [4]:
import torch

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda


In [5]:
import random
import numpy as np

# Ensure experiment reproducibility
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
random.seed(SEED)
np.random.seed(SEED)

# Dataset



[HateXplain](https://huggingface.co/datasets/Hate-speech-CNERG/hatexplain) is a benchmark dataset for hate speech detection.   
Each record in the dataset includes four fields:
- id: A unique identifier for the post.
- annotators: A list of annotations made by 3 different annotators. Each annotator entry includes:  
  - label: The label given to the post. 0: hatespeech, 1: normal, 2: offensive.
  - annotator_id: A unique identifier for the annotator.
  - target: The target of the post.
- rationales: A list of binary arrays. 1 indicates a token that can justify the label assigned by the annotator.

- post_tokens: A list of tokens (words) of the post.

In [6]:
from datasets import load_dataset

dataset = load_dataset("Hate-speech-CNERG/hatexplain", trust_remote_code=True)

In [7]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 15383
    })
    validation: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 1922
    })
    test: Dataset({
        features: ['id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 1924
    })
})


In [8]:
# Example
print(dataset['train'][0])

{'id': '23107796_gab', 'annotators': {'label': [0, 2, 2], 'annotator_id': [203, 204, 233], 'target': [['Hindu', 'Islam'], ['Hindu', 'Islam'], ['Hindu', 'Islam', 'Other']]}, 'rationales': [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'post_tokens': ['u', 'really', 'think', 'i', 'would', 'not', 'have', 'been', 'raped', 'by', 'feral', 'hindu', 'or', 'muslim', 'back', 'in', 'india', 'or', 'bangladesh', 'and', 'a', 'neo', 'nazi', 'would', 'rape', 'me', 'as', 'well', 'just', 'to', 'see', 'me', 'cry']}


# Tokenizer and Teacher Model (BERT)

In [9]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

# tokenizer
MODEL_NAME = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# teacher model
ID2LABEL = {0: "hate_speech", 1: "normal", 2: "offensive"}
LABEL2ID = {label: idx for idx, label in ID2LABEL.items()}
NUM_LABELS = len(ID2LABEL)

teacher_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = NUM_LABELS,
    id2label = ID2LABEL,
    label2id = LABEL2ID,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Preprocessing



- Use the tokenizer from the pretrained BERT model (bert-base-uncased) to tokenize input text.

- Use majority vote to select the most common label as the final label.

- Create PyTorch DataLoaders.

In [10]:
# from collections import Counter
# from transformers import DataCollatorWithPadding
# from torch.utils.data import DataLoader

# """
# Use DataCollatorWithPadding. Dynamic padding, save memory
# """


# def preprocess_function(example):
#     """
#     Preprocesses a dataset example by tokenizing text and applying majority vote on annotator labels.
#     """
#     texts = " ".join(example["post_tokens"])
#     output = tokenizer(texts, max_length=128, truncation=True)  # padding is False

#     labels = example["annotators"]["label"]
#     output["label"] = Counter(labels).most_common(1)[0][0]

#     return output


# def create_dataloaders(dataset, batch_size):
#     data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

#     train_dataset = dataset["train"].shuffle(seed=42)
#     eval_dataset = dataset["validation"].shuffle(seed=42)
#     test_dataset = dataset["test"].shuffle(seed=42)

#     train_loader = DataLoader(
#         train_dataset,
#         batch_size=batch_size,
#         shuffle=True,
#         collate_fn=data_collator
#     )

#     eval_loader = DataLoader(
#         eval_dataset,
#         batch_size=batch_size,
#         collate_fn=data_collator
#     )

#     test_loader = DataLoader(
#         test_dataset,
#         batch_size=batch_size,
#         collate_fn=data_collator
#     )
#     return train_loader, eval_loader, test_loader


# dataset_preprocessed = dataset.map(preprocess_function)  # batched is False
# dataset_preprocessed = dataset_preprocessed.remove_columns(["id", "annotators", "rationales", "post_tokens"])

In [11]:
from collections import Counter
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

"""
Use DataCollatorWithPadding. Dynamic padding, save memory
"""


def preprocess_function(examples):
    """
    Preprocesses a batch of examples. Tokenize text. Apply majority vote on annotator labels.
    """
    texts = [" ".join(tokens) for tokens in examples["post_tokens"]]
    output = tokenizer(texts, max_length=128, truncation=True)  # padding is False
    
    output["label"] = [Counter(annotator["label"]).most_common(1)[0][0] for annotator in examples["annotators"]]
    return output


def create_dataloaders(dataset, batch_size):
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    train_dataset = dataset["train"].shuffle(seed=42)
    eval_dataset = dataset["validation"].shuffle(seed=42)
    test_dataset = dataset["test"].shuffle(seed=42)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator
    )

    eval_loader = DataLoader(
        eval_dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    return train_loader, eval_loader, test_loader


dataset_preprocessed = dataset.map(preprocess_function, batched=True)
dataset_preprocessed = dataset_preprocessed.remove_columns(["id", "annotators", "rationales", "post_tokens"])

In [12]:
# Example dataset_preprocessed
print(dataset_preprocessed)
print("-- Example --")
print(dataset_preprocessed['train'][0])

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'label'],
        num_rows: 15383
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'label'],
        num_rows: 1922
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'label'],
        num_rows: 1924
    })
})
-- Example --
{'input_ids': [101, 1057, 2428, 2228, 1045, 2052, 2025, 2031, 2042, 15504, 2011, 18993, 7560, 2030, 5152, 2067, 1999, 2634, 2030, 7269, 1998, 1037, 9253, 6394, 2052, 9040, 2033, 2004, 2092, 2074, 2000, 2156, 2033, 5390, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'label': 2}


In [13]:
# # Example train_loader
# train_loader, eval_loader, test_loader = create_dataloaders(dataset_preprocessed, batch_size=32)

# for batch in train_loader:
#     for k, v in batch.items():
#         print(f"{k}: {v.shape}")   
#     break

# Model




> The student has the same general architecture as BERT.
- The token-type embeddings and the pooler are removed while the number of layers is reduced by a factor of 2.
- Most of the operations used in the Transformer architecture—linear layer and layer normalisation—are highly optimized in modern linear algebra frameworks.
- Our investigations showed that variations on the last dimension of the tensor (hidden size dimension) have a smaller impact on computation efficiency (for a fixed parameters budget) than variations on other factors like the number of layers.
- Thus we focus on reducing the number of layers.
>
> — Sanh, Victor, et al. *DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter.* arXiv:1910.01108 (2019)

## Teacher Model (BERT)

In [14]:
# from transformers import AutoModelForSequenceClassification
# help(AutoModelForSequenceClassification.from_pretrained)

# from transformers import PretrainedConfig
# help(PretrainedConfig)

In [15]:
teacher_model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [16]:
teacher_model.config

BertConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "hate_speech",
    "1": "normal",
    "2": "offensive"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "hate_speech": 0,
    "normal": 1,
    "offensive": 2
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.51.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

## Student Model (DistilBERT)

- The student model has the same configuration as the teacher model, but with half the number of layers.
- Initialize the layers of student model. Copy one out of two encoder layers, and all other layers from the teacher model.





In [17]:
from transformers.models.bert.modeling_bert import BertConfig


def build_student_model(teacher_model):
    student_config = teacher_model.config.to_dict()
    student_config['num_hidden_layers'] //= 2  # half the number of hidden layer
    student_config = BertConfig.from_dict(student_config)
    model = type(teacher_model)(student_config)
    return model


def init_student_layers(teacher, student, use_layer = "odd"):
    """
    Initialize the layers of student model.
    Copy one out of two encoder layers, and all other layers from the teacher model.

    Params:
        teacher: The teacher model.
        student: The student model.
        use_layer: Whether to use odd {1, 3, 5, 7, 9, 11} or even {2, 4, 6, 8, 10, 12} layers from the teacher model.
    """

    # Copy one out of two encoder layers
    teacher_encoder_layers = teacher.bert.encoder.layer
    student_encoder_layers = student.bert.encoder.layer
    # print(len(teacher_encoder_layers), len(student_encoder_layers))  # 12, 6
    layer_indices = range(1, 12, 2) if use_layer == "odd" else range(0, 12, 2)  #  select odd or even layers
    for i, layer_idx in enumerate(layer_indices):
        student_encoder_layers[i].load_state_dict(teacher_encoder_layers[layer_idx].state_dict())

    # Copy all other layers
    student.bert.embeddings.load_state_dict(teacher.bert.embeddings.state_dict())
    student.bert.pooler.load_state_dict(teacher.bert.pooler.state_dict())
    student.classifier.load_state_dict(teacher.classifier.state_dict())

    return student


def count_parameters(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)


# student_model = build_student_model(teacher_model)
# print(student_model)
# student_model = init_student_layers(teacher_model, student_model, use_layer="odd")
# print()
# print('Teacher parameters :', count_parameters(teacher_model))
# print('Student parameters :', count_parameters(student_model))
# print('Compression ratio:', count_parameters(student_model) / count_parameters(teacher_model))

# Loss Function

## Weighted Loss
Let $\alpha$, $\beta$ be weighting factors, $0 < \alpha, \beta < 1$ and $\alpha + \beta < 1$,
$$\mathcal{L} = \alpha \mathcal{L}_\text{cls} + \beta \mathcal{L}_\text{KD} + (1- \alpha - \beta)\mathcal{L}_{\text{cos}}$$

## Softmax
The softmax function converts a vector of raw scores (logits) $\pmb{z} = [z^{(1)}, z^{(2)}, \cdots, z^{(k)}]$ into a probability distribution over $k$ classes. The temperature parameter $T > 0$ controls the smoothness of the resulting probability distribution.
$$
\text{softmax}(\pmb{z}^{(j)}, T) = \frac{ \exp ( \pmb{z}^{(j)}  / T) }{\sum_{q=1}^{k} \exp ( \pmb{z}^{(q)} / T )}
$$
$T=1$, standard softmax;  
$T < 1$, sharper distribution, more confident predictions (the largest logits dominate);  
$T > 1$, smoother distribution, more uncertainty (probabilities are more uniform).

## Cross-Entropy Loss
Let $\pmb{p}$ be the target probability distribution and $\pmb{q}$ be the input probability distribution, both over $k$ classes. The Cross-Entropy measures the difference between these two distributions as:
$$
H(\pmb{p},\pmb{q})= - \sum_{j=1}^{k} \pmb{p}^{(j)} \log \pmb{q}^{(j)}
$$

## KL Divergence
$$D_{\mathrm{KL}}(\pmb{p} \,\|\, \pmb{q}) = \sum_{j=1}^{k} \pmb{p}^{(j)} \log \frac{\pmb{p}^{(j)}}{\pmb{q}^{(j)}} = \sum_{j=1}^{k} \pmb{p}^{(j)} (\log \pmb{p}^{(j)} - \log \pmb{q}^{(j)})$$

$$
H(\pmb{p}, \pmb{q}) = H(\pmb{p}) + D_{\mathrm{KL}}(\pmb{p} \,\|\, \pmb{q})
$$
Since $H(\pmb{p})$ is fixed, minimizing cross-entropy is equivalent to minimizing the KL divergence between $\pmb{p}$ and $\pmb{q}$.

## Classification Loss
$$\mathcal{L}_\text{cls} = H( \pmb{y}_\text{true}, \pmb{y}_\text{pred\_student} )$$

## Knowledge Distillation Loss
Both cross-entropy and KL divergence can be used for knowledge distillation loss, as they result in the same gradients and optimization updates.

KL divergence is often preferred, since it equals to zero when the student's output distribution exactly matches the teacher's. While cross-entropy includes an constant offset — the entropy $H(\pmb{p})$ of the teacher's output distribution. This constant does not affect the optimization process, but it introduces noise into the loss value across batches, making training curves less interpretable.

The multiplication by $T^2$ corrects for the gradient scaling effect introduced by the temperature during backpropagation.

$$\mathcal{L}_\text{KD} = H(\text{softmax}(\pmb{z}_{\text{teacher}}, T), \text{softmax}(\pmb{z}_{\text{student}}, T)) \cdot T^2 $$

$$\mathcal{L}_\text{KD} = D_{\mathrm{KL}}(\text{softmax}(\pmb{z}_{\text{teacher}}, T) \,\|\, \text{softmax}(\pmb{z}_{\text{student}}, T)) \cdot T^2 $$

In PyTorch, `F.cross_entropy(input, target)` expects logits as `input`, and class indices or class probabilities as `target`.   
`F.kl_div(input, target)` expects log-probabilities as `input`, and probabilities as `target`.

## Cosine Embedding Loss
$$\mathcal{L}_\text{cos} = 1 - \text{cosine\_similarity}(\pmb{z}_\text{student}, \pmb{z}_\text{teacher})$$

In [18]:
import torch.nn as nn
import torch.nn.functional as F

class KDLoss(nn.Module):
    """
    Compute the knowledge distillation loss
    """

    def __init__(self):
        super().__init__()

    def forward(self, logits_student, logits_teacher, T):
        """
        Params:
            logits_student: The logits of the student model.
            logits_teacher: The logits of the teacher model.
            T: The temperature parameter.
        """
        loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(logits_student/T, dim=-1),
            F.softmax(logits_teacher/T, dim=-1)
        ) * T * T

        return loss


# # Loss functions
# criterion_cls = nn.CrossEntropyLoss()  # Classification Loss
# criterion_kd = KDLoss()  # Knowledge Distillation Loss
# criterion_cos = nn.CosineEmbeddingLoss()  # Cosine Embedding Loss

# Trainer

In [19]:
import evaluate

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(preds, labels):
    accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]    
    f1 = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
    return accuracy, f1

In [20]:
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
import os


class KDTrainer:
    def __init__(self, teacher_model, student_model, train_loader, eval_loader, optimizer, scheduler=None, save_name="exp1"):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.train_loader = train_loader
        self.eval_loader = eval_loader
        self.optimizer = optimizer
        self.scheduler = scheduler

        self.criterion_cls = None
        self.criterion_kd = None
        self.criterion_cos = None
        self._init_criterions()

        self.save_name = save_name
        self.writer = SummaryWriter(log_dir=f"logs/{save_name}")  # Tensorboard writer

    def _init_criterions(self):
        self.criterion_cls = nn.CrossEntropyLoss()
        self.criterion_kd = KDLoss()
        self.criterion_cos = nn.CosineEmbeddingLoss()

    def train(self, num_epochs, T=1, alpha=0.5, beta=0.3):
        train_histories, eval_histories = [], []

        for epoch in range(num_epochs):
            print(f"\n-- Epoch {epoch + 1}/{num_epochs} --")

            train_history = self.train_one_epoch(T, alpha, beta)
            eval_history = self.evaluate()

            train_histories.append(train_history)
            eval_histories.append(eval_history)

            # print results
            print(f"Train\tLoss: {train_history['train_loss']:.4f}")
            print(f"\tcls: {train_history['train_loss_cls']:.4f}")
            print(f"\tkd: {train_history['train_loss_kd']:.4f}")
            print(f"\tcos: {train_history['train_loss_cos']:.4f}")
            print(f"Eval\tLoss: {eval_history['eval_loss']:.4f}")
            print(f"\tAccuracy: {eval_history['eval_accuracy']:.4f}")
            print(f"\tF1: {eval_history['eval_f1']:.4f}")

            # Log to TensorBoard
            self.writer.add_scalar("Loss/train", train_history['train_loss'], epoch)
            self.writer.add_scalar("Loss/train_cls", train_history['train_loss_cls'], epoch)
            self.writer.add_scalar("Loss/train_kd", train_history['train_loss_kd'], epoch)
            self.writer.add_scalar("Loss/train_cos", train_history['train_loss_cos'], epoch)
            self.writer.add_scalar("Loss/eval", eval_history['eval_loss'], epoch)
            self.writer.add_scalar("Metrics/accuracy", eval_history['eval_accuracy'], epoch)
            self.writer.add_scalar("Metrics/f1", eval_history['eval_f1'], epoch)

        self.writer.close()  # close Tensorboard writer
        
        return train_histories, eval_histories

    def train_one_epoch(self, T, alpha, beta):
        self.teacher_model.eval()
        self.student_model.train()

        total = 0
        total_loss, total_loss_cls, total_loss_kd, total_loss_cos = 0, 0, 0, 0

        for batch in tqdm(self.train_loader, desc="Training", leave=False):
            # move batch data to device
            batch = {k: v.to(DEVICE) for k, v in batch.items()}

            # clear grad
            self.optimizer.zero_grad()

            # forward
            outputs_student = self.student_model(**batch)
            with torch.no_grad():
                outputs_teacher = self.teacher_model(**batch)

            # compute loss
            loss_cls = self.criterion_cls(outputs_student.logits, batch["labels"])
            loss_kd = self.criterion_kd(outputs_student.logits, outputs_teacher.logits, T)
            loss_cos = self.criterion_cos(
                outputs_teacher.logits,
                outputs_student.logits,
                torch.ones(outputs_teacher.logits.size(0)).to(DEVICE)
            )
            loss = alpha * loss_cls + beta * loss_kd + (1.0 - alpha - beta) * loss_cos

            # backward
            loss.backward()
            self.optimizer.step()
            if self.scheduler:
                self.scheduler.step()

            # total loss
            batch_size = batch["labels"].size(0)
            total += batch_size
            total_loss += loss.item() * batch_size
            total_loss_cls += loss_cls.item() * batch_size
            total_loss_kd += loss_kd.item() * batch_size
            total_loss_cos += loss_cos.item() * batch_size

        # average loss
        history = {
            "train_loss": total_loss / total,
            "train_loss_cls": total_loss_cls / total,
            "train_loss_kd": total_loss_kd / total,
            "train_loss_cos": total_loss_cos / total,
        }

        return history

    def evaluate(self):
        self.student_model.eval()

        total = 0
        total_loss = 0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in tqdm(self.eval_loader, desc="Evaluating", leave=False):
                # move batch data to device
                batch = {k: v.to(DEVICE) for k, v in batch.items()}

                # forward
                outputs = self.student_model(**batch)

                # loss
                loss = self.criterion_cls(outputs.logits, batch["labels"])

                # total loss
                batch_size = batch["labels"].size(0)
                total += batch_size
                total_loss += loss.item() * batch_size

                preds = outputs.logits.argmax(dim=-1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(batch["labels"].cpu().numpy())

        # average loss
        eval_loss = total_loss / total

        accuracy, f1 = compute_metrics(all_preds, all_labels)

        history = {
            "eval_loss": eval_loss,
            "eval_accuracy": accuracy,
            "eval_f1": f1
        }

        return history

    def save_model(self):
        save_path = f"models/{self.save_name}.pt"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(self.student_model.state_dict(), save_path)
        print(f"Student model saved to: {save_path}")


# Train

In [None]:
import torch.optim as optim
from transformers import get_scheduler

LEARNING_RATE = 5e-5
NUM_EPOCHS = 5
T = 1
BATCH_SIZE = 32
ALPHA = 1/3
BETA = 1/3

# Create dataloaders
train_loader, eval_loader, test_loader = create_dataloaders(dataset_preprocessed, batch_size=BATCH_SIZE)

# Init student model
student_model = build_student_model(teacher_model)
student_model = init_student_layers(teacher_model, student_model, use_layer="odd")

# To device
teacher_model = teacher_model.to(DEVICE)
student_model = student_model.to(DEVICE)

print('Teacher parameters:', count_parameters(teacher_model))
print('Student parameters:', count_parameters(student_model))
compression_ratio = count_parameters(student_model) / count_parameters(teacher_model)
print(f'Compression ratio: {compression_ratio:.4f}')

optimizer = optim.AdamW(student_model.parameters(), lr=LEARNING_RATE)
num_training_steps = NUM_EPOCHS * len(train_loader)
scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

kdtrainer = KDTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    eval_loader=eval_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    save_name="exp1"
)

_, _ = kdtrainer.train(num_epochs=NUM_EPOCHS, T=T, alpha=ALPHA, beta=BETA)

kdtrainer.save_model()

In [None]:
# !tensorboard --logdir=logs

# Grid Search

In [23]:
ALPHA_BETAS = [
    (0.5, 0.3),
    (0.7, 0.2)
]
TEMPERATURES = [1, 2, 3]
LEARNING_RATES = [5e-5, 3e-5, 1e-5]
NUM_EPOCHS = [5]
BATCH_SIZE = 32

In [None]:
import torch.optim as optim
from transformers import get_scheduler
from itertools import product
import pandas as pd

configs = []

for alpha_beta, T, lr, num_epochs in product(ALPHA_BETAS, TEMPERATURES, LEARNING_RATES, NUM_EPOCHS):
    alpha, beta = alpha_beta
    print(f"\nRunning config: alpha={alpha:.2f}, beta={beta:.2f}, T={T}, lr={lr}, num_epochs={num_epochs}")
    save_name = f"A{alpha:.2f}_T{T}_LR{lr}_NE{num_epochs}"
    
    # Create dataloaders
    train_loader, eval_loader, test_loader = create_dataloaders(dataset_preprocessed, batch_size=BATCH_SIZE)
    
    # Init student model
    student_model = build_student_model(teacher_model)
    student_model = init_student_layers(teacher_model, student_model, use_layer="odd")
    
    # To device
    teacher_model = teacher_model.to(DEVICE)
    student_model = student_model.to(DEVICE)

    optimizer = optim.AdamW(student_model.parameters(), lr=lr)
    scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_epochs * len(train_loader)
    )

    kdtrainer = KDTrainer(
        teacher_model=teacher_model,
        student_model=student_model,
        train_loader=train_loader,
        eval_loader=eval_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        save_name=save_name
    )

    _, eval_metrics = kdtrainer.train(
        num_epochs=num_epochs,
        T=T,
        alpha=alpha,
        beta=beta
    )
    
    kdtrainer.save_model()
    
    f1 = eval_metrics[-1]["eval_f1"]
    configs.append({
        "alpha": alpha,
        "beta": beta,
        "T": T,
        "lr": lr,
        "num_epochs": num_epochs,
        "eval_f1": f1
    })
    
    df = pd.DataFrame(configs)
    df.to_csv("grid_results.csv", index=False)
