In [1]:
from typing import Dict

import torch
from lightning.fabric import seed_everything
from torch.func import functional_call, grad, vmap
from transformers import AutoModelForSequenceClassification

from energizer.datastores import PandasDataStoreForSequenceClassification

In [2]:
ds = PandasDataStoreForSequenceClassification.load("../data/prepared/agnews_binarised_bert-tiny/")
ds.show_batch("test")

{'input_ids': tensor([[  101, 10069,  2005,  1056,  1050, 11550,  2044,  7566,  9209,  5052,
           3667,  2012,  6769,  2047,  8095,  2360,  2027,  2024,  1005,  9364,
           1005,  2044,  7566,  2007, 16654,  6687,  3813,  2976,  9587, 24848,
           1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1]]),
 'labels': tensor([0]),
 <InputKeys.ON_CPU: 'on_cpu'>: {}}

In [3]:
ds.prepare_for_loading(batch_size=1)
sample = next(iter(ds.train_loader(passive=True)))
s_inp, s_att, s_lab = (
    sample["input_ids"].flatten(),
    sample["attention_mask"].flatten(),
    sample["labels"].flatten(),
)

ds.prepare_for_loading(batch_size=2)
batch = next(iter(ds.train_loader(passive=True)))
b_inp, b_att, b_lab = batch["input_ids"], batch["attention_mask"], batch["labels"]

In [4]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

In [5]:
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

In [6]:
def compute_loss(params: Dict, buffers: Dict, input_ids, attention_mask, labels) -> torch.Tensor:
    inp, att, lab = (
        input_ids.unsqueeze(0),
        attention_mask.unsqueeze(0),
        labels.unsqueeze(0),
    )
    return functional_call(model, (params, buffers), (inp, att), kwargs={"labels": lab}).loss


compute_loss(params, buffers, s_inp, s_att, s_lab)

tensor(0.8419)

In [7]:
compute_loss_vect = vmap(compute_loss, in_dims=(None, None, 0, 0, 0))
compute_loss_vect(params, buffers, b_inp, b_att, b_lab)

tensor([0.8419, 0.7595])

In [9]:
compute_grad = grad(compute_loss)


def grad_norm(grads, norm_type):
    norms = [g.norm(norm_type).unsqueeze(0) for g in grads.values() if g is not None]
    return torch.concat(norms).norm(norm_type)


def compute_grad_norm(
    params: Dict, buffers: Dict, input_ids, attention_mask, labels
) -> torch.Tensor:
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return grad_norm(grads, 2)


compute_grad_norm_vect = vmap(compute_grad_norm, in_dims=(None, None, 0, 0, 0))
compute_grad_norm_vect(params, buffers, b_inp, b_att, b_lab).tolist()

[8.894059181213379, 7.85386848449707]

In [None]:
list(model.parameters())[0]

In [None]:
compute_grad_vect = vmap(compute_grad, in_dims=(None, None, 0, 0, 0))
grads_vect = compute_grad_vect(params, buffers, b_inp, b_att, b_lab)
grads_vect[k].shape

In [None]:
list(grads.values())[0].norm()

In [None]:
sample.pop("on_cpu", None)
model(**sample).loss

In [None]:
f = vmap(compute_loss, in_dims=(None, None, 0, 0, 0))
batch.pop("on_cpu", None)
f(params, buffers, batch["input_ids"], batch["attention_mask"], batch["labels"])

In [None]:
ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0, 0))

In [None]:
ft_compute_sample_grad(
    params, buffers, batch["input_ids"], batch["attention_mask"], batch["labels"]
)

In [None]:
ds.prepare_for_loading(batch_size=2)
batch = next(iter(ds.train_loader(passive=True)))

In [None]:
batch["input_ids"].shape

In [None]:
batch.pop("on_cpu", None)
ft_compute_sample_grad(params, buffers, **batch)

In [10]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}


def compute_loss(params: Dict, buffers: Dict, input_ids, attention_mask, labels) -> torch.Tensor:
    inp, att, lab = (
        input_ids.unsqueeze(0),
        attention_mask.unsqueeze(0),
        labels.unsqueeze(0),
    )
    return functional_call(model, (params, buffers), (inp, att), kwargs={"labels": lab}).loss


compute_grad = grad(compute_loss)


def grad_norm(grads, norm_type):
    norms = [g.norm(norm_type).unsqueeze(0) for g in grads.values() if g is not None]
    return torch.concat(norms).norm(norm_type)


def compute_grad_norm(
    params: Dict, buffers: Dict, input_ids, attention_mask, labels
) -> torch.Tensor:
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return grad_norm(grads, 2)


compute_grad_norm_vect = vmap(compute_grad_norm, in_dims=(None, None, 0, 0, 0), randomness="same")

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

In [11]:
ds.prepare_for_loading(batch_size=2)
batch = next(iter(ds.train_loader(passive=True)))
b_inp, b_att, b_lab = batch["input_ids"], batch["attention_mask"], batch["labels"]

In [13]:
compute_grad_norm_vect(params, buffers, b_inp, b_att, b_lab)

tensor([8.8941, 7.8539])

In [14]:
list(model.parameters())[0]

Parameter containing:
tensor([[-4.1018e-03, -3.0695e-02, -3.5295e-03,  ...,  1.8925e-02,
          3.7396e-03, -2.9233e-03],
        [-4.2748e-04, -3.6929e-02, -1.7168e-02,  ...,  2.9314e-02,
         -1.0398e-02,  2.6772e-02],
        [ 5.9418e-03,  4.2119e-03, -1.9566e-02,  ...,  1.6799e-02,
         -2.7802e-02, -6.9017e-03],
        ...,
        [ 3.5573e-02, -1.5891e-02,  4.9951e-03,  ...,  5.4071e-03,
         -1.1270e-02, -6.9528e-05],
        [-8.7018e-03, -2.2516e-02,  3.1993e-03,  ...,  2.7591e-02,
         -1.9554e-02,  2.4023e-03],
        [-7.8904e-02, -7.5407e-02, -4.6660e-03,  ..., -5.3340e-03,
         -4.4993e-02,  5.9842e-02]], requires_grad=True)

In [17]:
from lightning.fabric import Fabric

In [18]:
fabric = Fabric(accelerator="gpu")

In [20]:
model = fabric.setup(model)

In [23]:
model.module.device

device(type='cuda', index=0)