In [1]:
import torch

In [2]:
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)


2024-11-27 14:38:28.243050: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-27 14:38:28.260384: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732718308.281733   96272 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732718308.290186   96272 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-27 14:38:28.313085: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [3]:
import evaluate

In [4]:
import numpy as np

In [5]:
from transformers import BertConfig

In [6]:
def create_model():
    num_hidden_layers = 8
    hidden_size = 256
    num_attention_heads = 4

    model = AutoModelForSequenceClassification.from_pretrained(
        f"bert-uncased_L-{num_hidden_layers}_H-{hidden_size}_A-{num_attention_heads}",
        config=BertConfig(
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=1024,
            num_labels=2,
        )
    )
    model = model.to("cuda")
    return model

In [7]:
tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")

In [8]:
from datasets import load_dataset

In [9]:
sst2 = load_dataset("stanfordnlp/sst2")

In [10]:
def tokenize_sst2(batch):
    return tokenizer(batch["sentence"], truncation=True, max_length=512)

In [11]:
tokenized_sst2 = sst2.map(tokenize_sst2, batched=True)

In [12]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [13]:
from torch import nn

In [14]:
import torch.nn.functional as F

In [15]:
from dataclasses import dataclass

In [16]:
@dataclass
class PretrainedDistillationTrainingArguments(TrainingArguments):
    temperature: float = 1.0

In [17]:
class PretrainedDistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")

    def compute_loss(self, model, inputs, return_outputs=False):
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits
        loss = self.args.temperature**2.0 * (
            self.kl_loss(
                F.log_softmax(student_logits / self.args.temperature, dim=-1),
                F.softmax(teacher_logits / self.args.temperature, dim=-1),
            )
        )
        return (loss, student_outputs) if return_outputs else loss

In [18]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_prediction):
    predictions, labels = eval_prediction
    return accuracy.compute(
        predictions=np.argmax(predictions, axis=1),
        references=labels,
    )

In [20]:
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "google/mobilebert-uncased",
    num_labels=2,
)

Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
teacher_model.load_state_dict(torch.load("sst2-mobilebert-base.pt", weights_only=True))

<All keys matched successfully>

In [23]:
teacher_model = teacher_model.to("cuda")

In [54]:
distillation_training_args = PretrainedDistillationTrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
    temperature=10.0
)

In [55]:
distillation_trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=distillation_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [56]:
distillation_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0108,0.023131,0.813073
2,0.0088,0.023418,0.818807
3,0.0073,0.023603,0.822248
4,0.0074,0.023367,0.818807
5,0.0064,0.023607,0.817661


TrainOutput(global_step=21050, training_loss=0.008899580046286775, metrics={'train_runtime': 1046.4749, 'train_samples_per_second': 321.79, 'train_steps_per_second': 20.115, 'total_flos': 454224414765732.0, 'train_loss': 0.008899580046286775, 'epoch': 5.0})

In [None]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
distillation_training_args = PretrainedDistillationTrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
    temperature=2.0
)

In [25]:
distillation_trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=distillation_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [26]:
distillation_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4755,0.444152,0.868119
2,0.2373,0.362129,0.883028
3,0.1505,0.340881,0.885321
4,0.126,0.350254,0.887615
5,0.0965,0.320552,0.897936
6,0.081,0.305657,0.897936
7,0.069,0.288588,0.896789
8,0.0612,0.28369,0.902523
9,0.0549,0.271675,0.902523
10,0.0509,0.273902,0.899083


TrainOutput(global_step=42100, training_loss=0.18526952814885952, metrics={'train_runtime': 2525.0718, 'train_samples_per_second': 266.721, 'train_steps_per_second': 16.673, 'total_flos': 907846567642716.0, 'train_loss': 0.18526952814885952, 'epoch': 10.0})

In [66]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [67]:
distillation_training_args = PretrainedDistillationTrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
    temperature=2.0
)

In [68]:
distillation_trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=distillation_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [69]:
distillation_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.2016,0.185791,0.858945
2,0.1362,0.18907,0.862385
3,0.0869,0.164493,0.880734
4,0.079,0.178829,0.87844
5,0.0601,0.172882,0.888761


TrainOutput(global_step=21050, training_loss=0.13676211759200288, metrics={'train_runtime': 1048.7449, 'train_samples_per_second': 321.093, 'train_steps_per_second': 20.072, 'total_flos': 454224414765732.0, 'train_loss': 0.13676211759200288, 'epoch': 5.0})

In [20]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
from lora import wrap_bert_model_with_lora

In [22]:
teacher_model = wrap_bert_model_with_lora(teacher_model, rank=8, alpha=8)

In [23]:
teacher_model.load_state_dict(torch.load("sst2-lora.pt", weights_only=True))

<All keys matched successfully>

In [25]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
teacher_model.load_state_dict(torch.load("sst2-base.pt", weights_only=True))

<All keys matched successfully>

In [27]:
teacher_model = teacher_model.to("cuda")

In [25]:
del model
model = create_model()

NameError: name 'model' is not defined

In [28]:
distillation_training_args = PretrainedDistillationTrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
    temperature=2.0,
)

In [29]:
distillation_trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=distillation_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [35]:
distillation_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.3396,0.373598,0.873853
2,0.2015,0.343292,0.876147
3,0.1326,0.342818,0.888761
4,0.1126,0.319433,0.893349
5,0.0844,0.299705,0.889908


