In [1]:
import comet_ml
import torch
from datasets import load_from_disk
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenized_dataset = load_from_disk('/shared/Tony/MSc2024/data/tokenized_wikipedia_20231101.hf') #20231101
#train_dataset = tokenized_dataset["train"]
#val_dataset = tokenized_dataset["validation"]
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(len(tokenized_dataset))

8313965




In [3]:
#comet_ml.init(project_name="distilbert_dotprod")

import random


#num_samples = len(train_dataset)
train_subset = tokenized_dataset.select(range(6313965))
print(train_subset)

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 6313965
})


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments
import matplotlib.pyplot as plt
import numpy as np
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup

class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, temperature=2.0, alpha_ce=0.5, alpha_hidden=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha_ce = alpha_ce
        self.alpha_hidden = alpha_hidden
        self.teacher.eval()  # Teacher model in evaluation mode

        # For tracking stats of each hidden state (6 transformer layers in DistilBERT)
        self.layer_stats_student = {i: {'min': [], 'max': [], 'mean': []} for i in range(6)}
        self.layer_stats_teacher = {i: {'min': [], 'max': [], 'mean': []} for i in range(6)}
    
    def compute_loss(self, model, inputs, return_outputs=False):
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits

        # Hidden state distillation: one-to-one layer mapping, excluding embedding layer (index 0)
        hidden_loss = 0.0
        student_hidden_states = student_outputs.hidden_states[1:]  # Skip embedding layer
        teacher_hidden_states = teacher_outputs.hidden_states[1:]  

        # One-to-one layer matching
        for i, (student_layer_hidden, teacher_layer_hidden) in enumerate(zip(student_hidden_states, teacher_hidden_states)):
            hidden_loss += nn.MSELoss()(student_layer_hidden, teacher_layer_hidden)

            # Track min, max, and mean for both student and teacher hidden states
            self.layer_stats_student[i]['min'].append(student_layer_hidden.min().item())
            self.layer_stats_student[i]['max'].append(student_layer_hidden.max().item())
            self.layer_stats_student[i]['mean'].append(student_layer_hidden.mean().item())

            self.layer_stats_teacher[i]['min'].append(teacher_layer_hidden.min().item())
            self.layer_stats_teacher[i]['max'].append(teacher_layer_hidden.max().item())
            self.layer_stats_teacher[i]['mean'].append(teacher_layer_hidden.mean().item())

        # Combine classification, distillation, and hidden state losses
        total_loss = hidden_loss

        return (total_loss, student_outputs) if return_outputs else total_loss

    def train(self, resume_from_checkpoint=None, **kwargs):
        print(self.get_num_trainable_parameters())
        res = super().train(resume_from_checkpoint=resume_from_checkpoint, **kwargs)
        self.plot_hidden_state_stats()  # Plot hidden state stats after training
        return res

    def plot_hidden_state_stats(self):
        # Plot min, max, and mean for each hidden state for both student and teacher
        fig, axs = plt.subplots(3, 2, figsize=(15, 12))

        for i in range(6):  # 6 transformer layers in DistilBERT (skipping the embedding layer)
            axs[0, 0].plot(self.layer_stats_student[i]['min'], label=f'Layer {i + 1}')
            axs[0, 1].plot(self.layer_stats_teacher[i]['min'], label=f'Layer {i + 1}')
            axs[1, 0].plot(self.layer_stats_student[i]['max'], label=f'Layer {i + 1}')
            axs[1, 1].plot(self.layer_stats_teacher[i]['max'], label=f'Layer {i + 1}')
            axs[2, 0].plot(self.layer_stats_student[i]['mean'], label=f'Layer {i + 1}')
            axs[2, 1].plot(self.layer_stats_teacher[i]['mean'], label=f'Layer {i + 1}')

        axs[0, 0].set_title('Student Hidden State Min Values')
        axs[0, 1].set_title('Teacher Hidden State Min Values')
        axs[1, 0].set_title('Student Hidden State Max Values')
        axs[1, 1].set_title('Teacher Hidden State Max Values')
        axs[2, 0].set_title('Student Hidden State Mean Values')
        axs[2, 1].set_title('Teacher Hidden State Mean Values')

        for ax in axs.flatten():
            ax.set_xlabel('Training Step')
            ax.set_ylabel('Value')
            ax.legend()

        plt.tight_layout()
        plt.savefig('hidden_state_stats.png')
        plt.close()


2024-10-02 13:24:16.107490: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-02 13:24:16.119452: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-02 13:24:16.132481: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-02 13:24:16.136341: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-02 13:24:16.147276: I tensorflow/core/platform/cpu_feature_guar

In [5]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DistilBertForMaskedLM, DataCollatorForLanguageModeling
from iDistilbert import iDistilBertForMaskedLM
from transformers import AutoTokenizer

student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)
#Load Models

