In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.func import functional_call, grad, vmap
from torch.utils.data import Dataset
from transformers import RobertaForSequenceClassification, RobertaTokenizer

from torchinfluence.methods import GradientSimilarity

%load_ext autoreload
%autoreload 2

In [2]:
tokenizer = RobertaTokenizer.from_pretrained("aychang/roberta-base-imdb")
model = RobertaForSequenceClassification.from_pretrained("aychang/roberta-base-imdb")
imdb = load_dataset("imdb")

In [3]:
class IMDBDataset(Dataset):
    def __init__(self, dataset, tokenizer, split: str = "train"):
        self.dataset = dataset[split]
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        label = self.dataset[idx]["label"]

        encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)

        return dict(encoding), torch.tensor(label).long()

    def decode(self, idx):
        input_ids = self[idx][0]["input_ids"][0]
        return self.tokenizer.decode(input_ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.dataset)

In [4]:
train_dataset = IMDBDataset(imdb, tokenizer, split="train")
test_dataset = IMDBDataset(imdb, tokenizer, split="test")

In [5]:
parameter_subset = ["roberta.embeddings.word_embeddings.weight"]

params = dict(model.named_parameters())
params = {name: param for name, param in params.items() if name in parameter_subset}

In [6]:
from typing import Callable, List, Optional, Union


class HFSeqClfGradientSimilarity(GradientSimilarity):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: Callable,
        device: str = "cpu",
        parameter_subset: Optional[List[str]] = None,
    ):
        super().__init__(model, loss_fn, device, parameter_subset)

    def _compute_loss(self, params, inputs, targets):
        prediction = functional_call(self.model, params, (*inputs,))
        prediction = prediction.logits

        return self.loss_fn(prediction, targets)

    def dataset_gradients(self, inputs: Union[str, torch.Tensor], targets: torch.Tensor):
        inputs_shape = inputs["input_ids"].shape

        if (len(inputs_shape) == 3) and (inputs_shape[1] == 1):
            inputs = {k: v.squeeze(1) for k, v in inputs.items()}

        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)

        compute_grads = vmap(grad(self._compute_loss), in_dims=(None, 0, 0), chunk_size=self.chunk_size)
        grads = compute_grads(
            self.params,
            (
                input_ids.unsqueeze(1),
                attention_mask.unsqueeze(1),
            ),
            targets.unsqueeze(1),
        )
        grads = torch.hstack([g.flatten() for g in list(grads.values())]).reshape(inputs_shape[0], -1)
        return grads

In [72]:
gradsim = HFSeqClfGradientSimilarity(model, F.cross_entropy)

train_idxs = np.random.choice(np.arange(len(train_dataset)), 100, replace=False).tolist()
test_idx = 0

In [73]:
scores = gradsim.score(
    train_dataset,
    test_dataset,
    subset_ids={"test": [test_idx], "train": train_idxs},
    normalize=True,
    chunk_size=8,
)

  0%|          | 0/13 [00:00<?, ?it/s]

In [74]:
scores.shape  # (n_test, n_train)

torch.Size([1, 100])

In [86]:
test_dataset.decode(test_idx)

"I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn't match the background, and painfully one-dimensional characters cannot be overcome with a'sci-fi' setting. (I'm sure there are those of you out there who think Babylon 5 is good sci-fi TV"

In [88]:
train_dataset.decode(scores[0].sort(descending=True).indices[0].item())

"I very much looked forward to this movie. Its a good family movie; however, if Michael Landon Jr.'s editing team did a better job of editing, the movie would be much better. Too many scenes out of context. I do hope there is another movie from the series, they're all very good. But, if another one is made, I beg them to take better care at editing. This story was all over the place and didn't seem to have a center. Which is unfortunate because the other movies of the series were great. I enjoy the story of Willie and Missy; they're both great role"