In [7]:
from datasets import load_dataset
import torch

imdb = load_dataset("imdb")

Found cached dataset imdb (/Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
100%|██████████| 3/3 [00:00<00:00, 295.13it/s]


In [2]:
small_train_dataset = (
    imdb["train"].shuffle(seed=42).select([i for i in list(range(3000))])
)
small_test_dataset = imdb["test"].shuffle(seed=42).select([i for i in list(range(300))])

Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9c48ce5d173413c7.arrow
Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c1eaa46e94dfbfd3.arrow


In [3]:
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("assemblyai/distilbert-base-uncased-sst2")
model = AutoModelForSequenceClassification.from_pretrained(
    "assemblyai/distilbert-base-uncased-sst2"
)

tokenized_segments = tokenizer(
    [
        "AssemblyAI is the best speech-to-text API for modern developers with performance being second to none!"
    ],
    return_tensors="pt",
    padding=True,
    truncation=True,
)
tokenized_segments_input_ids, tokenized_segments_attention_mask = (
    tokenized_segments.input_ids,
    tokenized_segments.attention_mask,
)
model_predictions = F.softmax(
    model(
        input_ids=tokenized_segments_input_ids,
        attention_mask=tokenized_segments_attention_mask,
    )["logits"],
    dim=1,
)

print("Positive probability: " + str(model_predictions[0][1].item() * 100) + "%")
print("Negative probability: " + str(model_predictions[0][0].item() * 100) + "%")

Positive probability: 96.0169792175293%
Negative probability: 3.9830222725868225%


In [9]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True)


tokenized_train = small_train_dataset.map(preprocess_function, batched=True)
tokenized_test = small_test_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-6582379d1c4310e2.arrow


Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c700964b93e7a374.arrow


In [37]:
class ImdbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        x = torch.tensor(self.encodings[idx])
        y = torch.tensor(self.labels[idx])
        return x, y

    def __len__(self):
        return len(self.labels)

In [46]:
train_dataset = ImdbDataset(tokenized_train["input_ids"], tokenized_train["label"])
test_dataset = ImdbDataset(tokenized_test["input_ids"], tokenized_test["label"])

In [72]:
model.base_model(
    torch.tensor(train_dataset[0][0]).unsqueeze(0),
    torch.tensor(train_dataset[0][0]).unsqueeze(0),
)

  model.base_model(torch.tensor(train_dataset[0][0]).unsqueeze(0), torch.tensor(train_dataset[0][0]).unsqueeze(0))


BaseModelOutput(last_hidden_state=tensor([[[-0.1542,  0.0906,  1.4949,  ...,  0.0321,  1.1469, -0.2308],
         [-0.2363,  0.5633, -0.1705,  ..., -0.2990,  0.9639,  0.7347],
         [-0.5188,  0.6324,  0.1162,  ..., -0.4970,  0.6356,  1.4636],
         ...,
         [ 0.3062,  0.1396,  0.1965,  ...,  0.0643,  0.4370, -0.0995],
         [-0.0399, -0.2684,  0.6871,  ...,  0.2505,  0.1654,  0.0542],
         [-0.0114, -0.2121,  0.7543,  ...,  0.3027,  0.2304, -0.0872]]],
       grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None)

In [79]:
loss = F.cross_entropy(
    model(
        torch.tensor(train_dataset[0][0]).unsqueeze(0),
        torch.tensor(train_dataset[0][0]).unsqueeze(0),
    )["logits"],
    torch.tensor(train_dataset[0][1]).unsqueeze(0),
)
loss.backward()

  loss = F.cross_entropy(model(torch.tensor(train_dataset[0][0]).unsqueeze(0), torch.tensor(train_dataset[0][0]).unsqueeze(0))['logits'], torch.tensor(train_dataset[0][1]).unsqueeze(0))


In [82]:
for param in model.base_model.parameters():
    print(param.grad.data.sum())

tensor(7.6890e-06)
tensor(1.3351e-05)
tensor(1.4605)
tensor(-2.6148)
tensor(-1.6427)
tensor(0.0928)
tensor(-15.6059)
tensor(-5.8470e-08)
tensor(0.4710)
tensor(0.5902)
tensor(-3.8147e-06)
tensor(2.3842e-07)
tensor(-0.9682)
tensor(0.5145)
tensor(11.0682)
tensor(-1.1260)
tensor(-1.5259e-05)
tensor(2.8312e-07)
tensor(-3.1945)
tensor(-0.2525)
tensor(-3.9610)
tensor(0.1884)
tensor(4.3107)
tensor(1.1079e-07)
tensor(-35.6873)
tensor(1.7198)
tensor(7.6294e-06)
tensor(1.1921e-07)
tensor(-0.7906)
tensor(-3.1745)
tensor(-8.5136)
tensor(0.3550)
tensor(-0.0001)
tensor(8.3447e-07)
tensor(-5.2181)
tensor(-36.4425)
tensor(-2.2278)
tensor(0.1347)
tensor(1.0476)
tensor(-7.5299e-08)
tensor(76.2288)
tensor(-3.8812)
tensor(-9.5367e-07)
tensor(8.3447e-07)
tensor(-5.1421)
tensor(-10.4402)
tensor(-51.4626)
tensor(1.9585)
tensor(-1.5259e-05)
tensor(-4.1723e-07)
tensor(-1.4264)
tensor(-14.5551)
tensor(-0.2876)
tensor(0.0015)
tensor(0.1487)
tensor(-1.0553e-07)
tensor(-11.1569)
tensor(0.6947)
tensor(4.1723e-06)
te

In [47]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=16, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

In [59]:
from pydvl.influence import compute_influences
from pydvl.influence.torch import TorchTwiceDifferentiable

ekfac_train_influences = compute_influences(
    TorchTwiceDifferentiable(model_logits, F.cross_entropy),
    training_data=train_dataloader,
    test_data=test_dataloader,
    influence_type="up",
    inversion_method="ekfac",
    hessian_regularization=0.1,
    progress=True,
)

AttributeError: 'function' object has no attribute 'training'