In [1]:
#import comet_ml
import numpy as np
import torch
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm
2024-10-17 22:38:36.076804: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-17 22:38:36.089067: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-17 22:38:36.101918: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-17 22:38:36.105757: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-17 22:38:36.1

In [2]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer

dataset = load_dataset("glue", "cola")

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def preprocess_function(examples):
    return tokenizer(examples['sentence'], truncation=True)

tokenized_dataset = dataset.map(preprocess_function, batched=True)




In [3]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = None, alpha_cos = None, alpha_mlm = None, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_cos = alpha_cos
        self.alpha_mlm = alpha_mlm
        self.teacher.eval()
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
        self.layer_logs = []

    def distillation_loss(self, student_logits, teacher_logits):
        #soft target probabilities
        soft_student = F.log_softmax(student_logits / self.temperature, dim = -1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim = -1)
        #Kullback Leibler Divergence
        loss = F.kl_div(soft_student, soft_teacher, reduction = 'batchmean') * (self.temperature**2) 
        return loss

    def cosine_embedding_loss(self, student_outputs, teacher_outputs):
        s_hidden_states = student_outputs.hidden_states[-1]
        t_hidden_states = teacher_outputs.hidden_states[-1]
        assert t_hidden_states.size() == s_hidden_states.size()
        dim = s_hidden_states.size(-1)
        s_hidden_states_slct = s_hidden_states.view(-1, dim)
        t_hidden_states_slct = t_hidden_states.view(-1, dim)

        target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) 
        loss = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
        return loss

    def compute_loss(self, model, inputs, return_outputs = False):
        #Distillation loss over soft target probabilities of teacher and student, KL DIV
        #Cosine embedding loss
        #supervised training loss
        #Attention Score Alignment???
        
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        
        student_loss = student_outputs.loss
        
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits
            
        l_ce = self.distillation_loss(student_logits, teacher_logits)
        
        l_cos = self.cosine_embedding_loss(student_outputs, teacher_outputs) if self.alpha_cos > 0 else 0

        #Combine losses
        loss = self.alpha_ce * l_ce + l_cos * self.alpha_cos + student_loss * self.alpha_mlm
        
        return (loss, student_outputs) if return_outputs else loss
         
        

In [4]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding
from iDistilbert import iDistilBertForSequenceClassification

#Load Models
teacher_id = "JeremiahZ/bert-base-uncased-cola"
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = 2,
    output_hidden_states = True,
)

student_config = DistilBertConfig(
    output_hidden_states = True,
    distance_metric = "manhattan_distance",
    activation_function = "relu",
    signed_inhibitor =  True,
    center = True,
    num_labels = 2,
    )

student_model = iDistilBertForSequenceClassification(
    config = student_config
)

initialized_weights = torch.load('/shared/Tony/MSc2024/KD_weight_init/models/hiddenstates4_center_inhibitor_init.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device) 
student_model.to(device)

  initialized_weights = torch.load('/shared/Tony/MSc2024/distilbert_init/models/hiddenstates4_center_inhibitor_init.pth')


iDistilBertForSequenceClassification(
  (distilbert): iDistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): iTransformer(
      (layer): ModuleList(
        (0-5): 6 x iTransformerBlock(
          (attention): iMultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=Fal

In [5]:
import evaluate
import numpy as np


glue_metric = evaluate.load("glue", "cola")
def compute_metrics(eval_pred):

    logits, labels = eval_pred
    
    if isinstance(logits, tuple):  
        logits = logits[0]  # get logit tensors
    predictions = np.argmax(logits, axis=-1)

    return glue_metric.compute(predictions=predictions, references=labels)

In [6]:

%env COMET_MODE=ONLINE
%env COMET_LOG_ASSETS=TRUE

EPOCHS = 10
BATCH_SIZE = 8
LEARNING_RATE = 2e-5

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    metric_for_best_model="matthews_correlation",
    eval_strategy="epoch",
    save_strategy="epoch",
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
    lr_scheduler_type="linear",
    #gradient_accumulation_steps=4,
    weight_decay = 0.01,
)

env: COMET_MODE=ONLINE
env: COMET_LOG_ASSETS=TRUE


In [7]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_dataset['train'],         
    eval_dataset=tokenized_dataset['validation'],
    compute_metrics = compute_metrics,
    temperature = 4,
    alpha_ce = 0,
    alpha_cos = 0,
    alpha_mlm = 1,
    tokenizer = tokenizer,
    data_collator = data_collator,
)


In [8]:
trainer.train()



Epoch,Training Loss,Validation Loss,Matthews Correlation
1,0.6079,0.606416,0.0
2,0.5504,0.54834,0.29146
3,0.4301,0.600234,0.334423
4,0.3437,0.724782,0.367835
5,0.2667,0.79062,0.372226
6,0.233,0.875612,0.386991
7,0.1951,1.011619,0.376831
8,0.1615,1.212003,0.379151
9,0.1437,1.31686,0.388616
10,0.1256,1.368353,0.383503




TrainOutput(global_step=5350, training_loss=0.29495446606217146, metrics={'train_runtime': 388.2611, 'train_samples_per_second': 220.238, 'train_steps_per_second': 13.779, 'total_flos': 457565871244848.0, 'train_loss': 0.29495446606217146, 'epoch': 10.0})

In [9]:
#FT 0.42 0.386991
#KD 0.474883 0.481270 0.469735

In [None]:
import matplotlib.pyplot as plt

def plot_trainer_loss(trainer):
    # Extract the logged values
    log_history = trainer.state.log_history
    
    train_loss = []
    val_loss = []
    train_steps = []
    val_steps = []
    
    for entry in log_history:
        if 'loss' in entry:
            train_loss.append(entry['loss'])
            train_steps.append(entry['step'])
        if 'eval_loss' in entry:
            val_loss.append(entry['eval_loss'])
            val_steps.append(entry['step'])
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Plot training loss
    plt.plot(train_steps, train_loss, label='Training Loss')
    
    # Plot validation loss
    plt.plot(val_steps, val_loss, label='Validation Loss')
    
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_trainer_loss(trainer)