In [None]:
!pip install traceml-ai==0.1.1

In [1]:
import random
import math

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)

# TraceML imports
from traceml.decorator import trace_model_instance, trace_timestep
from traceml.manager.tracker_manager import TrackerManager

In [2]:
SEED = 42
MODEL_NAME = "distilbert-base-uncased"
MAX_TRAIN_EXAMPLES = 500
MAX_VAL_EXAMPLES = 100
BATCH_SIZE = 32
EPOCHS = 1
LR = 2e-5
WARMUP_RATIO = 0.06


In [3]:
def set_seed(seed: int = SEED):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def accuracy_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=-1)
    correct = (preds == labels).sum().item()
    return correct / max(1, labels.size(0))


def prepare_data():
    raw = load_dataset("ag_news")
    train_raw = raw["train"].select(range(min(MAX_TRAIN_EXAMPLES, len(raw["train"]))))
    val_raw = raw["test"].select(range(min(MAX_VAL_EXAMPLES, len(raw["test"]))))

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def tok(examples):
        return tokenizer(examples["text"], truncation=True)

    train_ds = train_raw.map(tok, batched=True, remove_columns=["text"])
    val_ds = val_raw.map(tok, batched=True, remove_columns=["text"])

    train_ds = train_ds.rename_column("label", "labels")
    val_ds = val_ds.rename_column("label", "labels")

    collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest")

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator
    )
    val_loader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator
    )
    return tokenizer, train_loader, val_loader


In [4]:
# --- TraceML step timers (granular measurement) ---

@trace_timestep("dataloader_fetch", use_gpu=False)
def get_next_batch(it):
    """Measure wait time for dataloader iteration."""
    return next(it)


@trace_timestep("data_loading", use_gpu=False)
def load_batch_to_device(batch, device):
    """Measure data transfer time (CPU → GPU)."""
    return {k: v.to(device, non_blocking=True) for k, v in batch.items()}


@trace_timestep("forward", use_gpu=True)
def forward_pass(model, batch, dtype):
    """Trace forward pass latency."""
    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=dtype):
        return model(**batch)


@trace_timestep("backward", use_gpu=True)
def backward_pass(loss, scaler):
    """Trace backward pass (grad computation)."""
    scaler.scale(loss).backward()


@trace_timestep("optimizer_step", use_gpu=True)
def optimizer_step(scaler, optimizer, scheduler):
    """Trace optimizer + scheduler step."""
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()


@trace_timestep("validation", use_gpu=True)
def run_validation(model, val_loader, dtype, device):
    """Trace entire validation loop time."""
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    n_batches = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = load_batch_to_device(batch, device)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=dtype):
                out = model(**batch)
                loss = out.loss
                logits = out.logits
            val_loss += loss.item()
            val_acc += accuracy_from_logits(logits, batch["labels"])
            n_batches += 1
    model.train()
    return val_loss / max(1, n_batches), val_acc / max(1, n_batches)

In [5]:
def run_training(model, train_loader, val_loader, optimizer, scheduler, scaler, device, dtype):
    """Main training loop (timed via TraceML)."""
    model.train()
    global_step = 0
    train_iter = iter(train_loader)

    for epoch in range(EPOCHS):
        running_loss = 0.0
        running_acc = 0.0

        for _ in range(len(train_loader)):
            # Measure dataloader latency
            batch = get_next_batch(train_iter)

            # Transfer to GPU
            batch = load_batch_to_device(batch, device)

            optimizer.zero_grad(set_to_none=True)

            # Forward
            out = forward_pass(model, batch, dtype)
            loss = out.loss
            logits = out.logits

            # Backward
            backward_pass(loss, scaler)

            # Optimizer step
            optimizer_step(scaler, optimizer, scheduler)

            # Accuracy
            acc = accuracy_from_logits(logits.detach(), batch["labels"])
            running_loss += loss.item()
            running_acc += acc
            global_step += 1

            if global_step % 50 == 0:
                print(
                    f"[Train] epoch {epoch+1} step {global_step} | "
                    f"loss {running_loss/50:.4f} | acc {running_acc/50:.4f}"
                )
                running_loss = 0.0
                running_acc = 0.0

        # Validation after epoch
        val_loss, val_acc = run_validation(model, val_loader, dtype, device)
        print(f"[Val] epoch {epoch+1} | loss {val_loss:.4f} | acc {val_acc:.4f}")


In [None]:
set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

tokenizer, train_loader, val_loader = prepare_data()

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=4
).to(device)

# Attach TraceML hooks (activations, gradients, etc.)
trace_model_instance(
    model,
    sample_layer_memory=True,
    trace_activations=True,
    trace_gradients=True,
)

optimizer = AdamW(model.parameters(), lr=LR)
total_steps = EPOCHS * math.ceil(len(train_loader))
warmup_steps = int(WARMUP_RATIO * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
scaler = torch.amp.GradScaler(device="cuda", enabled=torch.cuda.is_available())

tracker = TrackerManager(interval_sec=1.0, mode="notebook")
tracker.start()
run_training(model, train_loader, val_loader, optimizer, scheduler, scaler, device, dtype)
tracker.stop()
tracker.log_summaries()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=dtype):


VBox()

  scaler.scale(loss).backward()
