In [1]:
import comet_ml
import torch
from datasets import load_from_disk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenized_dataset = load_from_disk('/shared/Tony/MSc2024/data/tokenized_preprocessed_wikitext103.hf') #20231101
train_dataset = tokenized_dataset["train"]
val_dataset = tokenized_dataset["validation"]

In [3]:
#comet_ml.init(project_name="distilbert_dotprod")
from transformers import AutoTokenizer
import random

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
num_samples = len(train_dataset)//10
train_subset = train_dataset.select(range(num_samples))
print(train_subset)

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 85965
})




In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments, TrainerCallback
import matplotlib.pyplot as plt
import numpy as np
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup

class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, hidden = False, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        self.layer_groups = [f"transformer.layer.{i}" for i in range(6)] 
        self.current_layer_group = 0
        self.unfrozen_layers = set()
        self.layer_logs = []
        self.context_loss_stats = []
        self.hidden = hidden
        self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

    def hidden_state_loss(self, student_outputs, teacher_outputs):
        s_hidden_states = student_outputs
        t_hidden_states = teacher_outputs
        assert t_hidden_states.size() == s_hidden_states.size()
        dim = s_hidden_states.size(-1)
        s_hidden_states_slct = s_hidden_states.view(-1, dim)
        t_hidden_states_slct = t_hidden_states.view(-1, dim)

        target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) 
        loss = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
        #loss = F.mse_loss(s_hidden_states_slct, t_hidden_states_slct)
        return loss
        
    def compute_loss(self, model, inputs, return_outputs = False):
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits

        student_obj = student_outputs.hidden_states[self.current_layer_group + 1] # +1 because idx 0 is the embedding layer.
        teacher_obj = teacher_outputs.hidden_states[self.current_layer_group + 1]


        loss = self.hidden_state_loss(student_obj,teacher_obj)
      
        #student_obj = student_outputs.contexts[self.current_layer_group]
        #teacher_obj = teacher_outputs.contexts[self.current_layer_group]
        #loss = F.mse_loss(student_obj, teacher_obj)

        student_stats = {
            'mean': student_obj.mean().item(),
            'max': student_obj.max().item(),
            'min': student_obj.min().item()
        }
        teacher_stats = {
            'mean': teacher_obj.mean().item(),
            'max': teacher_obj.max().item(),
            'min': teacher_obj.min().item()
        }

        self.context_loss_stats.append({
            'layer': self.current_layer_group,
            'student': student_stats,
            'teacher': teacher_stats,
        })
        
        return (loss, student_outputs) if return_outputs else loss
        
    def train(self, resume_from_checkpoint=None, **kwargs):
        layer_plots = []
        for layer_group in self.layer_groups:
            print(f"Training layer group: {layer_group}")
            self.switch_to_next_layer_group()
            print(self.get_num_trainable_parameters())
            res = super().train(resume_from_checkpoint=resume_from_checkpoint, **kwargs)
            self.layer_logs.append(self.state.log_history.copy())
            self.current_layer_group += 1
            #self.save_model(f"./results/layer_{layer_group}")
        self.plot_layer_losses()
        return res

    def freeze_all_layers(self):
        for param in self.model.parameters():
            param.requires_grad = False
    
    def switch_to_next_layer_group(self):
        self.freeze_all_layers()
        print("Current layer", self.current_layer_group)
        
        if self.current_layer_group < len(self.layer_groups):
            current_layer = self.layer_groups[self.current_layer_group]
            newly_unfrozen_params = []
            
            for name, param in self.model.named_parameters():
                if current_layer in name: #and any(qkv in name for qkv in ['q_lin', 'k_lin', 'v_lin', 'out_lin', 'sa_layer_norm','ffn','output_layer_norm']): #v_lin q_lin k_lin
                    param.requires_grad = True
                    newly_unfrozen_params.append(param)
                        
            print(f"Unfrozen parameters for layer {self.current_layer_group}:")
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    print(f"  - {name}")

            optimizer_grouped_parameters = [
            {
                "params": [p for p in newly_unfrozen_params if p.requires_grad],
                "weight_decay": self.args.weight_decay,
            }
            ]
            
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
            
            num_training_steps = len(self.train_dataset) // self.args.train_batch_size * self.args.num_train_epochs
            warmup_rate = 0.05  # 10% 
            warmup_steps = int(num_training_steps * warmup_rate)
            self.lr_scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps, 
                num_training_steps=num_training_steps
            )

    def plot_layer_losses(self):
        fig, axs = plt.subplots(2, 3, figsize=(20, 15))
        for layer, (ax, layer_logs) in enumerate(zip(axs.flatten(), self.layer_logs)):
            train_data = [(log['step'], log['loss']) for log in layer_logs if 'loss' in log]
            eval_data = [(log['step'], log['eval_loss']) for log in layer_logs if 'eval_loss' in log]
            
            if train_data:
                steps, losses = zip(*train_data)
                ax.plot(steps, losses, label='Train Loss')
            if eval_data:
                steps, losses = zip(*eval_data)
                ax.plot(steps, losses, label='Validation Loss')
            
            ax.set_title(f'Layer {layer} Loss')
            ax.set_xlabel('Steps')
            ax.set_ylabel('Loss')
            ax.legend()
        
        plt.tight_layout()
        plt.savefig('hiddenstate_cosineloss_layerloss.png')
        plt.close()
        




    

