In [20]:
from typing import Optional

import torch
from lightning.fabric import seed_everything
from torch.types import Device
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.func import functional_call, grad, vmap  # type: ignore

from energizer.datastores import PandasDataStoreForSequenceClassification

In [2]:
def rademacher(
    shape,
    generator: Optional[torch.Generator] = None,
    device: Optional[Device] = None,
) -> torch.Tensor:
    """Sample from Rademacher distribution."""
    return torch.randint(0, 2, shape, generator=generator, device=device) * 2.0 - 1.0

In [3]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds_dict = load_dataset("pietrolesci/pubmed-rct20k_indexed").map(lambda ex: tokenizer(ex["text"]), batched=True, num_proc=4)

Found cached dataset parquet (/home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-999f0e823471e99e_*_of_00004.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-7c0ef1ff89bb0eab_*_of_00004.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-db9896ae5bc093c4_*_of_00004.arrow


In [8]:
ds = PandasDataStoreForSequenceClassification()
ds.from_dataset_dict(
    ds_dict,
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
    tokenizer=tokenizer,
    uid_name="uid"
)

In [9]:
ds.prepare_for_loading(batch_size=1, eval_batch_size=1)

In [17]:
batch = ds.show_batch("test")
_ = batch.pop("on_cpu")
batch

{'input_ids': tensor([[  101,  2023,  2817, 16578, 11290,  3853, 28828,  1999,  2540,  4945,
           5022,  4914,  2007,  5729, 11325, 21933,  8737,  6132,  4383,  2540,
           4945,  1006,  4748,  2232,  2546,  1007,  1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1]]),
 <InputKeys.TARGET: 'labels'>: tensor([0])}

In [19]:
seed_everything(42)

# load model using data properties
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,  # type: ignore
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

In [None]:
def compute_loss(
    params: Dict, buffers: Dict, input_ids: Tensor, attention_mask: Tensor, labels: Tensor
) -> torch.Tensor:
    inp, att, lab = input_ids.unsqueeze(0), attention_mask.unsqueeze(0), labels.unsqueeze(0)
    return functional_call(model, (params, buffers), (inp, att), kwargs={"labels": lab}).loss

def _grad_norm(grads: Dict, norm_type: int) -> Tensor:
    norms = [g.norm(norm_type).unsqueeze(0) for g in grads.values() if g is not None]
    return torch.concat(norms).norm(norm_type)

grad_norm = partial(_grad_norm, norm_type=2)
compute_grad = grad(compute_loss)

def compute_grad_norm(
    params: Dict, buffers: Dict, input_ids: Tensor, attention_mask: Tensor, labels: Tensor
) -> torch.Tensor:
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return grad_norm(grads)

compute_grad_norm_vect = vmap(compute_grad_norm, in_dims=(None, None, 0, 0, 0), randomness="same")

norms, ids = [], []
for batch in tqdm(datastore.train_loader(), disable=True):
    batch = self.transfer_to_device(batch)
    b_inp, b_att, b_lab = batch[InputKeys.INPUT_IDS], batch[InputKeys.ATT_MASK], batch[InputKeys.TARGET]
    norms += compute_grad_norm_vect(params, buffers, b_inp, b_att, b_lab).tolist()
    ids += batch[InputKeys.ON_CPU][SpecialKeys.ID]

norms = np.array(norms)
ids = np.array(ids)
topk_ids = norms.argsort()[-self.num_influential :]  # biggest gradient norm
return ids[topk_ids].tolist()

In [None]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)
_ = model.eval()

In [None]:
model.zero_grad()
loss = model(**batch).loss
torch.autograd.backward(loss, create_graph=True)

In [None]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
generator = torch.Generator().manual_seed(2147483647)
grads = [p.grad for p in params]
n_samples = 10


for i in range(n_samples):
    zs = [
        torch.randint(0, 2, p.size(), generator=generator) * 2.0 - 1.0
        for p in params
    ]  # Rademacher distribution {-1.0, 1.0}

    h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=False)

    if i == 0:
        hess = h_zs
    else:
        for d, hvp, v in zip(hess, h_zs, zs):
            d += v * hvp / n_samples

    # for idx, (h_z, z) in enumerate(zip(h_zs, zs)):
    #     hess[idx] += h_z / n_samples  # approximate the expected values of z*(H@z)

In [None]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
generator = torch.Generator().manual_seed(2147483647)
grads = [p.grad for p in params]
n_samples = 10

for _ in range(V):
    for p in model.parameters():
        v = [rademacher(p.size())]
        Hv = hessian_vector_product(loss, [p], v)
        vHv = torch.einsum("i,i->", v[0].flatten(), Hv[0].flatten())

        trace += vHv / V

In [None]:
params[0].size()

In [None]:
rademacher

In [None]:
h_zs[0].shape

In [None]:
params[0].shape

In [None]:
hess = []
hess[1] = 2

In [None]:
h_zs

In [None]:
len(grads), len(params), len(zs)

In [None]:
zs[0].requires_grad

In [None]:
zs[0]

In [None]:
grads[0]