In [1]:
import comet_ml
import numpy as np
import torch

In [2]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer

student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)

dataset = load_dataset("imdb")

def pre_process(examples):
    return tokenizer(examples["text"], truncation = True, max_length = 512)

tokenized_data = dataset.map(pre_process, batched = True)

#test_valid = tokenized_data['test'].train_test_split(test_size=0.5)
#tokenized_data = DatasetDict({
#    'train': tokenized_data['train'],
#    'test': test_valid['train'],
#    'validation': test_valid['test']
#})


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#comet_ml.init(project_name="distilbert_dotprod")

In [4]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

2024-08-20 04:17:48.495594: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-20 04:17:48.519275: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-20 04:17:48.519298: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-20 04:17:48.535186: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = None, alpha_cos = None, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_cos = alpha_cos
        self.teacher.eval()
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

    def distillation_loss(self, student_logits, teacher_logits):
        #soft target probabilities
        soft_student = F.log_softmax(student_logits / self.temperature, dim = -1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim = -1)
        #Kullback Leibler Divergence
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction = 'batchmean') * (self.temperature**2) 
        return distill_loss

    def cosine_embedding_loss(self, student_outputs, teacher_outputs, attention_mask):
        #cosine embedding loss
        s_hidden_states = student_outputs.hidden_states[-1]  # (bs, seq_length, dim)
        t_hidden_states = teacher_outputs.hidden_states[-1]  # (bs, seq_length, dim)
        
        attention_mask = attention_mask.bool()
        mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
        assert s_hidden_states.size() == t_hidden_states.size()
        dim = s_hidden_states.size(-1)

        s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
        s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
        t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
        t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

        target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
        loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
        return loss_cos

    def compute_loss(self, model, inputs, return_outputs = False):
        #Distillation loss over soft target probabilities of teacher and student, KL DIV
        #Cosine embedding loss
        #supervised training loss
        #Attention Score Alignment???
        
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        
        student_loss = student_outputs.loss
        
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits
            
        l_ce = self.distillation_loss(student_logits, teacher_logits)
        
        l_cos = self.cosine_embedding_loss(student_outputs, teacher_outputs, inputs["attention_mask"]) if self.alpha_cos > 0 else 0

        #Combine losses
        loss = self.alpha_ce * l_ce + l_cos * self.alpha_cos + student_loss * (1 - (self.alpha_ce + self.alpha_cos)) 
        
        return (loss, student_outputs) if return_outputs else loss
         
        

In [6]:
labels = tokenized_data['train'].features['label'].names
num_labels = len(labels)
label2id, id2label = {}, {}

for idx, lbl in enumerate(labels):
    label2id[lbl] = idx
    id2label[idx] = lbl
print(label2id)
print(id2label)

{'neg': 0, 'pos': 1}
{0: 'neg', 1: 'pos'}


In [7]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding
from iDistilbert import iDistilBertForSequenceClassification

#Load Models
teacher_id = "textattack/bert-base-uncased-imdb"
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
    output_hidden_states=True,
)

student_config = DistilBertConfig(
    output_hidden_states = True,
    distance_metric = "manhattan_distance",
    activation_function = "relu",
    signed_inhibitor =  True,
    alpha = 0,
    center = True,
    num_labels = num_labels,
    )
student_model = iDistilBertForSequenceClassification(student_config)

initialized_weights = torch.load('/mnt/tony/MSc2024/distilbert_init/models/weight_opt_iDistilbert.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

iDistilBertForSequenceClassification(
  (distilbert): iDistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): iTransformer(
      (layer): ModuleList(
        (0-5): 6 x iTransformerBlock(
          (attention): iMultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=Fal

In [8]:
import evaluate
import numpy as np

#experiment = comet_ml.get_global_experiment()

accuracy = evaluate.load("accuracy")

def preprocess_logits_for_metrics(logits, labels):
    """
    Preprocess the logits to ensure they are in the correct format for metric computation.
    This function will be called during the evaluation process.
    """
    if isinstance(logits, tuple):  
        logits = logits[0]  # get logit tensors

    pred_ids = torch.argmax(logits, dim=-1)
    
    return pred_ids, labels
    
def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred

    return accuracy.compute(predictions=predictions[0], references=labels)



In [9]:

%env COMET_MODE=ONLINE
%env COMET_LOG_ASSETS=TRUE

EPOCHS = 4
BATCH_SIZE = 2
LEARNING_RATE = 2e-5

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    metric_for_best_model="accuracy",
    eval_strategy="steps",
    eval_steps = 500,
    save_strategy="steps",
    save_total_limit=2,
    seed = 42,
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
)

env: COMET_MODE=ONLINE
env: COMET_LOG_ASSETS=TRUE


In [10]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_data['train'],         
    eval_dataset=tokenized_data['test'],
    compute_metrics = compute_metrics,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
    temperature = 5,
    alpha_ce = 0.3,
    alpha_cos = 0.2,
    tokenizer = tokenizer,
    data_collator = data_collator,
)


In [None]:
trainer.train()



Step,Training Loss,Validation Loss,Accuracy
500,1.7761,1.351676,0.7984
1000,1.4143,1.196869,0.83796
1500,1.3543,1.167667,0.84968
2000,1.2405,1.031955,0.8548
2500,1.1434,1.132635,0.82816
3000,1.125,1.085184,0.8456
3500,0.988,1.088729,0.87552




In [None]:
student_config

In [13]:
trainer.save_model('./models')