In [12]:
from functools import partial
from typing import Dict, Optional

import torch
from datasets import load_dataset
from lightning.fabric import seed_everything
from torch import Tensor
from torch.func import functional_call, grad, hessian, jacfwd, jacrev, jvp, vmap  # type: ignore
from torch.types import Device
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from energizer.datastores import PandasDataStoreForSequenceClassification

In [2]:
def rademacher(
    shape,
    generator: Optional[torch.Generator] = None,
    device: Optional[Device] = None,
) -> torch.Tensor:
    """Sample from Rademacher distribution."""
    return torch.randint(0, 2, shape, generator=generator, device=device) * 2.0 - 1.0

In [3]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds_dict = load_dataset("pietrolesci/pubmed-rct20k_indexed").map(
    lambda ex: tokenizer(ex["text"]), batched=True, num_proc=4
)

Found cached dataset parquet (/home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-999f0e823471e99e_*_of_00004.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-7c0ef1ff89bb0eab_*_of_00004.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--pubmed-rct20k_indexed-58c1c04dc03e65ed/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-db9896ae5bc093c4_*_of_00004.arrow


In [4]:
ds = PandasDataStoreForSequenceClassification()
ds.from_dataset_dict(
    ds_dict,
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
    tokenizer=tokenizer,
    uid_name="uid",
)

In [6]:
batch = ds.show_batch("test")
_ = batch.pop("on_cpu")
batch

{'input_ids': tensor([[  101,  2023,  2817, 16578, 11290,  3853, 28828,  1999,  2540,  4945,
           5022,  4914,  2007,  5729, 11325, 21933,  8737,  6132,  4383,  2540,
           4945,  1006,  4748,  2232,  2546,  1007,  1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1]]),
 <InputKeys.TARGET: 'labels'>: tensor([0])}

In [40]:
ds.label(list(range(100)), round=0)

100

In [42]:
ds.train_size()

100

In [43]:
ds.prepare_for_loading(batch_size=32, eval_batch_size=256)

In [64]:
seed_everything(42)

# load model using data properties
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,  # type: ignore
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

In [65]:
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

In [66]:
# compute full gradient
for batch in ds.train_loader():
    batch.pop("on_cpu")
    loss = model(**batch).loss
    loss.backward()

full_grad = {n: p.grad for n, p in model.named_parameters() if p.requires_grad}

In [71]:
b_inp, b_att, b_lab = batch.values()
inp, att, lab = b_inp[0].squeeze(), b_att[0].squeeze(), b_lab[0].squeeze()

In [85]:
# loss


def _compute_loss(
    model,
    params: Dict,
    buffers: Dict,
    input_ids: Tensor,
    attention_mask: Tensor,
    labels: Tensor,
) -> torch.Tensor:
    inp, att, lab = (
        input_ids.unsqueeze(0),
        attention_mask.unsqueeze(0),
        labels.unsqueeze(0),
    )
    return functional_call(
        model, (params, buffers), (inp, att), kwargs={"labels": lab}
    ).loss


compute_loss = partial(_compute_loss, model)
compute_loss_batch = vmap(compute_loss, in_dims=(None, None, 0, 0, 0))

In [75]:
compute_loss(params, buffers, inp, att, lab)

tensor(1.5007)

In [77]:
compute_loss_batch(params, buffers, b_inp, b_att, b_lab)

tensor([1.5007, 1.5600, 1.5448, 1.4515])

In [None]:
# grads
compute_grad = grad(compute_loss)
compute_grad_batch = vmap(compute_grad, in_dims=(None, None, 0, 0, 0))

In [79]:
sample_grad = compute_grad(params, buffers, inp, att, lab)
len(sample_grad)

41

In [84]:
batch_grad = compute_grad_batch(params, buffers, b_inp, b_att, b_lab)
len(batch_grad)

41

In [87]:
# norm


def _gradnorm(grads: Dict, norm_type: int) -> Tensor:
    norms = [g.norm(norm_type).unsqueeze(0) for g in grads.values() if g is not None]
    return torch.concat(norms).norm(norm_type)


gradnorm = partial(_gradnorm, norm_type=2)


def compute_gradnorm(params, buffers, input_ids, attention_mask, labels):
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return gradnorm(grads)


compute_gradnorm_batch = vmap(
    compute_gradnorm, in_dims=(None, None, 0, 0, 0), randomness="same"
)

In [88]:
compute_gradnorm(params, buffers, inp, att, lab)

tensor(9.8751)

In [89]:
compute_gradnorm_batch(params, buffers, b_inp, b_att, b_lab)

tensor([ 9.8751, 10.0391,  8.4983, 10.0508])

In [96]:
# norm of project along the full gradient


def _gradnorm_full(grad: Dict, full_grad: Dict, norm_type: int) -> Tensor:
    # compute product
    prods = {k: full_grad[k] * grad[k] for k in full_grad}
    return _gradnorm(prods, norm_type)


gradnorm_full = partial(_gradnorm_full, full_grad=full_grad, norm_type=2)


def compute_gradnorm_full(params, buffers, input_ids, attention_mask, labels):
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return gradnorm_full(grads)


compute_gradnorm_full_batch = vmap(
    compute_gradnorm_full, in_dims=(None, None, 0, 0, 0), randomness="same"
)

In [94]:
compute_gradnorm_full(params, buffers, inp, att, lab)

tensor(4.2527)

In [97]:
compute_gradnorm_full_batch(params, buffers, b_inp, b_att, b_lab)

tensor([4.2526, 3.5917, 2.0297, 3.9550])

In [115]:
# cosine similarity with full gradient


def _cosine_similarity(grad, full_grad):
    return torch.nn.functional.cosine_similarity(
        torch.cat([p.flatten() for p in full_grad.values()]),
        torch.cat([p.flatten() for p in grad.values()]),
        dim=0,
    )


cosine_similarity = partial(_cosine_similarity, full_grad=full_grad)


def compute_cosine_similarity(params, buffers, input_ids, attention_mask, labels):
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return cosine_similarity(grads)


compute_cosine_similarity_batch = vmap(
    compute_cosine_similarity, in_dims=(None, None, 0, 0, 0), randomness="same"
)

In [116]:
compute_cosine_similarity(params, buffers, inp, att, lab)

tensor(0.6617)

In [118]:
compute_cosine_similarity_batch(params, buffers, b_inp, b_att, b_lab)

tensor([ 0.6617,  0.4107, -0.0971,  0.6107])

In [121]:
# distance from full gradient normed


def _diff(grad, full_grad):
    return torch.cat([p.flatten() for p in grad.values()]) - torch.cat(
        [p.flatten() for p in full_grad.values()]
    )


diff = partial(_diff, full_grad=full_grad)


def compute_diff(params, buffers, input_ids, attention_mask, labels):
    grads = compute_grad(params, buffers, input_ids, attention_mask, labels)
    return diff(grads).norm(2)


compute_diff_batch = vmap(
    compute_diff, in_dims=(None, None, 0, 0, 0), randomness="same"
)

In [122]:
compute_diff(params, buffers, inp, att, lab)

tensor(7.9565)

In [123]:
compute_diff_batch(params, buffers, b_inp, b_att, b_lab)

tensor([ 7.9565, 10.5898, 13.3072,  8.6203])

In [32]:
def hvp(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]


def f(x):
    return x.sin().sum()


x = torch.randn(2048)
tangent = torch.randn(2048)

result = hvp(f, (x,), (tangent,))

In [35]:
result.shape

torch.Size([2048])

In [31]:
# compute_grad(inp, att, lab),
vmap(grad(compute_loss))(b_inp, b_att, b_lab)

RuntimeError: only Tensors of floating point dtype can require gradients

In [None]:
compute_grad_norm_vect(b_inp, b_att, b_lab)

In [None]:
grads = vmap(compute_grad, in_dims=(None, None, 0, 0, 0), randomness="same")(
    params, buffers, b_inp, b_att, b_lab
)

In [None]:
# vmap(jacrev(f), in_dims=(None, None, 0, 0, 0), randomness="same")(params, buffers, b_inp, b_att, b_lab)

In [None]:
zs = [rademacher(p.size()) for p in grads.values()]

In [None]:
hvp(compute)

In [None]:
def trace(self, maxIter=100, tol=1e-3):
    """
    compute the trace of hessian using Hutchinson's method
    maxIter: maximum iterations used to compute trace
    tol: the relative tolerance
    """

    device = self.device
    trace_vhv = []
    trace = 0.0

    for i in range(maxIter):
        self.model.zero_grad()
        v = [torch.randint_like(p, high=2, device=device) for p in self.params]
        # generate Rademacher random variables
        for v_i in v:
            v_i[v_i == 0] = -1

        if self.full_dataset:
            _, Hv = self.dataloader_hv_product(v)
        else:
            Hv = hessian_vector_product(self.gradsH, self.params, v)
        trace_vhv.append(group_product(Hv, v).cpu().item())
        if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
            return trace_vhv
        else:
            trace = np.mean(trace_vhv)

    return trace_vhv

In [None]:
norms, ids = [], []
for batch in tqdm(datastore.train_loader(), disable=True):
    batch = self.transfer_to_device(batch)
    b_inp, b_att, b_lab = (
        batch[InputKeys.INPUT_IDS],
        batch[InputKeys.ATT_MASK],
        batch[InputKeys.TARGET],
    )
    norms += compute_grad_norm_vect(params, buffers, b_inp, b_att, b_lab).tolist()
    ids += batch[InputKeys.ON_CPU][SpecialKeys.ID]

norms = np.array(norms)
ids = np.array(ids)
topk_ids = norms.argsort()[-self.num_influential :]  # biggest gradient norm
return ids[topk_ids].tolist()

In [None]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)
_ = model.eval()

In [None]:
model.zero_grad()
loss = model(**batch).loss
torch.autograd.backward(loss, create_graph=True)

In [None]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
generator = torch.Generator().manual_seed(2147483647)
grads = [p.grad for p in params]
n_samples = 10


for i in range(n_samples):
    zs = [
        torch.randint(0, 2, p.size(), generator=generator) * 2.0 - 1.0
        for p in params
    ]  # Rademacher distribution {-1.0, 1.0}

    h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=False)

    if i == 0:
        hess = h_zs
    else:
        for d, hvp, v in zip(hess, h_zs, zs):
            d += v * hvp / n_samples

    # for idx, (h_z, z) in enumerate(zip(h_zs, zs)):
    #     hess[idx] += h_z / n_samples  # approximate the expected values of z*(H@z)

In [None]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
generator = torch.Generator().manual_seed(2147483647)
grads = [p.grad for p in params]
n_samples = 10

for _ in range(V):
    for p in model.parameters():
        v = [rademacher(p.size())]
        Hv = hessian_vector_product(loss, [p], v)
        vHv = torch.einsum("i,i->", v[0].flatten(), Hv[0].flatten())

        trace += vHv / V