In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
import evaluate

import time
from pprint import pprint
from tqdm.auto import tqdm

In [2]:
dataset_id = 'yelp_review_full'
dataset = load_dataset(dataset_id)
pprint(dataset['train'][0])

{'label': 4,
 'text': 'dr. goldberg offers everything i look for in a general '
         "practitioner.  he's nice and easy to talk to without being "
         "patronizing; he's always on time in seeing his patients; he's "
         'affiliated with a top-notch hospital (nyu) which my parents have '
         'explained to me is very important in case something happens and you '
         'need surgery; and you can get referrals to see specialists without '
         "having to see him first.  really, what more do you need?  i'm "
         'sitting here trying to think of any complaints i have about him, but '
         "i'm really drawing a blank."}


In [3]:
model_id = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['text'])
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')
tokenized_dataset.set_format('torch')
pprint(tokenized_dataset['train'][0])

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0

In [5]:
batch_size = 16
train_subset = tokenized_dataset['train'].shuffle(seed=0).select(range(10_000))
test_subset = tokenized_dataset['test'].shuffle(seed=0).select(range(1_000))

train_dataloader = DataLoader(train_subset, shuffle=True, batch_size=batch_size)
test_dataloader = DataLoader(test_subset, batch_size=batch_size)

In [6]:
# Setup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
optimizer = optim.AdamW(model.parameters(), lr=5e-5, fused=True)
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)
progress_bar = tqdm(range(num_training_steps))

# Optimizations
torch.cuda.empty_cache() # good practice
torch.set_float32_matmul_precision('high') # TODO: this does nothing without bfloat16 support
model.to(device).train()
model = torch.compile(model)

# Training Loop
for epoch in range(num_epochs):
    for batch in train_dataloader:
        t0 = time.time()
        
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.autocast(device_type=device.type, dtype=torch.float16): # TODO: change to bfloat16 when support to prevent overflow (or use gradient scalers otherwise)
            output = model(**batch)
        loss = output.loss
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        loss.backward()
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        
        torch.cuda.synchronize() # wait for everything to finish running
        t1 = time.time()
        dt = (t1 - t0) * 1000
        throughput = (batch_size * 512) / (t1 - t0)
        print(f"step: {progress_bar.n}/{num_training_steps}\tloss: {loss.item():.4f}\tlr: {optimizer.param_groups[0]['lr']:.3e}\tdt: {dt:.1f} ms\tthroughput: {throughput:.2f} tok/s\tnorm: {norm:.4f}")
        
        if progress_bar.n == 3:
            break

  0%|          | 0/625 [00:00<?, ?it/s]

step: 1/625	loss: 1.7444	lr: 4.992e-05	dt: 10772.5 ms	throughput: 760.45 tok/s	norm: 0.0000
step: 2/625	loss: 1.7254	lr: 4.984e-05	dt: 8447.9 ms	throughput: 969.71 tok/s	norm: 0.0000
step: 3/625	loss: 1.7051	lr: 4.976e-05	dt: 288.9 ms	throughput: 28356.29 tok/s	norm: 0.0000


In [None]:

# batch size 4
# step: 1/2500	loss: 1.7454	lr: 4.998e-05	dt: 338.5 ms	throughput: 6051.09 tok/s
# step: 2/2500	loss: 1.8611	lr: 4.996e-05	dt: 294.1 ms	throughput: 6962.79 tok/s
# step: 3/2500	loss: 1.6569	lr: 4.994e-05	dt: 295.1 ms	throughput: 6940.80 tok/s

# batch size 8
# step: 1/1250	loss: 1.5527	lr: 4.996e-05	dt: 813.1 ms	throughput: 5037.56 tok/s
# step: 2/1250	loss: 1.7109	lr: 4.992e-05	dt: 541.6 ms	throughput: 7562.55 tok/s
# step: 3/1250	loss: 1.7516	lr: 4.988e-05	dt: 552.3 ms	throughput: 7415.65 tok/s

# # batch size 16
# step: 1/625	loss: 1.5491	lr: 4.992e-05	dt: 1366.7 ms	throughput: 5993.87 tok/s
# step: 2/625	loss: 1.6528	lr: 4.984e-05	dt: 1068.3 ms	throughput: 7668.55 tok/s
# step: 3/625	loss: 1.5627	lr: 4.976e-05	dt: 1064.3 ms	throughput: 7697.24 tok/s

# batch size 32 (does not fit in vRAM anymore)
# step: 1/313	loss: 1.6456	lr: 4.984e-05	dt: 34043.3 ms	throughput: 481.27 tok/s
# step: 2/313	loss: 1.5949	lr: 4.968e-05	dt: 35184.1 ms	throughput: 465.66 tok/s
# step: 3/313	loss: 1.6386	lr: 4.952e-05	dt: 34717.0 ms	throughput: 471.93 tok/s

# batch size 16 (with reduced matmul precision - this will be better with bfloat16 support)
# step: 1/625	loss: 1.5726	lr: 4.992e-05	dt: 1292.8 ms	throughput: 6336.66 tok/s
# step: 2/625	loss: 1.5232	lr: 4.984e-05	dt: 996.7 ms	throughput: 8219.30 tok/s
# step: 3/625	loss: 1.6037	lr: 4.976e-05	dt: 1060.8 ms	throughput: 7722.38 tok/s

# batch size 16 (with reduced matmul precision and reduced precision)
# step: 1/625	loss: 1.2793	lr: 4.992e-05	dt: 607.1 ms	throughput: 13493.47 tok/s
# step: 2/625	loss: 1.3772	lr: 4.984e-05	dt: 314.9 ms	throughput: 26017.63 tok/s
# step: 3/625	loss: 1.3931	lr: 4.976e-05	dt: 313.7 ms	throughput: 26116.45 tok/s

# batch size 16 (with reduced matmul precision and reduced precision and torch compile)
# W0215 14:21:50.430000 6022 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode
# step: 1/625	loss: 1.3195	lr: 4.992e-05	dt: 29804.7 ms	throughput: 274.86 tok/s
# step: 2/625	loss: 1.2415	lr: 4.984e-05	dt: 295.9 ms	throughput: 27686.87 tok/s
# step: 3/625	loss: 1.0681	lr: 4.976e-05	dt: 281.0 ms	throughput: 29149.65 tok/s

# batch size 16 (with reduced matmul precision and reduced precision and torch compile and fused AdamW with gradient clipping)
# step: 1/625	loss: 1.7444	lr: 4.992e-05	dt: 10772.5 ms	throughput: 760.45 tok/s	norm: 0.0000
# step: 2/625	loss: 1.7254	lr: 4.984e-05	dt: 8447.9 ms	throughput: 969.71 tok/s	norm: 0.0000
# step: 3/625	loss: 1.7051	lr: 4.976e-05	dt: 288.9 ms	throughput: 28356.29 tok/s	norm: 0.0000

In [8]:
# TODO:
# - Adjust hyperparameters (optimizer perhaps)
# - Gradient accumulation