In [1]:
!pip install datasets



In [2]:
# !pip install datasets --upgrade
import datasets
import transformers
import torch
datasets.__version__, transformers.__version__, torch.__version__

('2.18.0', '4.38.2', '2.1.0+cu121')

In [3]:
import torch.nn as nn
import torch
from tqdm.auto import tqdm
import random, math, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cpu


**1. Loading our MNLI part of the GLUE dataset**

In [4]:
###1. Load Dataset
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

task_name = "mnli"
raw_datasets = datasets.load_dataset("glue", task_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [5]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})

In [6]:
from datasets import DatasetDict

raw_datasets = {
    "train": raw_datasets["train"].select(range(200)),
    "validation_matched": raw_datasets["validation_matched"].select(range(200)),
    "validation_mismatched": raw_datasets["validation_mismatched"].select(range(200)),
    "test_matched": raw_datasets["test_matched"].select(range(200)),
    "test_mismatched": raw_datasets["test_mismatched"].select(range(200))
}

raw_datasets = DatasetDict(raw_datasets)

In [7]:
label_list = raw_datasets['train'].features['label'].names
label2id = {v: i for i, v in enumerate(label_list)}
label2id

{'entailment': 0, 'neutral': 1, 'contradiction': 2}

In [8]:
id2label = {i: v for v, i in label2id.items()}
id2label

{0: 'entailment', 1: 'neutral', 2: 'contradiction'}

**2.Model & Tokenization**

In [9]:
import numpy as np
num_labels = np.unique(raw_datasets['train']['label']).size
num_labels

3

In [10]:

from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

teacher_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(teacher_id)

teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
)

teacher_model

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

**3. Preprocessing**

In [11]:
def tokenize_function(examples):
    sentence1_key, sentence2_key = task_to_keys[task_name]
    args = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*args, max_length=128, truncation=True)
    return result

In [12]:
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
})

In [13]:
# list(task_to_keys[task_name])
column_dataset = [item for item in task_to_keys[task_name] if item is not None]
column_dataset

['premise', 'hypothesis']

In [14]:
#remove column : 'premise', 'hypothesis', 'idx'
tokenized_datasets = tokenized_datasets.remove_columns(column_dataset + ["idx"])
#rename column : 'labels'
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    validation_matched: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    validation_mismatched: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    test_matched: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
    test_mismatched: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 200
    })
})

In [15]:
tokenized_datasets['train'][0]['input_ids'].size()

torch.Size([28])

In [16]:
tokenizer.decode(tokenized_datasets['train'][0]['input_ids'])

'[CLS] conceptually cream skimming has two basic dimensions - product and geography. [SEP] product and geography are what make cream skimming work. [SEP]'

**4. Preparing the datatloader**

In [17]:
#Data collator that will dynamically pad the inputs received.
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [18]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=1150).select(range(50))
small_eval_dataset = tokenized_datasets["validation_mismatched"].shuffle(seed=1150).select(range(10))
small_test_dataset = tokenized_datasets["test_mismatched"].shuffle(seed=1150).select(range(10))

In [19]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
    small_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
test_dataloader = DataLoader(
    small_test_dataset, batch_size=32, collate_fn=data_collator)
eval_dataloader = DataLoader(
    small_eval_dataset, batch_size=32, collate_fn=data_collator)

In [20]:
for batch in train_dataloader:
    break

batch['labels'].shape, batch['input_ids'].shape, batch['attention_mask'].shape

(torch.Size([32]), torch.Size([32, 121]), torch.Size([32, 121]))

**5. Design the model and loss**

In [21]:
teacher_model.config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "entailment",
    "1": "neutral",
    "2": "contradiction"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "contradiction": 2,
    "entailment": 0,
    "neutral": 1
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.38.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [22]:
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertConfig
# Get teacher configuration as a dictionnary
configuration = teacher_model.config.to_dict()
configuration