2024-09-21 23:41:51.900069: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-21 23:41:51.911890: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-21 23:41:51.924491: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-21 23:41:51.928283: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-21 23:41:51.938977: I tensorflow/core/platform/cpu_feature_guar

In [5]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DistilBertForMaskedLM, DataCollatorForLanguageModeling
from iDistilbert import iDistilBertForMaskedLM
from transformers import AutoTokenizer

student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)
#Load Models

teacher_id = "distilbert/distilbert-base-uncased"
teacher_config = DistilBertConfig(    
    distance_metric = "cosine_distance",
    activation_function = "softmax",
    signed_inhibitor =  False,
    alpha = 0,
    center = False,
    output_contexts = False,
    output_hidden_states = True,
)
    
teacher_model = iDistilBertForMaskedLM.from_pretrained(
        teacher_id,
        config=teacher_config,
    )

student_config = DistilBertConfig(
    distance_metric = "manhattan_distance",
    activation_function = "relu",
    signed_inhibitor =  True,
    alpha = 0,
    center = True,
    output_contexts = False,
    output_hidden_states = True,
    )

student_model = iDistilBertForMaskedLM(student_config)

initialized_weights = torch.load('/shared/Tony/MSc2024/distilbert_init/models/qk_inhibitor_init.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

  initialized_weights = torch.load('/shared/Tony/MSc2024/distilbert_init/models/qk_inhibitor_init.pth')


iDistilBertForMaskedLM(
  (activation): GELUActivation()
  (distilbert): iDistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): iTransformer(
      (layer): ModuleList(
        (0-5): 6 x iTransformerBlock(
          (attention): iMultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout

In [6]:
EPOCHS = 2
BATCH_SIZE = 8
LEARNING_RATE = 5e-4
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    metric_for_best_model="loss",
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps = 268, 
    logging_steps = 20,
    save_steps=268,
    save_total_limit=2,
    seed = 42,
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
    warmup_ratio=0.05,
    gradient_accumulation_steps=4,
    lr_scheduler_type="cosine",
)

trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=train_subset,         
    eval_dataset=val_dataset,
    tokenizer = tokenizer,
    data_collator = data_collator,
    hidden = True,
)


In [7]:
trainer.train()

Training layer group: transformer.layer.0
Current layer 0
Unfrozen parameters for layer 0:
  - distilbert.transformer.layer.0.attention.q_lin.weight
  - distilbert.transformer.layer.0.attention.q_lin.bias
  - distilbert.transformer.layer.0.attention.k_lin.weight
  - distilbert.transformer.layer.0.attention.k_lin.bias
  - distilbert.transformer.layer.0.attention.v_lin.weight
  - distilbert.transformer.layer.0.attention.v_lin.bias
  - distilbert.transformer.layer.0.attention.out_lin.weight
  - distilbert.transformer.layer.0.attention.out_lin.bias
  - distilbert.transformer.layer.0.sa_layer_norm.weight
  - distilbert.transformer.layer.0.sa_layer_norm.bias
  - distilbert.transformer.layer.0.ffn.lin1.weight
  - distilbert.transformer.layer.0.ffn.lin1.bias
  - distilbert.transformer.layer.0.ffn.lin2.weight
  - distilbert.transformer.layer.0.ffn.lin2.bias
  - distilbert.transformer.layer.0.output_layer_norm.weight
  - distilbert.transformer.layer.0.output_layer_norm.bias
7087872




Step,Training Loss,Validation Loss
268,0.0663,0.046045
536,0.0494,0.035637
804,0.0437,0.031754
1072,0.0401,0.028751
1340,0.0359,0.024684
1608,0.033,0.022203
1876,0.0316,0.020833
2144,0.0304,0.02006
2412,0.0299,0.01968
2680,0.0292,0.019082


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


Training layer group: transformer.layer.1
Current layer 1
Unfrozen parameters for layer 1:
  - distilbert.transformer.layer.1.attention.q_lin.weight
  - distilbert.transformer.layer.1.attention.q_lin.bias
  - distilbert.transformer.layer.1.attention.k_lin.weight
  - distilbert.transformer.layer.1.attention.k_lin.bias
  - distilbert.transformer.layer.1.attention.v_lin.weight
  - distilbert.transformer.layer.1.attention.v_lin.bias
  - distilbert.transformer.layer.1.attention.out_lin.weight
  - distilbert.transformer.layer.1.attention.out_lin.bias
  - distilbert.transformer.layer.1.sa_layer_norm.weight
  - distilbert.transformer.layer.1.sa_layer_norm.bias
  - distilbert.transformer.layer.1.ffn.lin1.weight
  - distilbert.transformer.layer.1.ffn.lin1.bias
  - distilbert.transformer.layer.1.ffn.lin2.weight
  - distilbert.transformer.layer.1.ffn.lin2.bias
  - distilbert.transformer.layer.1.output_layer_norm.weight
  - distilbert.transformer.layer.1.output_layer_norm.bias
7087872




Step,Training Loss,Validation Loss
268,0.1037,0.091156
536,0.0948,0.083577
804,0.0906,0.079658
1072,0.0884,0.077167
1340,0.0861,0.075552
1608,0.0842,0.074474
1876,0.0843,0.073789
2144,0.0827,0.072651
2412,0.0813,0.070331
2680,0.074,0.060477


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


Training layer group: transformer.layer.2
Current layer 2
Unfrozen parameters for layer 2:
  - distilbert.transformer.layer.2.attention.q_lin.weight
  - distilbert.transformer.layer.2.attention.q_lin.bias
  - distilbert.transformer.layer.2.attention.k_lin.weight
  - distilbert.transformer.layer.2.attention.k_lin.bias
  - distilbert.transformer.layer.2.attention.v_lin.weight
  - distilbert.transformer.layer.2.attention.v_lin.bias
  - distilbert.transformer.layer.2.attention.out_lin.weight
  - distilbert.transformer.layer.2.attention.out_lin.bias
  - distilbert.transformer.layer.2.sa_layer_norm.weight
  - distilbert.transformer.layer.2.sa_layer_norm.bias
  - distilbert.transformer.layer.2.ffn.lin1.weight
  - distilbert.transformer.layer.2.ffn.lin1.bias
  - distilbert.transformer.layer.2.ffn.lin2.weight
  - distilbert.transformer.layer.2.ffn.lin2.bias
  - distilbert.transformer.layer.2.output_layer_norm.weight
  - distilbert.transformer.layer.2.output_layer_norm.bias
7087872




Step,Training Loss,Validation Loss
268,0.1032,0.091226
536,0.0964,0.085249
804,0.0942,0.082875
1072,0.0931,0.081589
1340,0.0914,0.080625
1608,0.0902,0.079876
1876,0.09,0.079083
2144,0.0888,0.078664
2412,0.0888,0.07777
2680,0.0875,0.076184


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


Training layer group: transformer.layer.3
Current layer 3
Unfrozen parameters for layer 3:
  - distilbert.transformer.layer.3.attention.q_lin.weight
  - distilbert.transformer.layer.3.attention.q_lin.bias
  - distilbert.transformer.layer.3.attention.k_lin.weight
  - distilbert.transformer.layer.3.attention.k_lin.bias
  - distilbert.transformer.layer.3.attention.v_lin.weight
  - distilbert.transformer.layer.3.attention.v_lin.bias
  - distilbert.transformer.layer.3.attention.out_lin.weight
  - distilbert.transformer.layer.3.attention.out_lin.bias
  - distilbert.transformer.layer.3.sa_layer_norm.weight
  - distilbert.transformer.layer.3.sa_layer_norm.bias
  - distilbert.transformer.layer.3.ffn.lin1.weight
  - distilbert.transformer.layer.3.ffn.lin1.bias
  - distilbert.transformer.layer.3.ffn.lin2.weight
  - distilbert.transformer.layer.3.ffn.lin2.bias
  - distilbert.transformer.layer.3.output_layer_norm.weight
  - distilbert.transformer.layer.3.output_layer_norm.bias
7087872




Step,Training Loss,Validation Loss
268,0.1173,0.109103
536,0.113,0.10493
804,0.1113,0.102966
1072,0.1054,0.094659
1340,0.1014,0.09138
1608,0.0995,0.089803
1876,0.0988,0.08865
2144,0.098,0.08742
2412,0.0974,0.086442
2680,0.0971,0.086003


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


Training layer group: transformer.layer.4
Current layer 4
Unfrozen parameters for layer 4:
  - distilbert.transformer.layer.4.attention.q_lin.weight
  - distilbert.transformer.layer.4.attention.q_lin.bias
  - distilbert.transformer.layer.4.attention.k_lin.weight
  - distilbert.transformer.layer.4.attention.k_lin.bias
  - distilbert.transformer.layer.4.attention.v_lin.weight
  - distilbert.transformer.layer.4.attention.v_lin.bias
  - distilbert.transformer.layer.4.attention.out_lin.weight
  - distilbert.transformer.layer.4.attention.out_lin.bias
  - distilbert.transformer.layer.4.sa_layer_norm.weight
  - distilbert.transformer.layer.4.sa_layer_norm.bias
  - distilbert.transformer.layer.4.ffn.lin1.weight
  - distilbert.transformer.layer.4.ffn.lin1.bias
  - distilbert.transformer.layer.4.ffn.lin2.weight
  - distilbert.transformer.layer.4.ffn.lin2.bias
  - distilbert.transformer.layer.4.output_layer_norm.weight
  - distilbert.transformer.layer.4.output_layer_norm.bias
7087872




Step,Training Loss,Validation Loss
268,0.1015,0.092275
536,0.0975,0.088081
804,0.0962,0.086269
1072,0.0947,0.085084
1340,0.0924,0.083502
1608,0.0914,0.08228
1876,0.0913,0.081612
2144,0.0904,0.081402
2412,0.09,0.081172
2680,0.0898,0.080557


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


Training layer group: transformer.layer.5
Current layer 5
Unfrozen parameters for layer 5:
  - distilbert.transformer.layer.5.attention.q_lin.weight
  - distilbert.transformer.layer.5.attention.q_lin.bias
  - distilbert.transformer.layer.5.attention.k_lin.weight
  - distilbert.transformer.layer.5.attention.k_lin.bias
  - distilbert.transformer.layer.5.attention.v_lin.weight
  - distilbert.transformer.layer.5.attention.v_lin.bias
  - distilbert.transformer.layer.5.attention.out_lin.weight
  - distilbert.transformer.layer.5.attention.out_lin.bias
  - distilbert.transformer.layer.5.sa_layer_norm.weight
  - distilbert.transformer.layer.5.sa_layer_norm.bias
  - distilbert.transformer.layer.5.ffn.lin1.weight
  - distilbert.transformer.layer.5.ffn.lin1.bias
  - distilbert.transformer.layer.5.ffn.lin2.weight
  - distilbert.transformer.layer.5.ffn.lin2.bias
  - distilbert.transformer.layer.5.output_layer_norm.weight
  - distilbert.transformer.layer.5.output_layer_norm.bias
7087872




Step,Training Loss,Validation Loss
268,0.1811,0.162957
536,0.1724,0.156017
804,0.1699,0.153546
1072,0.1687,0.15238
1340,0.1667,0.150864
1608,0.1659,0.150582
1876,0.1653,0.149409
2144,0.1636,0.148962
2412,0.1638,0.148297
2680,0.1631,0.147246


There were missing keys in the checkpoint model loaded: ['vocab_projector.weight'].


TrainOutput(global_step=2686, training_loss=0.17140709515036884, metrics={'train_runtime': 12625.2906, 'train_samples_per_second': 13.618, 'train_steps_per_second': 0.213, 'total_flos': 2.2787395294058496e+16, 'train_loss': 0.17140709515036884, 'epoch': 1.9996277684719894})

In [8]:
import os

folder = 'models/'
os.makedirs(folder, exist_ok=True)
torch.save(student_model.state_dict(), os.path.join(folder, 'hiddenstates2_inhibitor_init.pth'))

In [11]:
import matplotlib.pyplot as plt
import numpy as np

def smooth_data(data, window_size):
    cumsum = np.cumsum(np.insert(data, 0, 0)) 
    return (cumsum[window_size:] - cumsum[:-window_size]) / window_size

def organize_data_by_layer(data):
    organized_data = {}
    for entry in data:
        layer = entry['layer']
        if layer not in organized_data:
            organized_data[layer] = {'student': {'mean': [], 'max': [], 'min': []},
                                     'teacher': {'mean': [], 'max': [], 'min': []}}
        for model in ['student', 'teacher']:
            for stat in ['mean', 'max', 'min']:
                organized_data[layer][model][stat].append(entry[model][stat])
    return organized_data

def plot_all_layers(data, smooth_window=100):
    organized_data = organize_data_by_layer(data)
    num_layers = len(organized_data)
    
    fig, axs = plt.subplots(num_layers, 3, figsize=(20, 8 * num_layers))
    fig.suptitle('Layer Statistics Comparison (Smoothed)', fontsize=16)
    
    if num_layers == 1:
        axs = axs.reshape(1, -1)
    
    colors = {'student': 'blue', 'teacher': 'red'}
    stats = ['mean', 'max', 'min']
    
    for layer, (layer_num, layer_data) in enumerate(sorted(organized_data.items())):
        for col, stat in enumerate(stats):
            ax = axs[layer, col]
            for model in ['student', 'teacher']:
                original_data = layer_data[model][stat]
                
                # Plot original data with low alpha
                ax.plot(original_data, color=colors[model], alpha=0.3, linewidth=1)
                
                # Smooth and plot the data
                if len(original_data) > smooth_window:
                    smoothed_data = smooth_data(original_data, smooth_window)
                    ax.plot(range(smooth_window-1, len(original_data)), smoothed_data, 
                            color=colors[model], label=f'{model.capitalize()} (Smoothed)')
                else:
                    ax.plot(original_data, color=colors[model], label=model.capitalize())
            
            ax.set_title(f'Layer {layer_num} - {stat.capitalize()}', fontsize=12)
            ax.set_xlabel('Steps', fontsize=10)
            ax.set_ylabel('Value', fontsize=10)
            ax.legend(fontsize=8)
            ax.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig('hiddenstates2_context_plot.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

In [12]:
plot_all_layers(trainer.context_loss_stats)