This notebook runs in google collab and encodes movies reviews in IMDB dataset using DistillBERT model.
All of the intermediate embeddings are also stored for further analysis (eg: to keep track of how individual
token embeddings slowly change to the output embeddings as the inference progresses).

In [1]:
! pip install datasets



In [19]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
import torch
import pickle, os
from tqdm.notebook import tqdm
from transformers import AutoModel
from datasets import load_dataset
from transformers import AutoTokenizer
import random

In [21]:
root_dir = "/content/drive/My Drive/ML_Experiments/"
os.makedirs(f'{root_dir}imdb_distillbert_inference', exist_ok=True)

In [12]:
imdb = load_dataset("imdb")
del imdb['unsupervised']

In [7]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
imdb = imdb.map(lambda x: tokenizer(x['text'], padding='max_length', truncation=True, max_length=512), batched=True)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [9]:
imdb_data = []
for split in ['train', 'test']:
    for row in tqdm(imdb[split]):
        imdb_data.append({
            'split': split,
            'text': row['text'],
            'label': row['label'],
            'input_ids': row['input_ids'],
            'attention_mask': row['attention_mask']
        })

  0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/25000 [00:00<?, ?it/s]

In [14]:
# shuffle data
random.seed(42)
random.shuffle(imdb_data)

In [15]:
distillBERT = AutoModel.from_pretrained("distilbert/distilbert-base-uncased",
                                        output_hidden_states=True,
                                        output_attentions=True,
                                        torch_dtype=torch.float16)
distillBERT.to('cuda')

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [27]:
# run batchwise inference and save the pickle file for each batch
batch_size = 100
distillBERT.eval()
for start in tqdm(list(range(0, len(imdb_data), batch_size)), desc="Distill BERT batch inference on IMDB sentiment analysis"):
    end = min(start + batch_size, len(imdb_data))
    with torch.no_grad():
        input_ids = torch.tensor([x['input_ids'] for x in imdb_data[start:end]]).to('cuda')
        attention_mask = torch.tensor([x['attention_mask'] for x in imdb_data[start:end]]).to('cuda')
        output = distillBERT(input_ids, attention_mask)
        batch_output = []
        for i in range(end-start):
            last_hidden_state = output.last_hidden_state[i]
            attention_mask = imdb_data[start+i]['attention_mask']
            max_index_with_attention_mask = sum(attention_mask)
            last_hidden_state = last_hidden_state[:max_index_with_attention_mask]
            batch_output.append({
                'split': imdb_data[start+i]['split'],
                'text': imdb_data[start+i]['text'],
                'label': imdb_data[start+i]['label'],
                'input_ids': imdb_data[start+i]['input_ids'],
                'attention_mask': attention_mask,
                'last_hidden_state': last_hidden_state.cpu().numpy()
            })
        with open(f"{root_dir}imdb_distillbert_inference/batch_{start}_to_{end}.pkl", "wb") as f:
            pickle.dump(batch_output, f)

Distill BERT batch inference on IMDB sentiment analysis:   0%|          | 0/500 [00:00<?, ?it/s]