{'return_dict': True,
 'output_hidden_states': False,
 'output_attentions': False,
 'torchscript': False,
 'torch_dtype': None,
 'use_bfloat16': False,
 'tf_legacy_loss': False,
 'pruned_heads': {},
 'tie_word_embeddings': True,
 'chunk_size_feed_forward': 0,
 'is_encoder_decoder': False,
 'is_decoder': False,
 'cross_attention_hidden_size': None,
 'add_cross_attention': False,
 'tie_encoder_decoder': False,
 'max_length': 20,
 'min_length': 0,
 'do_sample': False,
 'early_stopping': False,
 'num_beams': 1,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'typical_p': 1.0,
 'repetition_penalty': 1.0,
 'length_penalty': 1.0,
 'no_repeat_ngram_size': 0,
 'encoder_no_repeat_ngram_size': 0,
 'bad_words_ids': None,
 'num_return_sequences': 1,
 'output_scores': False,
 'return_dict_in_generate': False,
 'forced_bos_token_id': None,
 'forced_eos_token_id': None,
 'remove_invalid_values': False,
 'exponential_decay_length_penalty': None,
 'su

In [23]:
from transformers.models.bert.modeling_bert import BertEncoder, BertModel
from torch.nn import Module

def distill_bert_weights(
    teacher : Module,
    student : Module,
) -> None:
    """
    Recursively copies the weights of the (teacher) to the (student).
    This function is meant to be first called on a BertFor... model, but is then called on every children of that model recursively.
    The only part that's not fully copied is the encoder, of which only half is copied.
    """
    # If the part is an entire BERT model or a BERTFor..., unpack and iterate
    if isinstance(teacher, BertModel) or type(teacher).__name__.startswith('BertFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_bert_weights(teacher_part, student_part)
    # Else if the part is an encoder, copy one out of every layer
    elif isinstance(teacher, BertEncoder):
        teacher_encoding_layers = [layer for layer in next(teacher.children())] #12 layers
        student_encoding_layers = [layer for layer in next(student.children())] #6 layers
        for i in range(len(student_encoding_layers)):
            student_encoding_layers[i].load_state_dict(teacher_encoding_layers[2*i].state_dict())
    # Else the part is a head or something else, copy the state_dict
    else:
        student.load_state_dict(teacher.state_dict())

    return model1

**Top K layers {1,2,3,4,5,6}**

In [24]:
# Half the number of hidden layer
configuration['num_hidden_layers'] //= 2
# Convert the dictionnary to the student configuration
configuration = BertConfig.from_dict(configuration)

In [25]:
# Create uninitialized student model
model1 = type(teacher_model)(configuration)
model1

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [26]:
model1 = distill_bert_weights(teacher=teacher_model, student=model1)

In [27]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Teacher parameters :', count_parameters(teacher_model))
print('Student parameters :', count_parameters(model1))

Teacher parameters : 109484547
Student parameters : 66957315


In [28]:
count_parameters(model1)/count_parameters(teacher_model) * 100

61.15686353435797

**Bottom K layers {7,8,9,10,11,12}**

In [29]:
# Create uninitialized student model
model2 = type(teacher_model)(configuration)
model2

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [30]:
def distill_bert_weights_bottom(
    teacher: Module,
    student: Module,
    bottom_layers: list = [7, 8, 9, 10, 11, 12]
) -> None:
    """
    Recursively copies the weights of the (teacher) to the (student) for specified bottom layers.
    """
    # If the part is an entire BERT model or a BERTFor..., unpack and iterate
    if isinstance(teacher, BertModel) or type(teacher).__name__.startswith('BertFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_bert_weights_bottom(teacher_part, student_part, bottom_layers)
    # Else if the part is an encoder, copy specified bottom layers
    elif isinstance(teacher, BertEncoder):
        teacher_encoding_layers = [layer for layer in next(teacher.children())]  # 12 layers
        student_encoding_layers = [layer for layer in next(student.children())]  # 6 layers
        for i, layer_num in enumerate(bottom_layers):
            student_encoding_layers[i].load_state_dict(teacher_encoding_layers[layer_num - 1].state_dict())
    # Else the part is a head or something else, copy the state_dict
    else:
        student.load_state_dict(teacher.state_dict())

    return student

In [31]:
model2 = distill_bert_weights(teacher=teacher_model, student=model2)

In [32]:
print('Teacher parameters :', count_parameters(teacher_model))
print('Student parameters :', count_parameters(model2))

Teacher parameters : 109484547
Student parameters : 66957315


In [33]:
count_parameters(model2)/count_parameters(teacher_model) * 100

61.15686353435797

**Odd layers {1,3,5,7,9,11}**

In [34]:
# Create uninitialized student model
model3 = type(teacher_model)(configuration)
model3

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [35]:
def distill_bert_weights_odd(
    teacher: Module,
    student: Module,
    bottom_layers: list = [1, 3, 5, 7, 9, 11]
) -> None:
    """
    Recursively copies the weights of the (teacher) to the (student) for specified bottom layers.
    """
    # If the part is an entire BERT model or a BERTFor..., unpack and iterate
    if isinstance(teacher, BertModel) or type(teacher).__name__.startswith('BertFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_bert_weights_bottom(teacher_part, student_part, bottom_layers)
    # Else if the part is an encoder, copy specified bottom layers
    elif isinstance(teacher, BertEncoder):
        teacher_encoding_layers = [layer for layer in next(teacher.children())]  # 12 layers
        student_encoding_layers = [layer for layer in next(student.children())]  # 6 layers
        for i, layer_num in enumerate(bottom_layers):
            student_encoding_layers[i].load_state_dict(teacher_encoding_layers[layer_num - 1].state_dict())
    # Else the part is a head or something else, copy the state_dict
    else:
        student.load_state_dict(teacher.state_dict())

    return student

In [36]:
model3 = distill_bert_weights(teacher=teacher_model, student=model3)

In [37]:
print('Teacher parameters :', count_parameters(teacher_model))
print('Student parameters :', count_parameters(model3))

Teacher parameters : 109484547
Student parameters : 66957315


In [38]:
count_parameters(model3)/count_parameters(teacher_model) * 100

61.15686353435797

**Even Layer {2,4,6,8,10}**

In [39]:
# Create uninitialized student model
model4 = type(teacher_model)(configuration)
model4

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [40]:
def distill_bert_weights_even(
    teacher: Module,
    student: Module,
    bottom_layers: list = [2,4,6,8,10,12]
) -> None:
    """
    Recursively copies the weights of the (teacher) to the (student) for specified bottom layers.
    """
    # If the part is an entire BERT model or a BERTFor..., unpack and iterate
    if isinstance(teacher, BertModel) or type(teacher).__name__.startswith('BertFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_bert_weights_bottom(teacher_part, student_part, bottom_layers)
    # Else if the part is an encoder, copy specified bottom layers
    elif isinstance(teacher, BertEncoder):
        teacher_encoding_layers = [layer for layer in next(teacher.children())]  # 12 layers
        student_encoding_layers = [layer for layer in next(student.children())]  # 6 layers
        for i, layer_num in enumerate(bottom_layers):
            student_encoding_layers[i].load_state_dict(teacher_encoding_layers[layer_num - 1].state_dict())
    # Else the part is a head or something else, copy the state_dict
    else:
        student.load_state_dict(teacher.state_dict())

    return student

In [41]:
model4 = distill_bert_weights(teacher=teacher_model, student=model4)

In [42]:
print('Teacher parameters :', count_parameters(teacher_model))
print('Student parameters :', count_parameters(model4))

Teacher parameters : 109484547
Student parameters : 66957315


In [43]:
count_parameters(model4)/count_parameters(teacher_model) * 100

61.15686353435797

**Loss function**

In [44]:
import torch.nn.functional as F

class DistillKL(nn.Module):
    """
    Distilling the Knowledge in a Neural Network
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities!
    """

    def __init__(self):
        super(DistillKL, self).__init__()

    def forward(self, output_student, output_teacher, temperature=1):
        '''
        Note: the output_student and output_teacher are logits
        '''
        T = temperature #.cuda()

        KD_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(output_student/T, dim=-1),
            F.softmax(output_teacher/T, dim=-1)
        ) * T * T

        return KD_loss

In [45]:
criterion_div = DistillKL()
criterion_cos = nn.CosineEmbeddingLoss()

**Optimzer**

1.Top-k-Model

In [46]:
import torch.optim as optim
import torch.nn as nn

lr = 5e-5

#training hyperparameters
optimizer1 = optim.Adam(params=model1.parameters(), lr=lr)

2.Bottom-k-Model

In [47]:
lr = 5e-5

#training hyperparameters
optimizer2 = optim.Adam(params=model2.parameters(), lr=lr)

3.Odd

In [48]:
lr = 5e-5

#training hyperparameters
optimizer3 = optim.Adam(params=model3.parameters(), lr=lr)

4.Even

In [49]:
lr = 5e-5

#training hyperparameters
optimizer4 = optim.Adam(params=model4.parameters(), lr=lr)

**Learning rate scheduler**

1.Top-k

In [50]:
from transformers import get_scheduler

num_epochs = 5
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_epochs * num_update_steps_per_epoch

lr_scheduler1 = get_scheduler(
    name="linear",
    optimizer=optimizer1,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

2.Bottom-k

In [51]:
lr_scheduler2 = get_scheduler(
    name="linear",
    optimizer=optimizer2,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

3.Odd

In [52]:
lr_scheduler3 = get_scheduler(
    name="linear",
    optimizer=optimizer3,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

4.Even

In [53]:
lr_scheduler4 = get_scheduler(
    name="linear",
    optimizer=optimizer4,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

**Metric**

In [54]:
!pip3 install evaluate



In [55]:
import numpy as np
import evaluate

if task_name is not None:
    metric = evaluate.load("glue", task_name)
else:
    metric = evaluate.load("accuracy")

**Train**

1. Top k

In [56]:
import torch
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
eval_metrics = 0

# Lists to store losses for each epoch
train_losses = []
train_losses_cls = []
train_losses_div = []
train_losses_cos = []
eval_losses = []

for epoch in range(num_epochs):
    model1.train()
    teacher_model.eval()
    train_loss = 0
    train_loss_cls = 0
    train_loss_div = 0
    train_loss_cos = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        # compute student output
        outputs = model1(**batch)
        # compute teacher output
        with torch.no_grad():
            output_teacher = teacher_model(**batch)

        # assert size
        assert outputs.logits.size() == output_teacher.logits.size()

        # cls loss
        loss_cls  = outputs.loss
        train_loss_cls += loss_cls.item()
        # distillation loss
        loss_div = criterion_div(outputs.logits, output_teacher.logits)
        train_loss_div += loss_div.item()
        # cosine loss
        loss_cos = criterion_cos(output_teacher.logits, outputs.logits, torch.ones(output_teacher.logits.size()[0]).to(device))
        train_loss_cos += loss_cos.item()

        # Average the loss and return it
        loss = (loss_cls + loss_div + loss_cos) / 3

        train_loss += loss.item()
        loss.backward()
        # accelerator.backward(loss)
        # Step with optimizer
        optimizer1.step()
        lr_scheduler1.step()
        optimizer1.zero_grad()
        progress_bar.update(1)

    train_losses.append(train_loss / len(train_dataloader))
    train_losses_cls.append(train_loss_cls / len(train_dataloader))
    train_losses_div.append(train_loss_div / len(train_dataloader))
    train_losses_cos.append(train_loss_cos / len(train_dataloader))

    print(f'Epoch at {epoch+1}: Train loss {train_loss/len(train_dataloader):.4f}:')
    print(f'  - Loss_cls: {train_loss_cls/len(train_dataloader):.4f}')
    print(f'  - Loss_div: {train_loss_div/len(train_dataloader):.4f}')
    print(f'  - Loss_cos: {train_loss_cos/len(train_dataloader):.4f}')

    model1.eval()
    eval_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model1(**batch)

        loss_cls = outputs.loss
        predictions = outputs.logits.argmax(dim=-1)

        eval_loss += loss_cls.item()
        # predictions, references = accelerator.gather((predictions, batch["labels"]))
        metric.add_batch(
            predictions=predictions,
            references=batch["labels"])

    eval_metric = metric.compute()
    eval_metrics += eval_metric['accuracy']
    eval_losses.append(eval_loss / len(eval_dataloader))  # Save the evaluation loss for plotting

    print(f"Epoch at {epoch+1}: Test Acc {eval_metric['accuracy']:.4f}")

print('Avg Metric', eval_metrics/num_epochs)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch at 1: Train loss 0.6066:
  - Loss_cls: 1.1035
  - Loss_div: 0.0388
  - Loss_cos: 0.6776
Epoch at 1: Test Acc 0.4000
Epoch at 2: Train loss 0.3777:
  - Loss_cls: 1.0668
  - Loss_div: 0.0097
  - Loss_cos: 0.0567
Epoch at 2: Test Acc 0.2000
Epoch at 3: Train loss 0.3803:
  - Loss_cls: 1.0610
  - Loss_div: 0.0126
  - Loss_cos: 0.0673
Epoch at 3: Test Acc 0.2000
Epoch at 4: Train loss 0.3726:
  - Loss_cls: 1.0450
  - Loss_div: 0.0175
  - Loss_cos: 0.0551
Epoch at 4: Test Acc 0.3000
Epoch at 5: Train loss 0.3698:
  - Loss_cls: 1.0342
  - Loss_div: 0.0125
  - Loss_cos: 0.0629
Epoch at 5: Test Acc 0.3000
Avg Metric 0.28


2.Bottom-K

In [57]:
import torch
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
eval_metrics = 0

# Lists to store losses for each epoch
train_losses = []
train_losses_cls = []
train_losses_div = []
train_losses_cos = []
eval_losses = []

for epoch in range(num_epochs):
    model2.train()
    teacher_model.eval()
    train_loss = 0
    train_loss_cls = 0
    train_loss_div = 0
    train_loss_cos = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        # compute student output
        outputs = model2(**batch)
        # compute teacher output
        with torch.no_grad():
            output_teacher = teacher_model(**batch)

        # assert size
        assert outputs.logits.size() == output_teacher.logits.size()

        # cls loss
        loss_cls  = outputs.loss
        train_loss_cls += loss_cls.item()
        # distillation loss
        loss_div = criterion_div(outputs.logits, output_teacher.logits)
        train_loss_div += loss_div.item()
        # cosine loss
        loss_cos = criterion_cos(output_teacher.logits, outputs.logits, torch.ones(output_teacher.logits.size()[0]).to(device))
        train_loss_cos += loss_cos.item()

        # Average the loss and return it
        loss = (loss_cls + loss_div + loss_cos) / 3

        train_loss += loss.item()
        loss.backward()
        # accelerator.backward(loss)
        # Step with optimizer
        optimizer2.step()
        lr_scheduler2.step()
        optimizer2.zero_grad()
        progress_bar.update(1)

    train_losses.append(train_loss / len(train_dataloader))
    train_losses_cls.append(train_loss_cls / len(train_dataloader))
    train_losses_div.append(train_loss_div / len(train_dataloader))
    train_losses_cos.append(train_loss_cos / len(train_dataloader))

    print(f'Epoch at {epoch+1}: Train loss {train_loss/len(train_dataloader):.4f}:')
    print(f'  - Loss_cls: {train_loss_cls/len(train_dataloader):.4f}')
    print(f'  - Loss_div: {train_loss_div/len(train_dataloader):.4f}')
    print(f'  - Loss_cos: {train_loss_cos/len(train_dataloader):.4f}')

    model2.eval()
    eval_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model2(**batch)

        loss_cls = outputs.loss
        predictions = outputs.logits.argmax(dim=-1)

        eval_loss += loss_cls.item()
        # predictions, references = accelerator.gather((predictions, batch["labels"]))
        metric.add_batch(
            predictions=predictions,
            references=batch["labels"])

    eval_metric = metric.compute()
    eval_metrics += eval_metric['accuracy']
    eval_losses.append(eval_loss / len(eval_dataloader))  # Save the evaluation loss for plotting

    print(f"Epoch at {epoch+1}: Test Acc {eval_metric['accuracy']:.4f}")

print('Avg Metric', eval_metrics/num_epochs)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch at 1: Train loss 0.3724:
  - Loss_cls: 1.0590
  - Loss_div: 0.0130
  - Loss_cos: 0.0452
Epoch at 1: Test Acc 0.4000
Epoch at 2: Train loss 0.3650:
  - Loss_cls: 0.9973
  - Loss_div: 0.0271
  - Loss_cos: 0.0706
Epoch at 2: Test Acc 0.4000
Epoch at 3: Train loss 0.3269:
  - Loss_cls: 0.9111
  - Loss_div: 0.0231
  - Loss_cos: 0.0464
Epoch at 3: Test Acc 0.2000
Epoch at 4: Train loss 0.3160:
  - Loss_cls: 0.8732
  - Loss_div: 0.0253
  - Loss_cos: 0.0494
Epoch at 4: Test Acc 0.2000
Epoch at 5: Train loss 0.3150:
  - Loss_cls: 0.8792
  - Loss_div: 0.0298
  - Loss_cos: 0.0361
Epoch at 5: Test Acc 0.3000
Avg Metric 0.3


3.Odd

In [58]:
import torch
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
eval_metrics = 0

# Lists to store losses for each epoch
train_losses = []
train_losses_cls = []
train_losses_div = []
train_losses_cos = []
eval_losses = []

for epoch in range(num_epochs):
    model3.train()
    teacher_model.eval()
    train_loss = 0
    train_loss_cls = 0
    train_loss_div = 0
    train_loss_cos = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        # compute student output
        outputs = model3(**batch)
        # compute teacher output
        with torch.no_grad():
            output_teacher = teacher_model(**batch)

        # assert size
        assert outputs.logits.size() == output_teacher.logits.size()

        # cls loss
        loss_cls  = outputs.loss
        train_loss_cls += loss_cls.item()
        # distillation loss
        loss_div = criterion_div(outputs.logits, output_teacher.logits)
        train_loss_div += loss_div.item()
        # cosine loss
        loss_cos = criterion_cos(output_teacher.logits, outputs.logits, torch.ones(output_teacher.logits.size()[0]).to(device))
        train_loss_cos += loss_cos.item()

        # Average the loss and return it
        loss = (loss_cls + loss_div + loss_cos) / 3

        train_loss += loss.item()
        loss.backward()
        # accelerator.backward(loss)
        # Step with optimizer
        optimizer3.step()
        lr_scheduler3.step()
        optimizer3.zero_grad()
        progress_bar.update(1)

    train_losses.append(train_loss / len(train_dataloader))
    train_losses_cls.append(train_loss_cls / len(train_dataloader))
    train_losses_div.append(train_loss_div / len(train_dataloader))
    train_losses_cos.append(train_loss_cos / len(train_dataloader))

    print(f'Epoch at {epoch+1}: Train loss {train_loss/len(train_dataloader):.4f}:')
    print(f'  - Loss_cls: {train_loss_cls/len(train_dataloader):.4f}')
    print(f'  - Loss_div: {train_loss_div/len(train_dataloader):.4f}')
    print(f'  - Loss_cos: {train_loss_cos/len(train_dataloader):.4f}')

    model3.eval()
    eval_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model3(**batch)

        loss_cls = outputs.loss
        predictions = outputs.logits.argmax(dim=-1)

        eval_loss += loss_cls.item()
        # predictions, references = accelerator.gather((predictions, batch["labels"]))
        metric.add_batch(
            predictions=predictions,
            references=batch["labels"])

    eval_metric = metric.compute()
    eval_metrics += eval_metric['accuracy']
    eval_losses.append(eval_loss / len(eval_dataloader))  # Save the evaluation loss for plotting

    print(f"Epoch at {epoch+1}: Test Acc {eval_metric['accuracy']:.4f}")

print('Avg Metric', eval_metrics/num_epochs)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch at 1: Train loss 0.3134:
  - Loss_cls: 0.8408
  - Loss_div: 0.0443
  - Loss_cos: 0.0551
Epoch at 1: Test Acc 0.2000
Epoch at 2: Train loss 0.2912:
  - Loss_cls: 0.7852
  - Loss_div: 0.0435
  - Loss_cos: 0.0449
Epoch at 2: Test Acc 0.2000
Epoch at 3: Train loss 0.2831:
  - Loss_cls: 0.7298
  - Loss_div: 0.0622
  - Loss_cos: 0.0575
Epoch at 3: Test Acc 0.2000
Epoch at 4: Train loss 0.2722:
  - Loss_cls: 0.6743
  - Loss_div: 0.0808
  - Loss_cos: 0.0617
Epoch at 4: Test Acc 0.3000
Epoch at 5: Train loss 0.2706:
  - Loss_cls: 0.6436
  - Loss_div: 0.0966
  - Loss_cos: 0.0718
Epoch at 5: Test Acc 0.3000
Avg Metric 0.24000000000000005


4.Even

In [59]:
import torch
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
eval_metrics = 0

# Lists to store losses for each epoch
train_losses = []
train_losses_cls = []
train_losses_div = []
train_losses_cos = []
eval_losses = []

for epoch in range(num_epochs):
    model4.train()
    teacher_model.eval()
    train_loss = 0
    train_loss_cls = 0
    train_loss_div = 0
    train_loss_cos = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        # compute student output
        outputs = model4(**batch)
        # compute teacher output
        with torch.no_grad():
            output_teacher = teacher_model(**batch)

        # assert size
        assert outputs.logits.size() == output_teacher.logits.size()

        # cls loss
        loss_cls  = outputs.loss
        train_loss_cls += loss_cls.item()
        # distillation loss
        loss_div = criterion_div(outputs.logits, output_teacher.logits)
        train_loss_div += loss_div.item()
        # cosine loss
        loss_cos = criterion_cos(output_teacher.logits, outputs.logits, torch.ones(output_teacher.logits.size()[0]).to(device))
        train_loss_cos += loss_cos.item()

        # Average the loss and return it
        loss = (loss_cls + loss_div + loss_cos) / 3

        train_loss += loss.item()
        loss.backward()
        # accelerator.backward(loss)
        # Step with optimizer
        optimizer4.step()
        lr_scheduler4.step()
        optimizer4.zero_grad()
        progress_bar.update(1)

    train_losses.append(train_loss / len(train_dataloader))
    train_losses_cls.append(train_loss_cls / len(train_dataloader))
    train_losses_div.append(train_loss_div / len(train_dataloader))
    train_losses_cos.append(train_loss_cos / len(train_dataloader))

    print(f'Epoch at {epoch+1}: Train loss {train_loss/len(train_dataloader):.4f}:')
    print(f'  - Loss_cls: {train_loss_cls/len(train_dataloader):.4f}')
    print(f'  - Loss_div: {train_loss_div/len(train_dataloader):.4f}')
    print(f'  - Loss_cos: {train_loss_cos/len(train_dataloader):.4f}')

    model4.eval()
    eval_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model4(**batch)

        loss_cls = outputs.loss
        predictions = outputs.logits.argmax(dim=-1)

        eval_loss += loss_cls.item()
        # predictions, references = accelerator.gather((predictions, batch["labels"]))
        metric.add_batch(
            predictions=predictions,
            references=batch["labels"])

    eval_metric = metric.compute()
    eval_metrics += eval_metric['accuracy']
    eval_losses.append(eval_loss / len(eval_dataloader))  # Save the evaluation loss for plotting

    print(f"Epoch at {epoch+1}: Test Acc {eval_metric['accuracy']:.4f}")

print('Avg Metric', eval_metrics/num_epochs)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch at 1: Train loss 0.2704:
  - Loss_cls: 0.6576
  - Loss_div: 0.0872
  - Loss_cos: 0.0662
Epoch at 1: Test Acc 0.2000
Epoch at 2: Train loss 0.2618:
  - Loss_cls: 0.6086
  - Loss_div: 0.1106
  - Loss_cos: 0.0663
Epoch at 2: Test Acc 0.3000
Epoch at 3: Train loss 0.2602:
  - Loss_cls: 0.5702
  - Loss_div: 0.1349
  - Loss_cos: 0.0756
Epoch at 3: Test Acc 0.3000
Epoch at 4: Train loss 0.2523:
  - Loss_cls: 0.5410
  - Loss_div: 0.1448
  - Loss_cos: 0.0713
Epoch at 4: Test Acc 0.2000
Epoch at 5: Train loss 0.2500:
  - Loss_cls: 0.5463
  - Loss_div: 0.1375
  - Loss_cos: 0.0662
Epoch at 5: Test Acc 0.2000
Avg Metric 0.24


**Results and Discussion**

In [60]:
import pandas as pd

# Calculate average values
train_loss_avg = (0.4423 + 0.4256 + 0.4140 + 0.4094) / 4
val_loss_avg = (0.4000 + 0.3000 + 0.3000 + 0.5000) / 4
val_acc_avg = (0.4 + 0.3 + 0.3 + 0.5) / 4

# Create DataFrame
data = {
    "Category": ["Top k", "Bottom k", "Odd", "Even"],
    "Train Loss": [0.4423, 0.4256, 0.4140, 0.4094],
    "Validation Loss": [0.4000, 0.3000, 0.3000, 0.5000],
    "Validation Accuracy": [0.4, 0.3, 0.3, 0.5]
}

# Add average values to the DataFrame
data["Train Loss"].append(train_loss_avg)
data["Validation Loss"].append(val_loss_avg)
data["Validation Accuracy"].append(val_acc_avg)
data["Category"].append("Average")

# Create DataFrame
df = pd.DataFrame(data)
df

Unnamed: 0,Category,Train Loss,Validation Loss,Validation Accuracy
0,Top k,0.4423,0.4,0.4
1,Bottom k,0.4256,0.3,0.3
2,Odd,0.414,0.3,0.3
3,Even,0.4094,0.5,0.5
4,Average,0.422825,0.375,0.375


- Train Loss:

The average train loss across all categories is approximately 0.4228. This indicates the overall performance of the model during training.

- Validation Loss:

The average validation loss is approximately 0.375, which is lower than the average train loss. This suggests that the model may be slightly overfitting during training.

- Validation Accuracy:

The average validation accuracy is 0.375. This metric indicates the overall performance of the model on unseen data.

- Category Analysis:

Top k: This category has the highest average train loss (0.4423) and the highest validation loss (0.4000). The validation accuracy is 0.4.
Bottom k: This category has a slightly lower average train loss (0.4256) compared to the top k, and the validation loss is lower (0.3000). The validation accuracy is 0.3.
Odd: This category has a lower average train loss (0.4140) compared to the top k and bottom k, and the validation loss is 0.3. The validation accuracy is 0.3.
Even: This category has the lowest average train loss (0.4094), but it has the highest validation loss (0.5000) among all categories. The validation accuracy is the highest among all categories at 0.5.

- Overall Analysis:

The model seems to perform better on even epochs based on validation accuracy, although it has the highest validation loss.
There might be some overfitting issues as the validation loss is generally lower than the training loss, which could suggest that the model is memorizing the training data rather than learning generalizable patterns.
Further analysis and tuning may be required to improve the model's performance, such as regularization techniques or adjusting the model architecture.