In [None]:
import numpy as np
import torch

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer

student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)

dataset = load_dataset("imdb")

def pre_process(examples):
    return tokenizer(examples["text"], truncation = True, max_length = 512)

tokenized_data = dataset.map(pre_process, batched = True)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(tokenized_data)

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 50000
    })
})


In [4]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

2024-06-13 14:55:00.358572: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = 0.25, alpha_cos = 0.25, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_cos = alpha_cos
        self.teacher.eval()

    def distillation_loss(self, student_logits, teacher_logits):
        #soft target probabilities
        soft_student = F.log_softmax(student_logits / self.temperature, dim = -1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim = -1)
        #distill loss
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction = 'batchmean') * (self.temperature**2) 
        return distill_loss

    def cosine_embedding_loss(self, student_outputs, teacher_outputs):
        #cosine embedding loss
        teacher_hidden = torch.stack(teacher_outputs.hidden_states, dim = -1)
        student_hidden = torch.stack(student_outputs.hidden_states, dim = -1)
        assert student_hidden.size() == teacher_hidden.size(), "Hidden State Size Dont Match"
        cosine_embedding_loss = torch.mean(1 - F.cosine_similarity(student_hidden, teacher_hidden, dim = -2))
        return cosine_embedding_loss

    def compute_loss(self, model, inputs, return_outputs = False):
        #Distillation loss over soft target probabilities of teacher and student, KL DIV
        #Cosine embedding loss
        #supervised training loss
        #Attention Score Alignment???
        
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        
        student_loss = student_outputs.loss
        
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits
            
        l_ce = self.distillation_loss(student_logits, teacher_logits)
        l_cos = self.cosine_embedding_loss(student_outputs, teacher_outputs)

        #Combine losses
        loss = self.alpha_ce * l_ce + l_cos * self.alpha_cos + student_loss * (1 - self.alpha_ce + self.alpha_cos) 
        
        return (loss, student_outputs) if return_outputs else loss
         
        

In [6]:
labels = tokenized_data['train'].features['label'].names
num_labels = len(labels)
label2id, id2label = {}, {}

for idx, lbl in enumerate(labels):
    label2id[lbl] = idx
    id2label[idx] = lbl
print(label2id)
print(id2label)

{'neg': 0, 'pos': 1}
{0: 'neg', 1: 'pos'}


In [7]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding

#Load Models
teacher_id = "lvwerra/distilbert-imdb"
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
    output_hidden_states=True,
    ignore_mismatched_sizes=True,
)

student_config = DistilBertConfig(output_hidden_states = True)
student_model = DistilBertForSequenceClassification(student_config)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [8]:
EPOCHS = 2
BATCH_SIZE = 4
LEARNING_RATE = 0.00004

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    metric_for_best_model="accuracy",
    eval_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
)

In [9]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return accuracy.compute(predictions=predictions[0], references=labels)

def preprocess_logits_for_metrics(logits, labels):
    """
    Preprocess the logits to ensure they are in the correct format for metric computation.
    This function will be called during the evaluation process.
    """
    if isinstance(logits, tuple):  
        logits = logits[0]  # get logit tensors

    pred_ids = torch.argmax(logits, dim=-1)
    
    return pred_ids, labels

In [10]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_data['train'],         
    eval_dataset=tokenized_data['test'],
    compute_metrics = compute_metrics,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
    temperature = 3.0,
    alpha_ce = 0.25,
    alpha_cos = 0.25,
    tokenizer = tokenizer,
    data_collator = data_collator,
)


In [11]:
trainer.train()



Step,Training Loss,Validation Loss,Accuracy
500,1.0142,0.708833,0.82796
1000,0.6443,0.596317,0.86896
1500,0.5806,0.647774,0.861
2000,0.4528,0.582055,0.878
2500,0.4128,0.544595,0.88232
3000,0.4111,0.545567,0.88084




TrainOutput(global_step=3126, training_loss=0.5793175224455518, metrics={'train_runtime': 3624.6267, 'train_samples_per_second': 13.795, 'train_steps_per_second': 0.862, 'total_flos': 6556904415524352.0, 'train_loss': 0.5793175224455518, 'epoch': 2.0})

In [None]:
from torchinfo import summary
summary(teacher_model)

In [58]:
summary(student_model)

Layer (type:depth-idx)                                  Param #
DistilBertForSequenceClassification                     --
├─DistilBertModel: 1-1                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              23,440,896
│    │    └─Embedding: 3-2                              393,216
│    │    └─LayerNorm: 3-3                              1,536
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             42,527,232
├─Linear: 1-2                                           590,592
├─Linear: 1-3                                           1,538
├─Dropout: 1-4                                          --
Total params: 66,955,010
Trainable params: 66,955,010
Non-trainable params: 0

In [None]:
trainer.eval()