In [1]:
import random
import math

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)

# TraceML imports
from traceml.decorator import trace_model_instance
from traceml.manager.tracker_manager import TrackerManager

In [2]:
SEED = 42
MODEL_NAME = "distilbert-base-uncased"
MAX_TRAIN_EXAMPLES = 500
MAX_VAL_EXAMPLES = 100
BATCH_SIZE = 32
EPOCHS = 1
LR = 2e-5
WARMUP_RATIO = 0.06


In [3]:
def set_seed(seed: int = SEED):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def accuracy_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=-1)
    correct = (preds == labels).sum().item()
    return correct / max(1, labels.size(0))


def prepare_data():
    raw = load_dataset("ag_news")
    train_raw = raw["train"].select(range(min(MAX_TRAIN_EXAMPLES, len(raw["train"]))))
    val_raw = raw["test"].select(range(min(MAX_VAL_EXAMPLES, len(raw["test"]))))

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def tok(examples):
        return tokenizer(examples["text"], truncation=True)

    train_ds = train_raw.map(tok, batched=True, remove_columns=["text"])
    val_ds = val_raw.map(tok, batched=True, remove_columns=["text"])

    train_ds = train_ds.rename_column("label", "labels")
    val_ds = val_ds.rename_column("label", "labels")

    collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest")

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator
    )
    val_loader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator
    )
    return tokenizer, train_loader, val_loader


In [4]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    scaler,
    device,
    dtype,
    epochs: int = 1
):
    """
    Train and validate a model.

    Args:
        model: torch.nn.Module
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        optimizer: torch optimizer
        scheduler: learning rate scheduler
        scaler: GradScaler for mixed precision
        device: torch.device ("cpu" or "cuda")
        dtype: torch dtype (torch.float16 or torch.float32)
        epochs: number of training epochs
    """
    model.train()
    global_step = 0

    for epoch in range(epochs):
        running_loss = 0.0
        running_acc = 0.0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=dtype):
                out = model(**batch)
                loss = out.loss
                logits = out.logits

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            acc = accuracy_from_logits(logits.detach(), batch["labels"])
            running_loss += loss.item()
            running_acc += acc
            global_step += 1

        # Validation
        model.eval()
        val_loss, val_acc, n_batches = 0.0, 0.0, 0
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=dtype):
                    out = model(**batch)
                    loss = out.loss
                    logits = out.logits
                val_loss += loss.item()
                val_acc += accuracy_from_logits(logits, batch["labels"])
                n_batches += 1

        model.train()

In [None]:
set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

tokenizer, train_loader, val_loader = prepare_data()

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=4
).to(device)

# Attach TraceML hooks
trace_model_instance(
    model,
    sample_layer_memory=True,
    trace_activations=True,
    trace_gradients=True,
)

optimizer = AdamW(model.parameters(), lr=LR)
total_steps = EPOCHS * math.ceil(len(train_loader))
warmup_steps = int(WARMUP_RATIO * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

tracker = TrackerManager(interval_sec=1.0, mode='notebook')
tracker.start()
train_model(model, train_loader, val_loader, optimizer, scheduler, scaler, device, dtype)
tracker.stop()
tracker.log_summaries()


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


VBox()

  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=dtype):
  scaler.scale(loss).backward()
