In [1]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [2]:
from datasets import load_dataset

In [3]:
import numpy as np

In [4]:
from typing import Dict, List

In [5]:
from tqdm.notebook import tqdm

In [6]:
sst2 = load_dataset("stanfordnlp/sst2")

In [7]:
sst2["train"]

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 67349
})

In [8]:
sst2["validation"]

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 872
})

In [9]:
sst2["test"]

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 1821
})

In [10]:
tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")

In [11]:
def tokenize(batch: Dict[str, List]):
    return tokenizer(batch["sentence"], truncation=True, max_length=512)

In [12]:
tokenized_sst2 = sst2.map(tokenize, batched=True)

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [13]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [14]:
import evaluate

accuracy = evaluate.load("accuracy")

In [15]:
def compute_metrics(eval_prediction):
    predictions, labels = eval_prediction
    return accuracy.compute(
        predictions=np.argmax(predictions, axis=1),
        references=labels,
    )

In [16]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

In [17]:
model = AutoModelForSequenceClassification.from_pretrained(
    "google/mobilebert-uncased",
    num_labels=2,
)

Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
model

MobileBertForSequenceClassification(
  (mobilebert): MobileBertModel(
    (embeddings): MobileBertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (embedding_transformation): Linear(in_features=384, out_features=512, bias=True)
      (LayerNorm): NoNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): MobileBertEncoder(
      (layer): ModuleList(
        (0-23): 24 x MobileBertLayer(
          (attention): MobileBertAttention(
            (self): MobileBertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=512, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): MobileBertSelfOutput(
              (dense): Linear(in_fe

In [19]:
model = model.to("mps")
model.device

device(type='mps', index=0)

In [20]:
training_args = TrainingArguments(
    output_dir="tmp/mobilebert",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.1948,0.279231,0.917431
2,0.1577,0.367442,0.909404
3,0.1121,0.348334,0.922018
4,0.1033,0.399022,0.916284
5,0.0746,0.477361,0.909404


TrainOutput(global_step=21050, training_loss=898.4853809164256, metrics={'train_runtime': 9077.9515, 'train_samples_per_second': 37.095, 'train_steps_per_second': 2.319, 'total_flos': 1452184699411620.0, 'train_loss': 898.4853809164256, 'epoch': 5.0})

In [21]:
model

MobileBertForSequenceClassification(
  (mobilebert): MobileBertModel(
    (embeddings): MobileBertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (embedding_transformation): Linear(in_features=384, out_features=512, bias=True)
      (LayerNorm): NoNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): MobileBertEncoder(
      (layer): ModuleList(
        (0-23): 24 x MobileBertLayer(
          (attention): MobileBertAttention(
            (self): MobileBertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=512, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): MobileBertSelfOutput(
              (dense): Linear(in_fe

In [29]:
list(model.mobilebert.encoder.layer[0].attention.self.query.parameters())

[Parameter containing:
 tensor([[-0.2266, -0.0468, -0.1313,  ..., -0.0694,  0.1024,  0.0643],
         [ 0.2212, -0.1479,  0.0529,  ...,  0.1253, -0.0353, -0.0575],
         [ 0.0316, -0.0239,  0.1146,  ..., -0.0174, -0.0099,  0.0407],
         ...,
         [ 0.0746,  0.1113,  0.1695,  ...,  0.1416,  0.0800,  0.0718],
         [-0.1779, -0.0636, -0.2137,  ...,  0.0496, -0.0058,  0.1994],
         [ 0.1006,  0.0032,  0.1494,  ..., -0.0088,  0.1172, -0.1506]],
        device='mps:0', requires_grad=True),
 Parameter containing:
 tensor([ 9.9882e-02, -2.1004e-01, -7.4925e-01, -3.6873e-01,  1.5776e-01,
         -1.2681e-01, -1.4546e-01, -3.0952e-02, -3.1803e-01,  1.3049e-01,
          2.8325e-01,  1.2784e-01,  1.0947e-01,  8.4597e-03,  1.9864e-01,
          1.9554e-01, -1.0397e-01,  8.0338e-02,  2.3664e+00,  4.5677e-02,
         -1.5654e-01, -2.9233e-01, -2.4647e-01, -4.6836e-01, -7.0618e-02,
          1.2230e-02,  2.3433e-01, -1.3977e-03,  8.0074e-02,  2.1029e-01,
          3.6164e-02,  8

In [79]:
sum([param.numel() for param in model.parameters() if param.requires_grad])

24582914

In [19]:
del model

In [35]:
import torch
from torch import nn

In [42]:
class LoRAUpdate(nn.Module):
    def __init__(self, in_features, out_features, rank, alpha, device=None):
        super().__init__()
        self.A = nn.Parameter(data=torch.empty((out_features, rank), device=device))
        with torch.no_grad():
            nn.init.xavier_normal_(self.A)
        self.B = nn.Parameter(data=torch.zeros((rank, in_features), device=device))
        self.rank = rank
        self.alpha = alpha

    def forward(self, input):
        return (input @ self.A @ self.B) * self.alpha / self.rank

In [98]:
class LoRALayer(nn.Module):
    def __init__(self, original_layer: nn.Linear, rank: int, alpha: int, device=None):
        super().__init__()
        self.original_layer = original_layer
        self.lora_update = LoRAUpdate(
            in_features=original_layer.in_features,
            out_features=original_layer.out_features,
            rank=rank,
            alpha=alpha,
            device=device,
        )

    def forward(self, input):
        return self.original_layer(input) + self.lora_update(input)

In [116]:
lora_model = AutoModelForSequenceClassification.from_pretrained(
    "google/mobilebert-uncased",
    num_labels=2,
)

Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [117]:
for param in lora_model.parameters():
    param.requires_grad = False

In [118]:
sum([param.numel() for param in lora_model.parameters() if param.requires_grad])

0

In [119]:
rank = 2
alpha = 2

for layer in lora_model.mobilebert.encoder.layer:
    layer.attention.self.query = LoRALayer(
        layer.attention.self.query,
        rank=rank,
        alpha=alpha,
    )
    layer.attention.self.key = LoRALayer(
        layer.attention.self.key,
        rank=rank,
        alpha=alpha,
    )
    # layer.attention.self.value = LoRALayer(
    #     layer.attention.self.value,
    #     rank=rank,
    #     alpha=alpha,
    # )
    layer.attention.output.dense = LoRALayer(
        layer.attention.output.dense,
        rank=rank,
        alpha=alpha,
    )

In [120]:
lora_model

MobileBertForSequenceClassification(
  (mobilebert): MobileBertModel(
    (embeddings): MobileBertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (embedding_transformation): Linear(in_features=384, out_features=512, bias=True)
      (LayerNorm): NoNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): MobileBertEncoder(
      (layer): ModuleList(
        (0-23): 24 x MobileBertLayer(
          (attention): MobileBertAttention(
            (self): MobileBertSelfAttention(
              (query): LoRALayer(
                (original_layer): Linear(in_features=128, out_features=128, bias=True)
                (lora_update): LoRAUpdate()
              )
              (key): LoRALayer(
                (original_layer): Linear(in_features=128, out_features=128, bias=True)
                (lora_update): LoRAUpdate()
              )
              

In [121]:
sum([param.numel() for param in lora_model.parameters() if param.requires_grad])

36864

In [122]:
lora_model = lora_model.to("mps")
lora_model.device

device(type='mps', index=0)

In [123]:
lora_training_args = TrainingArguments(
    output_dir="tmp/mobilebert_lora",
    learning_rate=5e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

lora_trainer = Trainer(
    model=lora_model,
    args=lora_training_args,
    train_dataset=tokenized_sst2["train"],
    eval_dataset=tokenized_sst2["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

lora_trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.2684,0.266855,0.895642
2,1.2747,0.263192,0.895642
3,0.6445,0.272071,0.90367
4,0.2164,0.258067,0.901376
5,0.2057,0.250308,0.917431
6,0.2035,0.285823,0.900229
7,0.1903,0.276891,0.905963
8,16.2636,0.275548,0.911697
9,13.5932,0.286364,0.90711
10,0.165,0.276279,0.904817


KeyboardInterrupt: 