In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


In [3]:
from drone import compute_and_save_bert_activations

In [4]:
import torch
from copy import deepcopy
from transformers import AutoModel

device = "cuda:0"
model = AutoModel.from_pretrained("microsoft/deberta-v3-base").to(device)

Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2Model: ['lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.LayerNorm.bias', 'mask_predictions.dense.weight', 'lm_predictions.lm_head.dense.weight', 'mask_predictions.LayerNorm.weight', 'mask_predictions.dense.bias', 'mask_predictions.classifier.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'mask_predictions.classifier.weight', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.bias']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
class TokenizedDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        return {"input_ids": self.data[index], "attention_mask": torch.ones_like(self.data[index])}

    def __len__(self):
        return len(self.data)

In [6]:
data = torch.arange(500 * 28).reshape(500, 28) % 97 + 100

dataset = TokenizedDataset(data)

In [7]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128)

In [8]:
!mkdir bert_activations
!rm bert_activations/*.pth

mkdir: cannot create directory ‘bert_activations’: File exists


In [9]:
compute_and_save_bert_activations(model, dataloader, output_dir="./bert_activations", device=device, subsample=1.0)

computing activations: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.02s/it]


In [10]:
!ls bert_activations/

batch_0_layer_0_attn-mask.pth	       batch_2_layer_0_attn-mask.pth
batch_0_layer_0_attn-out.pth	       batch_2_layer_0_attn-out.pth
batch_0_layer_0_embeds.pth	       batch_2_layer_0_embeds.pth
batch_0_layer_0_keys.pth	       batch_2_layer_0_keys.pth
batch_0_layer_0_mlp-intermediate.pth   batch_2_layer_0_mlp-intermediate.pth
batch_0_layer_0_mlp-output.pth	       batch_2_layer_0_mlp-output.pth
batch_0_layer_0_queries.pth	       batch_2_layer_0_queries.pth
batch_0_layer_0_values.pth	       batch_2_layer_0_values.pth
batch_0_layer_10_attn-out.pth	       batch_2_layer_10_attn-out.pth
batch_0_layer_10_keys.pth	       batch_2_layer_10_keys.pth
batch_0_layer_10_mlp-intermediate.pth  batch_2_layer_10_mlp-intermediate.pth
batch_0_layer_10_mlp-output.pth        batch_2_layer_10_mlp-output.pth
batch_0_layer_10_queries.pth	       batch_2_layer_10_queries.pth
batch_0_layer_10_values.pth	       batch_2_layer_10_values.pth
batch_0_layer_11_attn-out.pth	       batch_2_layer_11_attn-out.pt

In [11]:
from drone import concat_activations_by_layers

concat_activations_by_layers("./bert_activations/")

In [12]:
!ls bert_activations/ 

batch_0_layer_0_attn-mask.pth	       batch_2_layer_3_values.pth
batch_0_layer_0_attn-out.pth	       batch_2_layer_4_attn-out.pth
batch_0_layer_0_embeds.pth	       batch_2_layer_4_keys.pth
batch_0_layer_0_keys.pth	       batch_2_layer_4_mlp-intermediate.pth
batch_0_layer_0_mlp-intermediate.pth   batch_2_layer_4_mlp-output.pth
batch_0_layer_0_mlp-output.pth	       batch_2_layer_4_queries.pth
batch_0_layer_0_queries.pth	       batch_2_layer_4_values.pth
batch_0_layer_0_values.pth	       batch_2_layer_5_attn-out.pth
batch_0_layer_10_attn-out.pth	       batch_2_layer_5_keys.pth
batch_0_layer_10_keys.pth	       batch_2_layer_5_mlp-intermediate.pth
batch_0_layer_10_mlp-intermediate.pth  batch_2_layer_5_mlp-output.pth
batch_0_layer_10_mlp-output.pth        batch_2_layer_5_queries.pth
batch_0_layer_10_queries.pth	       batch_2_layer_5_values.pth
batch_0_layer_10_values.pth	       batch_2_layer_6_attn-out.pth
batch_0_layer_11_attn-out.pth	       batch_2_layer_6_keys.pth
batch_0_l

In [13]:
from drone import run_drone_compression_for_bert

compressed_model = run_drone_compression_for_bert(deepcopy(model), 300, "bert_activations", device)

Relative error: 2.4758108571683106
Relative error: 1.4365284705374777
Relative error: 1.4476874611829198
Relative error: 0.6711083809911085
Relative error: 0.8354931261727883
Relative error: 120.35973890625857
Relative error: 67.09788560879295
Relative error: 54.797522223862856
Relative error: 153.67660941360057
Relative error: 3.364251157350531
Relative error: 0.9953075481532908
Relative error: 7.88361819425119
Relative error: 0.8870916649548979
Relative error: 1.0244650358474787
Relative error: 0.9661249307555151
Relative error: 0.58574131768592
Relative error: 0.24868560647445295
Relative error: 10.025016196711704
Relative error: 0.6380854711827554
Relative error: 0.8028771896215232
Relative error: 0.7023145738453228
Relative error: 0.5683557680394715
Relative error: 0.24256982922778922
Relative error: 3.9076505350091595
Relative error: 0.4440861980940897
Relative error: 0.6097666038098561
Relative error: 0.38583672584089695
Relative error: 0.562466175795647
Relative error: 0.263031

In [14]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

count_params(compressed_model) / count_params(model)

0.8086930365468491

In [15]:
compressed_model

DebertaV2Model(
  (embeddings): DebertaV2Embeddings(
    (word_embeddings): Embedding(128100, 768, padding_idx=0)
    (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
    (dropout): StableDropout()
  )
  (encoder): DebertaV2Encoder(
    (layer): ModuleList(
      (0): DebertaV2Layer(
        (attention): DebertaV2Attention(
          (self): DisentangledSelfAttention(
            (query_proj): LowRankLinear(768_300_768)
            (key_proj): LowRankLinear(768_300_768)
            (value_proj): LowRankLinear(768_300_768)
            (pos_dropout): StableDropout()
            (dropout): StableDropout()
          )
          (output): DebertaV2SelfOutput(
            (dense): LowRankLinear(768_300_768)
            (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
            (dropout): StableDropout()
          )
        )
        (intermediate): DebertaV2Intermediate(
          (dense): LowRankLinear(768_300_3072)
          (intermediate_act_fn): GELU

In [16]:
count_params(compressed_model.encoder) / count_params(model.encoder)

0.5884309108231023

In [25]:
for batch in dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    break

with torch.inference_mode():
    output = model(**batch, output_hidden_states=True)

    embeds = model.embeddings(batch["input_ids"])
    other_output = model.encoder(embeds, torch.ones_like(batch["input_ids"]))

    print(torch.allclose(output.last_hidden_state, output.hidden_states[-1]))
    print(torch.allclose(output.hidden_states[0], embeds))

True
True