TrainOutput(global_step=21050, training_loss=0.23170526409375697, metrics={'train_runtime': 1097.813, 'train_samples_per_second': 306.742, 'train_steps_per_second': 19.174, 'total_flos': 454224414765732.0, 'train_loss': 0.23170526409375697, 'epoch': 5.0})

In [41]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
distillation_training_args = PretrainedDistillationTrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
    temperature=2.0,
)

In [31]:
distillation_trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=distillation_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [32]:
distillation_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4054,0.410278,0.860092
2,0.2282,0.339943,0.888761
3,0.1438,0.340901,0.893349
4,0.1115,0.345065,0.879587
5,0.0797,0.346559,0.879587
6,0.0677,0.293289,0.895642
7,0.0532,0.287055,0.893349
8,0.0522,0.281259,0.897936
9,0.0456,0.280736,0.895642
10,0.0405,0.273102,0.896789


TrainOutput(global_step=42100, training_loss=0.15657943215902514, metrics={'train_runtime': 2195.1596, 'train_samples_per_second': 306.807, 'train_steps_per_second': 19.179, 'total_flos': 907846567642716.0, 'train_loss': 0.15657943215902514, 'epoch': 10.0})

In [37]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [38]:
distillation_training_args = PretrainedDistillationTrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=20,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
    temperature=2.0,
)

In [39]:
distillation_trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=distillation_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [40]:
distillation_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4836,0.527755,0.830275
2,0.2877,0.401884,0.861239
3,0.1805,0.334275,0.885321
4,0.1355,0.328135,0.886468
5,0.0921,0.311932,0.891055
6,0.0789,0.303881,0.888761
7,0.0637,0.30279,0.886468
8,0.0549,0.29501,0.896789
9,0.0461,0.293332,0.896789
10,0.0395,0.310456,0.895642


TrainOutput(global_step=84200, training_loss=0.10467038980289196, metrics={'train_runtime': 4416.4305, 'train_samples_per_second': 304.993, 'train_steps_per_second': 19.065, 'total_flos': 1815796264157928.0, 'train_loss': 0.10467038980289196, 'epoch': 20.0})

In [28]:
sum([param.numel() for param in model.parameters() if param.requires_grad])

14330114

In [23]:
training_args = TrainingArguments(
    output_dir=f"tmp/sst2-compact",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
)

In [24]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [25]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.3041,0.353995,0.864679
2,0.2346,0.372156,0.875
3,0.1751,0.414538,0.884174
4,0.1631,0.455341,0.877294
5,0.1487,0.491209,0.875


TrainOutput(global_step=21050, training_loss=0.2303163706265266, metrics={'train_runtime': 543.0132, 'train_samples_per_second': 620.141, 'train_steps_per_second': 38.765, 'total_flos': 454224414765732.0, 'train_loss': 0.2303163706265266, 'epoch': 5.0})

In [31]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
training_args = TrainingArguments(
    output_dir=f"tmp/sst2-compact",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
)

In [33]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [34]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.3358,0.377583,0.84289
2,0.2524,0.377797,0.870413
3,0.1837,0.43905,0.87844
4,0.1757,0.471805,0.873853
5,0.1489,0.511784,0.872706
6,0.1273,0.548129,0.873853
7,0.1143,0.542696,0.875
8,0.1063,0.577406,0.872706
9,0.0965,0.601633,0.884174
10,0.084,0.621975,0.885321


TrainOutput(global_step=42100, training_loss=0.17282748371858211, metrics={'train_runtime': 1113.7587, 'train_samples_per_second': 604.7, 'train_steps_per_second': 37.8, 'total_flos': 907846567642716.0, 'train_loss': 0.17282748371858211, 'epoch': 10.0})

In [39]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
training_args = TrainingArguments(
    output_dir=f"tmp/sst2-compact",
    learning_rate=2e-6,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
)

In [41]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.6142,0.597795,0.71445
2,0.444,0.452075,0.787844
3,0.3715,0.399034,0.818807
4,0.3554,0.382024,0.833716
5,0.3202,0.37313,0.844037
6,0.2947,0.35537,0.855505
7,0.2788,0.353704,0.864679
8,0.2631,0.350249,0.872706
9,0.2543,0.35343,0.877294
10,0.2209,0.357953,0.875


TrainOutput(global_step=126300, training_loss=0.24467212997158652, metrics={'train_runtime': 3368.1214, 'train_samples_per_second': 599.88, 'train_steps_per_second': 37.499, 'total_flos': 2723487563851080.0, 'train_loss': 0.24467212997158652, 'epoch': 30.0})

In [None]:
del model
model = create_model()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
training_args = TrainingArguments(
    output_dir=f"tmp/sst2-compact",
    learning_rate=2e-6,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4852,0.472094,0.779817
2,0.3788,0.407089,0.817661
3,0.337,0.379842,0.837156
4,0.3352,0.374131,0.841743
5,0.3129,0.371728,0.845183
6,0.2939,0.360264,0.854358
7,0.2864,0.359505,0.862385
8,0.2792,0.355645,0.856651
9,0.2775,0.356874,0.856651
10,0.2529,0.356474,0.860092


TrainOutput(global_step=42100, training_loss=0.34087565510403234, metrics={'train_runtime': 1115.5485, 'train_samples_per_second': 603.73, 'train_steps_per_second': 37.739, 'total_flos': 907846567642716.0, 'train_loss': 0.34087565510403234, 'epoch': 10.0})