In [14]:
import comet_ml
import numpy as np
import torch

In [15]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer

student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)

dataset = load_dataset("imdb")

def pre_process(examples):
    return tokenizer(examples["text"], truncation = True, max_length = 512)

tokenized_data = dataset.map(pre_process, batched = True)

test_valid = tokenized_data['test'].train_test_split(test_size=0.5)
tokenized_data = DatasetDict({
    'train': tokenized_data['train'],
    'test': test_valid['train'],
    'validation': test_valid['test']
})

print(tokenized_data)

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 12500
    })
    validation: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 12500
    })
})


In [16]:
comet_ml.init(project_name="distilbert_dotprod")

In [17]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = None, alpha_cos = None, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_cos = alpha_cos
        self.teacher.eval()

    def distillation_loss(self, student_logits, teacher_logits):
        #soft target probabilities
        soft_student = F.log_softmax(student_logits / self.temperature, dim = -1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim = -1)
        #Kullback Leibler Divergence
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction = 'batchmean') * (self.temperature**2) 
        return distill_loss

    def cosine_embedding_loss(self, student_outputs, teacher_outputs):
        #cosine embedding loss
        teacher_hidden = torch.stack(teacher_outputs.hidden_states, dim = -1)
        student_hidden = torch.stack(student_outputs.hidden_states, dim = -1)
        assert student_hidden.size() == teacher_hidden.size(), "Hidden State Size Dont Match"
        cosine_embedding_loss = torch.mean(1 - F.cosine_similarity(student_hidden, teacher_hidden, dim = -2))
        return cosine_embedding_loss

    def compute_loss(self, model, inputs, return_outputs = False):
        #Distillation loss over soft target probabilities of teacher and student, KL DIV
        #Cosine embedding loss
        #supervised training loss
        #Attention Score Alignment???
        
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        
        student_loss = student_outputs.loss
        
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits
            
        l_ce = self.distillation_loss(student_logits, teacher_logits)
        
        l_cos = 0
        if self.alpha_cos > 0:
            l_cos += self.cosine_embedding_loss(student_outputs, teacher_outputs)

        #Combine losses
        loss = self.alpha_ce * l_ce + l_cos * self.alpha_cos + student_loss * (1 - (self.alpha_ce + self.alpha_cos)) 
        
        return (loss, student_outputs) if return_outputs else loss
         
        

In [19]:
labels = tokenized_data['train'].features['label'].names
num_labels = len(labels)
label2id, id2label = {}, {}

for idx, lbl in enumerate(labels):
    label2id[lbl] = idx
    id2label[idx] = lbl
print(label2id)
print(id2label)

{'neg': 0, 'pos': 1}
{0: 'neg', 1: 'pos'}


In [20]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding

#Load Models
teacher_id = "lvwerra/distilbert-imdb"
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
    output_hidden_states=True,
)

student_config = DistilBertConfig(output_hidden_states = True)
student_model = DistilBertForSequenceClassification(student_config)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [21]:
import evaluate
import numpy as np

experiment = comet_ml.get_global_experiment()

accuracy = evaluate.load("accuracy")

def preprocess_logits_for_metrics(logits, labels):
    """
    Preprocess the logits to ensure they are in the correct format for metric computation.
    This function will be called during the evaluation process.
    """
    if isinstance(logits, tuple):  
        logits = logits[0]  # get logit tensors

    pred_ids = torch.argmax(logits, dim=-1)
    
    return pred_ids, labels
    
def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred

    return accuracy.compute(predictions=predictions[0], references=labels)



In [22]:
%env COMET_MODE=ONLINE
%env COMET_LOG_ASSETS=TRUE

EPOCHS = 2
BATCH_SIZE = 2
LEARNING_RATE = 0.00002

training_args = TrainingArguments(
    output_dir = './task_specific/results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './task_specific/logs',
    load_best_model_at_end= True,
    metric_for_best_model="accuracy",
    eval_strategy="steps",
    eval_steps = 500,
    save_strategy="steps",
    save_total_limit=2,
    report_to=['comet_ml', 'tensorboard'],
)

env: COMET_MODE=ONLINE
env: COMET_LOG_ASSETS=TRUE


In [23]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_data['train'],         
    eval_dataset=tokenized_data['validation'],
    compute_metrics = compute_metrics,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
    temperature = 5,
    alpha_ce = 0.25,
    alpha_cos = 0.25,
    tokenizer = tokenizer,
    data_collator = data_collator,
)


In [24]:
trainer.train()

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/tonytonfisk2/distilbert-dotprod/020618a9575a4505b402da6b3e35da3d



Step,Training Loss,Validation Loss,Accuracy
500,0.8703,0.592735,0.81328
1000,0.5511,0.466201,0.85584
1500,0.4824,0.469772,0.85584
2000,0.4067,0.474402,0.86888
2500,0.3653,0.414717,0.87376
3000,0.3667,0.41112,0.87648


[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : hilarious_singularity_2458
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/tonytonfisk2/distilbert-dotprod/020618a9575a4505b402da6b3e35da3d
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     epoch [13]                     : (0.3198976327575176, 2.0)
[1;38;5;39mCOMET INFO:[0m     eval/accuracy [6]              : (0.81328, 0.87648)
[1;38;5;39mCOMET INFO:[0m     eval/loss [6]                  : (0.4111196994781494, 0.5927345156669617)
[1;38;5;39mCOMET INFO:[0m     eval/runtime [6]     

TrainOutput(global_step=3126, training_loss=0.5011479358831736, metrics={'train_runtime': 2513.4656, 'train_samples_per_second': 19.893, 'train_steps_per_second': 1.244, 'total_flos': 6556904415524352.0, 'train_loss': 0.5011479358831736, 'epoch': 2.0})

In [1]:
trainer.evaluate(tokenized_data['test'])

NameError: name 'trainer' is not defined