teacher_id = "distilbert/distilbert-base-uncased"
teacher_config = DistilBertConfig(    
    #distance_metric = "cosine_distance",
    #activation_function = "softmax",
    #signed_inhibitor =  False,
    #alpha = 0,
    #center = False,
    #output_contexts = False,
    output_hidden_states = True,
)
    
teacher_model = DistilBertForMaskedLM.from_pretrained(
        teacher_id,
        config=teacher_config,
    )

student_config = DistilBertConfig(
    distance_metric = "manhattan_distance",
    activation_function = "relu",
    signed_inhibitor =  True,
    center = True,
    output_contexts = False,
    output_hidden_states = True,
    )

student_model = iDistilBertForMaskedLM(student_config)

initialized_weights = torch.load('/shared/Tony/MSc2024/KD_weight_init/models/qkv_center_inhibitor1_init.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

  initialized_weights = torch.load('/shared/Tony/MSc2024/distilbert_init/models/qkv_center_inhibitor1_init.pth')


iDistilBertForMaskedLM(
  (activation): GELUActivation()
  (distilbert): iDistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): iTransformer(
      (layer): ModuleList(
        (0-5): 6 x iTransformerBlock(
          (attention): iMultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout

In [6]:
EPOCHS = 3
BATCH_SIZE = 8
LEARNING_RATE = 3e-4
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    #eval_strategy="steps",
    save_strategy="steps",
    #eval_steps = 2014, 
    logging_steps = 32,
    save_steps=3699,
    save_total_limit=10,
    seed = 42,
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
    warmup_ratio=0.05,
    gradient_accumulation_steps=32,
    lr_scheduler_type="cosine",
)

trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=train_subset,         
    #eval_dataset=val_dataset,
    tokenizer = tokenizer,
    data_collator = data_collator,
)


In [7]:
trainer.train()

66985548




Step,Training Loss
32,2.0632
64,1.5276
96,1.269
128,1.1469
160,1.0696
192,1.0085
224,0.9584
256,0.9156
288,0.8761
320,0.8409


  plt.tight_layout()
  plt.savefig('hidden_state_stats.png')


TrainOutput(global_step=36993, training_loss=0.2174487009399197, metrics={'train_runtime': 862982.5078, 'train_samples_per_second': 21.949, 'train_steps_per_second': 0.043, 'total_flos': 2.5107636938611507e+18, 'train_loss': 0.2174487009399197, 'epoch': 2.9997643320333585})

In [8]:
import os

folder = 'models/'
os.makedirs(folder, exist_ok=True)
torch.save(student_model.state_dict(), os.path.join(folder, 'hiddenstates4_center_inhibitor_init.pth'))

In [10]:
import matplotlib.pyplot as plt
import numpy as np

def smooth_data(data, window_size):
    cumsum = np.cumsum(np.insert(data, 0, 0)) 
    return (cumsum[window_size:] - cumsum[:-window_size]) / window_size

def organize_data_by_layer(data):
    organized_data = {}
    for entry in data:
        layer = entry['layer']
        if layer not in organized_data:
            organized_data[layer] = {'student': {'mean': [], 'max': [], 'min': []},
                                     'teacher': {'mean': [], 'max': [], 'min': []}}
        for model in ['student', 'teacher']:
            for stat in ['mean', 'max', 'min']:
                organized_data[layer][model][stat].append(entry[model][stat])
    return organized_data

def plot_all_layers(data, smooth_window=100):
    organized_data = organize_data_by_layer(data)
    num_layers = len(organized_data)
    
    fig, axs = plt.subplots(num_layers, 3, figsize=(20, 8 * num_layers))
    fig.suptitle('Layer Statistics Comparison (Smoothed)', fontsize=16)
    
    if num_layers == 1:
        axs = axs.reshape(1, -1)
    
    colors = {'student': 'blue', 'teacher': 'red'}
    stats = ['mean', 'max', 'min']
    
    for layer, (layer_num, layer_data) in enumerate(sorted(organized_data.items())):
        for col, stat in enumerate(stats):
            ax = axs[layer, col]
            for model in ['student', 'teacher']:
                original_data = layer_data[model][stat]
                
                # Plot original data with low alpha
                ax.plot(original_data, color=colors[model], alpha=0.3, linewidth=1)
                
                # Smooth and plot the data
                if len(original_data) > smooth_window:
                    smoothed_data = smooth_data(original_data, smooth_window)
                    ax.plot(range(smooth_window-1, len(original_data)), smoothed_data, 
                            color=colors[model], label=f'{model.capitalize()} (Smoothed)')
                else:
                    ax.plot(original_data, color=colors[model], label=model.capitalize())
            
            ax.set_title(f'Layer {layer_num} - {stat.capitalize()}', fontsize=12)
            ax.set_xlabel('Steps', fontsize=10)
            ax.set_ylabel('Value', fontsize=10)
            ax.legend(fontsize=8)
            ax.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig('all_layers_stats_smoothed2.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

In [11]:
plot_all_layers(trainer.context_loss_stats)