In [1]:
import torch
from transformers import DistilBertForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import datasets
import os
import numpy as np
import re

  from .autonotebook import tqdm as notebook_tqdm
2024-08-09 01:58:22.738701: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-09 01:58:22.760921: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-09 01:58:22.760946: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-09 01:58:22.776712: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler f

In [2]:
def preprocess_and_tokenize(examples):
    processed_texts = []
    for text in examples['text']:
        # Preprocess text
        if not text.strip():
            continue
        
        text = text.strip()
        
        text = re.sub(r'@.@', '-', text)
        processed_texts.append(text)
        
    # Tokenize the processed texts
    tokenized = tokenizer(
        processed_texts,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_special_tokens_mask=True
    )
    return tokenized

# Load the tokenizer
student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)

# Load the dataset
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

# Apply preprocessing and tokenization
tokenized_data = dataset.map(
    preprocess_and_tokenize,
    batched=True,
    remove_columns=dataset['train'].column_names
)


Map: 100%|██████████| 1801350/1801350 [07:37<00:00, 3933.70 examples/s]


In [3]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = None, alpha_cos = None, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_cos = alpha_cos
        self.teacher.eval()
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

    def distillation_loss(self, student_outputs, teacher_outputs, attention_mask):
        #soft target probabilities
        s_logits = student_outputs.logits  # (bs, seq_length, voc_size)
        t_logits = teacher_outputs.logits  # (bs, seq_length, voc_size)

        attention_mask = attention_mask.bool()
        mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
        
        s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()
        
        soft_student = F.log_softmax(s_logits_slct / self.temperature, dim = -1)
        soft_teacher = F.softmax(t_logits_slct / self.temperature, dim = -1)
        #Kullback Leibler Divergence
        distill_loss = self.ce_loss_fct(soft_student, soft_teacher, reduction = 'batchmean') * (self.temperature**2) 
        return distill_loss

    def cosine_embedding_loss(self, student_outputs, teacher_outputs, attention_mask):
        #cosine embedding loss
        s_hidden_states = student_outputs.hidden_states[-1]  # (bs, seq_length, dim)
        t_hidden_states = teacher_outputs.hidden_states[-1]  # (bs, seq_length, dim)
        
        attention_mask = attention_mask.bool()
        mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
        assert s_hidden_states.size() == t_hidden_states.size()
        dim = s_hidden_states.size(-1)

        s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
        s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
        t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
        t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

        target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
        loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
        return loss_cos

    def compute_loss(self, model, inputs, return_outputs = False):
        #Distillation loss over soft target probabilities of teacher and student, KL DIV
        #Cosine embedding loss
        #supervised training loss
        #Attention Score Alignment???
        
        student_outputs = model(**inputs)
        student_loss = student_outputs.loss
        
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            
        l_ce = self.distillation_loss(student_outputs, teacher_outputs, inputs['attention_mask'])
        
        l_cos = 0
        if self.alpha_cos > 0:
            l_cos += self.cosine_embedding_loss(student_outputs, teacher_outputs, inputs['attention_mask'])

        #Combine losses
        loss = self.alpha_ce * l_ce + l_cos * self.alpha_cos + student_loss * (1 - (self.alpha_ce + self.alpha_cos)) 
        
        return (loss, student_outputs) if return_outputs else loss

In [4]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding, BertForMaskedLM, DistilBertForMaskedLM 

#Load Models
teacher_id = "google-bert/bert-base-uncased"
teacher_model = DistilBertForMaskedLM.from_pretrained(
    teacher_id,
    output_hidden_states=True,
)

student_config = DistilBertConfig(output_hidden_states = True)
student_model = DistilBertForMaskedLM(student_config)
initialized_weights = torch.load('/mnt/tony/MSc2024/distilbert_init/distilbert_init.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

Some weights of DistilBertForMaskedLM were not initialized from the model checkpoint at lvwerra/distilbert-imdb and are newly initialized: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForMaskedLM(
  (activation): GELUActivation()
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.

In [5]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def preprocess_logits_for_metrics(logits, labels):
    """
    Preprocess the logits to ensure they are in the correct format for metric computation.
    This function will be called during the evaluation process.
    """
    if isinstance(logits, tuple):  
        logits = logits[0]  # get logit tensors
    pred_ids = torch.argmax(logits, dim=-1)
    
    return pred_ids, labels

def compute_metrics(eval_preds):
    predictions, labels = eval_preds

    # Flatten the arrays if they're multi-dimensional
    predictions = predictions[0].flatten()
    labels = labels.flatten()

    # Compute MLM accuracy only on masked tokens
    masked_tokens = labels != -100
    mlm_accuracy = accuracy_score(labels[masked_tokens], predictions[masked_tokens])
    
    # Compute overall accuracy, precision, recall, and F1
    # Ignore padding tokens (-100)
    valid_tokens = labels != -100
    accuracy = accuracy_score(labels[valid_tokens], predictions[valid_tokens])
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels[valid_tokens], 
        predictions[valid_tokens], 
        average='weighted',
        zero_division=0
    )
    
    return { 
        "accuracy": accuracy,
        "precision": precision,
        "f1": f1
    }
    

In [6]:
from transformers import DataCollatorForLanguageModeling
EPOCHS = 2
BATCH_SIZE = 2
LEARNING_RATE = 2e-5

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    eval_strategy="steps",
    eval_steps = 500,
    save_strategy="steps",
    save_total_limit=2,
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_data['train'],         
    eval_dataset=tokenized_data['validation'],
    temperature = 5,
    alpha_ce = 0.3,
    alpha_cos = 0.2,
    tokenizer = tokenizer,
    data_collator = data_collator,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

In [7]:
trainer.train()



Step,Training Loss,Validation Loss,Mlm Accuracy,Accuracy,Precision,Recall,F1
500,63.9827,,0.067072,0.067072,0.010725,0.067072,0.009636
1000,56.0469,,0.074271,0.074271,0.019811,0.074271,0.017891
1500,54.9976,,0.073492,0.073492,0.019056,0.073492,0.019154
2000,54.5089,,0.077844,0.077844,0.016811,0.077844,0.020753
2500,53.857,,0.073212,0.073212,0.017591,0.073212,0.021553
3000,53.0982,,0.076129,0.076129,0.029193,0.076129,0.026445
3500,53.3226,,0.077201,0.077201,0.057508,0.077201,0.029157
4000,52.4454,,0.079999,0.079999,0.075715,0.079999,0.034903
4500,51.7112,,0.084324,0.084324,0.064318,0.084324,0.040335
5000,50.9222,,0.083149,0.083149,0.049324,0.083149,0.038476


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


TrainOutput(global_step=291258, training_loss=33.77319262173432, metrics={'train_runtime': 194854.3061, 'train_samples_per_second': 11.958, 'train_steps_per_second': 1.495, 'total_flos': 3.088751822507336e+17, 'train_loss': 33.77319262173432, 'epoch': 2.0})

In [8]:
print(preprocess_and_tokenize(dataset['validation'][0:10]))

{'input_ids': [[101, 1031, 2516, 1033, 7570, 7849, 2271, 13091, 7946, 1031, 1013, 2516, 1033, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [9]:
trainer.save_model('./models/distilbert_wikitext')