In [1]:
from typing import Callable, List, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.func import functional_call, grad, vmap
from torch.utils.data import Dataset
from transformers import RobertaForSequenceClassification, RobertaTokenizer

from torchinfluence.methods import GradientSimilarity

%load_ext autoreload
%autoreload 2

In [2]:
tokenizer = RobertaTokenizer.from_pretrained("aychang/roberta-base-imdb")
model = RobertaForSequenceClassification.from_pretrained("aychang/roberta-base-imdb")
imdb = load_dataset("imdb")

In [3]:
class IMDBDataset(Dataset):
    def __init__(self, dataset, tokenizer, split: str = "train"):
        self.dataset = dataset[split]
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        label = self.dataset[idx]["label"]

        encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=256)

        return dict(encoding), torch.tensor(label).long()

    def decode(self, idx):
        input_ids = self[idx][0]["input_ids"][0]
        return self.tokenizer.decode(input_ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.dataset)

In [4]:
train_dataset = IMDBDataset(imdb, tokenizer, split="train")
test_dataset = IMDBDataset(imdb, tokenizer, split="test")

In [5]:
parameter_subset = ["roberta.embeddings.word_embeddings.weight"]

params = dict(model.named_parameters())
params = {name: param for name, param in params.items() if name in parameter_subset}

In [6]:
class HFSeqClfGradientSimilarity(GradientSimilarity):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: Callable,
        device: str = "cpu",
        parameter_subset: Optional[List[str]] = None,
    ):
        super().__init__(model, loss_fn, device, parameter_subset)

    def _compute_loss(self, params, inputs, targets):
        prediction = functional_call(self.model, params, (*inputs,))
        prediction = prediction.logits

        return self.loss_fn(prediction, targets)

    def dataset_gradients(self, inputs: Union[str, torch.Tensor], targets: torch.Tensor):
        inputs_shape = inputs["input_ids"].shape

        if (len(inputs_shape) == 3) and (inputs_shape[1] == 1):
            inputs = {k: v.squeeze(1) for k, v in inputs.items()}

        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        targets = targets.to(self.device)

        compute_grads = vmap(grad(self._compute_loss), in_dims=(None, 0, 0), chunk_size=self.chunk_size)
        grads = compute_grads(
            self.params,
            (
                input_ids.unsqueeze(1),
                attention_mask.unsqueeze(1),
            ),
            targets.unsqueeze(1),
        )
        grads = torch.hstack([g.flatten() for g in list(grads.values())]).reshape(inputs_shape[0], -1)
        return grads

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
gradsim = HFSeqClfGradientSimilarity(model, F.cross_entropy, device=device, parameter_subset=parameter_subset)

n_train, n_test = 250, 1

np.random.seed(0)

train_idxs = np.random.choice(np.arange(len(train_dataset)), n_train, replace=False).tolist()
test_idx = np.random.choice(np.arange(len(test_dataset)), n_test, replace=False).tolist()

In [11]:
scores = gradsim.score(
    train_dataset,
    test_dataset,
    subset_ids={"test": test_idx, "train": train_idxs},
    normalize=True,
    chunk_size=4,
)

  0%|          | 0/63 [00:00<?, ?it/s]

In [12]:
print(scores.shape)  # (n_test, n_train)

torch.Size([1, 250])


In [13]:
print("TEST EXAMPLE")
print(test_dataset.decode(test_idx))
print("")

TEST EXAMPLE
I remember disliking this movie the 1st time I saw it, but it has grown on me. I love the costumes and poses the actors make, the humor, the cinematography, the soundtrack. The scenes are very rich, and it moves very quickly. Every time I watch it, there is something new that catches my eye. Aaliyah as Akasha is probably the only thing that ruins it, but not enough.<br /><br />Also, the Lestat in this movie IS different, it is not the same character. You can see that the character Armand has been given Lestat-



In [14]:
print("MOST INFLUENTIAL TRAINING EXAMPLE")
print(train_dataset.decode(scores[0].sort(descending=True).indices[0].item()))

MOST INFLUENTIAL TRAINING EXAMPLE
When will the hurting stop? I never want to see another version of a Christmas Carol again. They keep on making movies with the same story, falling over each other in trying to make the movie better then the rest, but sadly fail to do so, as this is not a good story. Moralistic, old-fashioned, conservative happy-thinking. As if people learn. The numerous different versions of this film prove that we don´t.
