In [None]:
!pip install transformers torch sparsezoo datasets

In [7]:
import sparsezoo
model = sparsezoo.Model("zoo:bert-large-conll2003_wikipedia_bookcorpus-base", download_path="./bert-large-conll2003").download()

downloading...: 100%|██████████| 349/349 [00:00<00:00, 182kB/s]
downloading...: 100%|██████████| 695k/695k [00:00<00:00, 13.6MB/s]
downloading...: 100%|██████████| 226k/226k [00:00<00:00, 7.33MB/s]
downloading...: 100%|██████████| 196/196 [00:00<00:00, 113kB/s]
downloading...: 100%|██████████| 1.18k/1.18k [00:00<00:00, 589kB/s]
downloading...: 100%|██████████| 112/112 [00:00<00:00, 62.8kB/s]
downloading...: 100%|██████████| 3.23k/3.23k [00:00<00:00, 1.59MB/s]
downloading...: 100%|██████████| 1.03k/1.03k [00:00<00:00, 515kB/s]
downloading...: 100%|██████████| 1.24G/1.24G [01:04<00:00, 20.6MB/s]
downloading...: 100%|██████████| 346/346 [00:00<00:00, 182kB/s]
downloading...: 100%|██████████| 522/522 [00:00<00:00, 192kB/s]
downloading...: 100%|██████████| 1.24G/1.24G [01:10<00:00, 19.1MB/s]
downloading...: 100%|██████████| 349/349 [00:00<00:00, 190kB/s]
downloading...: 100%|██████████| 1.03k/1.03k [00:00<00:00, 522kB/s]
downloading...: 100%|██████████| 695k/695k [00:00<00:00, 15.5MB/s]
dow

In [10]:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer

model_path = "./bert-large-conll2003/training"

bert = AutoModelForTokenClassification.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)

bert = bert.to("cuda:0")

In [23]:
from datasets import load_dataset

dataset = load_dataset("imdb", split="test")

Downloading builder script: 100%|██████████| 4.31k/4.31k [00:00<00:00, 3.13MB/s]
Downloading metadata: 100%|██████████| 2.17k/2.17k [00:00<00:00, 1.46MB/s]
Downloading readme: 100%|██████████| 7.59k/7.59k [00:00<00:00, 3.06MB/s]
Downloading data: 100%|██████████| 84.1M/84.1M [00:02<00:00, 29.8MB/s]
Generating train split: 100%|██████████| 25000/25000 [00:06<00:00, 3714.42 examples/s] 
Generating test split: 100%|██████████| 25000/25000 [00:06<00:00, 3763.60 examples/s] 
Generating unsupervised split: 100%|██████████| 50000/50000 [00:08<00:00, 6155.64 examples/s] 


In [25]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [26]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [48]:
SEQ_LEN = 128

b_1 = tokenizer(dataset[3]["text"], return_tensors="pt", max_length=SEQ_LEN, padding="max_length", truncation=True)
b_1["input_ids"].shape

torch.Size([1, 384])

In [47]:
b_64 = tokenizer(dataset[3:3+64]["text"], return_tensors="pt", max_length=SEQ_LEN, padding="max_length", truncation=True)
b_64["input_ids"].shape

torch.Size([64, 384])

In [49]:
b_256 = tokenizer(dataset[3:3+256]["text"], return_tensors="pt", max_length=SEQ_LEN, padding="max_length", truncation=True)
b_256["input_ids"].shape

torch.Size([256, 384])

In [50]:
b_512 = tokenizer(dataset[3:3+512]["text"], return_tensors="pt", max_length=SEQ_LEN, padding="max_length", truncation=True)
b_512["input_ids"].shape

torch.Size([512, 384])

In [51]:
for inp_dict in [b_1, b_64, b_256, b_512]:
    for k in inp_dict:
        inp_dict[k] = inp_dict[k].cuda()

In [55]:
import time

iterations_list = [
    100,
    5,
    5,
    5
]

inputs_dict = {
    1: b_1,
    64: b_64,
    256: b_256,
    512: b_512
}

print(f"------- WARMUP -------")
for _ in range(5):
    output = bert(**inputs_dict[1])
torch.cuda.synchronize()

bert.eval()
with torch.no_grad():
    for batch_size, iterations in zip(inputs_dict, iterations_list):
        print(f"------- STARTING B={batch_size} -------")
        print(inputs_dict[batch_size]["input_ids"].shape)

        start = time.perf_counter()
        for _ in range(iterations):
            output = bert(**inputs_dict[batch_size])
        torch.cuda.synchronize()
        end = time.perf_counter()
        
        total_items = iterations * batch_size
        total_time = end - start
        print(f"TOTAL_ITEMS = {total_items}")
        print(f"TOTAL_TIME = {total_time :0.2f}")
        print(f"THROUGHPUT = {total_items / total_time :0.2f}")

------- WARMUP -------
------- STARTING B=1 -------
torch.Size([1, 384])
TOTAL_ITEMS = 100
TOTAL_TIME = 1.84
THROUGHPUT = 54.21
------- STARTING B=64 -------
torch.Size([64, 384])
TOTAL_ITEMS = 320
TOTAL_TIME = 5.42
THROUGHPUT = 59.07
------- STARTING B=256 -------
torch.Size([256, 384])
TOTAL_ITEMS = 1280
TOTAL_TIME = 21.73
THROUGHPUT = 58.91
------- STARTING B=512 -------
torch.Size([512, 384])
TOTAL_ITEMS = 2560
TOTAL_TIME = 44.80
THROUGHPUT = 57.14
