In [None]:
import torch

In [1]:
import pandas as pd

In [2]:
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)


In [3]:
import evaluate

In [36]:
import numpy as np

In [6]:
import torch
from torch.utils.data import Dataset

In [4]:
model = AutoModelForSequenceClassification.from_pretrained(
    "google/mobilebert-uncased",
    num_labels=2,
)

Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [74]:
del model

In [67]:
from transformers import BertConfig

In [77]:
num_hidden_layers = 8
hidden_size = 256
num_attention_heads = 4

model = AutoModelForSequenceClassification.from_pretrained(
    f"bert-uncased_L-{num_hidden_layers}_H-{hidden_size}_A-{num_attention_heads}",
    config=BertConfig(
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers,
        num_attention_heads=num_attention_heads,
        intermediate_size=1024,
        num_labels=2,
    )
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-uncased_L-8_H-256_A-4 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [78]:
model = model.to("cuda")

In [5]:
tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")

In [11]:
class PretrainedDistillationDataset(Dataset):
    def __init__(self, parquet_file):
        super().__init__()
        self.examples = pd.read_parquet(parquet_file)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, index):
        return self.examples.iloc[index]

In [14]:
from datasets import load_dataset


In [17]:
reviews = load_dataset(
    "parquet",
    data_files="reviews-text/reviews-unlabeled.parquet.snappy",
)

Generating train split: 0 examples [00:00, ? examples/s]

In [18]:
reviews

DatasetDict({
    train: Dataset({
        features: ['reviewText'],
        num_rows: 1697533
    })
})

In [22]:
def tokenize_reviews(batch):
    return tokenizer(batch["reviewText"], truncation=True, max_length=512)

In [23]:
tokenized_reviews = reviews.map(tokenize_reviews, batched=True)

Map:   0%|          | 0/1697533 [00:00<?, ? examples/s]

In [24]:
sst2 = load_dataset("stanfordnlp/sst2")

In [25]:
def tokenize_sst2(batch):
    return tokenizer(batch["sentence"], truncation=True, max_length=512)

In [26]:
tokenized_sst2 = sst2.map(tokenize_sst2, batched=True)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [21]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [28]:
from torch import nn

In [27]:
import torch.nn.functional as F

In [44]:
class PretrainedDistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
    
    def compute_loss(self, model, inputs, return_outputs=False):
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits
        loss = self.kl_loss(
            F.log_softmax(student_logits, dim=-1),
            F.softmax(teacher_logits, dim=-1),
        )
        return (loss, student_outputs) if return_outputs else loss

In [79]:
training_args = TrainingArguments(
    output_dir=f"tmp/sst2-distillation",
    learning_rate=2e-5,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    push_to_hub=False,
)

In [30]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_prediction):
    predictions, labels = eval_prediction
    return accuracy.compute(
        predictions=np.argmax(predictions, axis=1),
        references=labels,
    )

In [72]:
type(model)

transformers.models.bert.modeling_bert.BertForSequenceClassification

In [31]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [33]:
teacher_model.load_state_dict(torch.load("sst2-base.pt", weights_only=True))

<All keys matched successfully>

In [41]:
teacher_model = teacher_model.to("cuda")

In [80]:
trainer = PretrainedDistillationTrainer(
    model=model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [48]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.1804,0.127186,0.912844
2,0.117,0.185388,0.90367
3,0.0678,0.203811,0.908257
4,0.047,0.168926,0.91055
5,0.0994,0.195241,0.908257


TrainOutput(global_step=21050, training_loss=8471.673765370359, metrics={'train_runtime': 3546.0983, 'train_samples_per_second': 94.962, 'train_steps_per_second': 5.936, 'total_flos': 1452184699411620.0, 'train_loss': 8471.673765370359, 'epoch': 5.0})

In [81]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.278,0.253845,0.868119
2,0.1955,0.266159,0.870413
3,0.1358,0.292334,0.879587
4,0.1244,0.314349,0.875
5,0.0982,0.325385,0.877294


TrainOutput(global_step=21050, training_loss=0.19130823878381145, metrics={'train_runtime': 1057.3821, 'train_samples_per_second': 318.47, 'train_steps_per_second': 19.908, 'total_flos': 454224414765732.0, 'train_loss': 0.19130823878381145, 'epoch': 5.0})