In [59]:
from typing import Optional

import torch
from lightning.fabric import seed_everything
from torch.types import Device
from transformers import AutoModelForSequenceClassification

from energizer.datastores import PandasDataStoreForSequenceClassification

In [61]:
def rademacher(
    shape,
    generator: Optional[torch.Generator] = None,
    device: Optional[Device] = None,
) -> torch.Tensor:
    """Sample from Rademacher distribution."""
    return torch.randint(0, 2, shape, generator=generator, device=device) * 2.0 - 1.0

In [2]:
ds = PandasDataStoreForSequenceClassification.load(
    "../data/prepared/agnews_binarised_bert-tiny/"
)
ds.show_batch("test")

{'input_ids': tensor([[  101, 10069,  2005,  1056,  1050, 11550,  2044,  7566,  9209,  5052,
           3667,  2012,  6769,  2047,  8095,  2360,  2027,  2024,  1005,  9364,
           1005,  2044,  7566,  2007, 16654,  6687,  3813,  2976,  9587, 24848,
           1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1]]),
 'labels': tensor([0]),
 <InputKeys.ON_CPU: 'on_cpu'>: {}}

In [4]:
ds.prepare_for_loading()

In [96]:
for batch in ds.test_loader():
    batch.pop("on_cpu")
    loss = model(**batch).loss
    loss.backward(create_graph=True)

AttributeError: 'SequenceClassifierOutput' object has no attribute 'backward'

In [5]:
batch = ds.show_batch("test")
batch.pop("on_cpu")

{}

In [52]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)
_ = model.eval()

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

In [53]:
model.zero_grad()
loss = model(**batch).loss
torch.autograd.backward(loss, create_graph=True)

In [54]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
generator = torch.Generator().manual_seed(2147483647)
grads = [p.grad for p in params]
n_samples = 10


for i in range(n_samples):
    zs = [
        torch.randint(0, 2, p.size(), generator=generator) * 2.0 - 1.0
        for p in params
    ]  # Rademacher distribution {-1.0, 1.0}

    h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=False)

    if i == 0:
        hess = h_zs
    else:
        for d, hvp, v in zip(hess, h_zs, zs):
            d += v * hvp / n_samples

    # for idx, (h_z, z) in enumerate(zip(h_zs, zs)):
    #     hess[idx] += h_z / n_samples  # approximate the expected values of z*(H@z)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
generator = torch.Generator().manual_seed(2147483647)
grads = [p.grad for p in params]
n_samples = 10

for _ in range(V):
    for p in model.parameters():
        v = [rademacher(p.size())]
        Hv = hessian_vector_product(loss, [p], v)
        vHv = torch.einsum("i,i->", v[0].flatten(), Hv[0].flatten())

        trace += vHv / V

In [56]:
params[0].size()

torch.Size([30522, 128])

In [None]:
rademacher

In [79]:
h_zs[0].shape

torch.Size([30522, 128])

In [80]:
params[0].shape

torch.Size([30522, 128])

In [55]:
hess = []
hess[1] = 2

IndexError: list assignment index out of range

In [53]:
h_zs

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[ 7.1879, -4.2720,  0.4914,  ..., -0.5351,  2.2430,  5.4674],
         [ 1.3710, -2.0633,  0.6413,  ...,  0.6785, -0.6097,  3.3822],
         [ 1.1472, -0.9647,  0.4216,  ..., -0.1143,  0.2280,  1.4992],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor([[ 31.8962, -41.4615,  24.8657,  19.4978,  17.6549, -44.0062, -18.6706,
          -16.7598, -16.9078,  15.6570,  -9.1751,   5.3166, -11.1128,   3.5085,
          -64.4569,  -1.9440,  17.7808,  12.6853,  -3.4328, -28.2170, -14.6217,
           -0.6578, -15.4823,   2.3558,   6.6547, -53.19

In [24]:
len(grads), len(params), len(zs)

(41, 41, 41)

In [29]:
zs[0].requires_grad

False

In [33]:
zs[0]

tensor([[ 1., -1.,  1.,  ...,  1.,  1.,  1.],
        [-1.,  1.,  1.,  ...,  1., -1., -1.],
        [ 1.,  1., -1.,  ...,  1., -1.,  1.],
        ...,
        [ 1.,  1.,  1.,  ..., -1., -1., -1.],
        [ 1.,  1.,  1.,  ...,  1.,  1., -1.],
        [-1., -1.,  1.,  ..., -1.,  1.,  1.]], grad_fn=<SubBackward0>)

In [41]:
grads[0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])