In [1]:
#import comet_ml
import numpy as np
import torch
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm
2024-10-14 23:48:55.690408: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-14 23:48:55.704311: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-14 23:48:55.718058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-14 23:48:55.722126: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-14 23:48:55.7

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the STS-B dataset
dataset = load_dataset("glue", "stsb")

# Load the BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Preprocess the dataset
def preprocess_function(examples):
    return tokenizer(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
        return_token_type_ids=True,
    )

tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['sentence1', 'sentence2', 'idx'])

# Ensure labels are floats
def convert_label_to_float(examples):
    examples['label'] = [float(label) for label in examples['label']]
    return examples

tokenized_datasets = tokenized_datasets.map(convert_label_to_float, batched=True)


Using the latest cached version of the dataset since glue couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'stsb' at /root/.cache/huggingface/datasets/glue/stsb/0.0.0/bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c (last modified on Mon Sep 30 18:19:13 2024).


In [3]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = None, alpha_cos = None, alpha_mlm = None, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_cos = alpha_cos
        self.alpha_mlm = alpha_mlm
        self.teacher.eval()
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
        self.layer_logs = []

    def distillation_loss(self, student_logits, teacher_logits):
        # Use MSE loss for regression outputs
        loss_fn = nn.MSELoss()
        loss = loss_fn(student_logits, teacher_logits)
        return loss

    def cosine_embedding_loss(self, student_outputs, teacher_outputs):
        s_hidden_states = student_outputs.hidden_states[-1]
        t_hidden_states = teacher_outputs.hidden_states[-1]
        assert t_hidden_states.size() == s_hidden_states.size()
        dim = s_hidden_states.size(-1)
        s_hidden_states_slct = s_hidden_states.view(-1, dim)
        t_hidden_states_slct = t_hidden_states.view(-1, dim)

        target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) 
        loss = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
        return loss

    def compute_loss(self, model, inputs, return_outputs = False):
        #Distillation loss over soft target probabilities of teacher and student, KL DIV
        #Cosine embedding loss
        #supervised training loss
        #Attention Score Alignment???

        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        
        student_loss = student_outputs.loss
        
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits
        
        l_ce = self.distillation_loss(student_logits, teacher_logits)
        
        l_cos = self.cosine_embedding_loss(student_outputs, teacher_outputs) if self.alpha_cos > 0 else 0

        #Combine losses
        loss = self.alpha_ce * l_ce + l_cos * self.alpha_cos + student_loss * self.alpha_mlm
        
        return (loss, student_outputs) if return_outputs else loss
         
        

In [4]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding
from iDistilbert import iDistilBertForSequenceClassification

#Load Models
teacher_id = "JeremiahZ/bert-base-uncased-stsb"
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = 1,
)

student_config = DistilBertConfig(
    distance_metric = "manhattan_distance",
    activation_function = "relu",
    signed_inhibitor =  True,
    center = True,
    num_labels = 1
    )

student_model = iDistilBertForSequenceClassification(
    config = student_config
)

initialized_weights = torch.load('/shared/Tony/MSc2024/KD_weight_init/models/hiddenstates4_center_inhibitor_init.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device) 
student_model.to(device)

  initialized_weights = torch.load('/shared/Tony/MSc2024/distilbert_init/models/hiddenstates4_center_inhibitor_init.pth')


iDistilBertForSequenceClassification(
  (distilbert): iDistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): iTransformer(
      (layer): ModuleList(
        (0-5): 6 x iTransformerBlock(
          (attention): iMultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=Fal

In [5]:
import evaluate
import numpy as np

#experiment = comet_ml.get_global_experiment()

pearson_metric = evaluate.load("glue", "stsb")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Clip predictions between 0 and 5 as STS-B labels are in this range
    predictions = np.clip(predictions, 0, 5)
    # Compute Pearson correlation
    return pearson_metric.compute(predictions=predictions, references=labels)



In [6]:
EPOCHS = 10

BATCH_SIZE = 8
LEARNING_RATE = 4e-5

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    #metric_for_best_model="accuracy",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
    lr_scheduler_type="linear",
    weight_decay = 0.01,
    #gradient_accumulation_steps=4,
)

In [7]:
from torch.utils.data import default_collate
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_datasets['train'],         
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics = compute_metrics,
    temperature = 4,
    alpha_ce = 0.5,
    alpha_cos = 0,
    alpha_mlm = 0.5,
    tokenizer = tokenizer,
    data_collator = data_collator,
)


In [8]:
trainer.train()



Epoch,Training Loss,Validation Loss,Pearson,Spearmanr
1,No log,1.685618,0.799717,0.806885
2,1.712800,1.306577,0.839401,0.836541
3,1.480900,1.284718,0.838017,0.836148
4,1.480900,1.33149,0.836648,0.835506
5,1.410700,1.306381,0.821936,0.820096
6,1.370200,1.293062,0.838949,0.835909
7,1.351300,1.285704,0.836737,0.834546
8,1.351300,1.309229,0.834512,0.832294
9,1.340600,1.305584,0.832826,0.829905
10,1.334500,1.290402,0.834914,0.832125


Could not locate the best model at ./results/checkpoint-1080/pytorch_model.bin, if you are running a distributed training on multiple nodes, you should activate `--save_on_each_node`.


TrainOutput(global_step=3600, training_loss=1.4256152682834202, metrics={'train_runtime': 919.4788, 'train_samples_per_second': 62.525, 'train_steps_per_second': 3.915, 'total_flos': 923781019642740.0, 'train_loss': 1.4256152682834202, 'epoch': 10.0})

In [None]:
#0.837107 0.83717 FT

In [13]:
trainer.save_model('./models')

In [None]:
import matplotlib.pyplot as plt

def plot_trainer_loss(trainer):
    # Extract the logged values
    log_history = trainer.state.log_history
    
    train_loss = []
    val_loss = []
    train_steps = []
    val_steps = []
    
    for entry in log_history:
        if 'loss' in entry:
            train_loss.append(entry['loss'])
            train_steps.append(entry['step'])
        if 'eval_loss' in entry:
            val_loss.append(entry['eval_loss'])
            val_steps.append(entry['step'])
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Plot training loss
    plt.plot(train_steps, train_loss, label='Training Loss')
    
    # Plot validation loss
    plt.plot(val_steps, val_loss, label='Validation Loss')
    
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_trainer_loss(trainer)