In [1]:
import networkx as nx
import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from torchinfluence.methods import HFSeqClfGradientSimilarity

tokenizer = AutoTokenizer.from_pretrained("mrm8488/distilroberta-finetuned-banking77")
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/distilroberta-finetuned-banking77")

dataset = load_dataset("banking77")

In [2]:
def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")


dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset))
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

In [3]:
class Banking77Dataset(Dataset):
    def __init__(self, dataset, split="train"):
        self.dataset = dataset[split]

    def __getitem__(self, idx):
        label = self.dataset[idx]["label"]
        input_ids = self.dataset[idx]["input_ids"]
        attention_mask = self.dataset[idx]["attention_mask"]

        return {"input_ids": input_ids, "attention_mask": attention_mask}, label

    def __len__(self):
        return len(self.dataset)

In [4]:
dataset = dataset.shuffle()

banking77_train = Banking77Dataset(dataset, split="train")
banking77_test = Banking77Dataset(dataset, split="test")

In [6]:
gradsim = HFSeqClfGradientSimilarity(
    model, torch.nn.CrossEntropyLoss(), parameter_subset=["classifier.out_proj.weight"], device="cpu"
)

score_matrix = gradsim.score(
    banking77_train,
    banking77_test,
    # subset_ids={"train": list(range(3000)), "test": list(range(1000))},
    normalize=True,
    chunk_size=100,
)

  0%|          | 0/30809240 [00:00<?, ?it/s]

: 

In [15]:
CS = (score_matrix.T.float() @ score_matrix.float()).half()

In [16]:
G = nx.from_numpy_array(CS.numpy(), create_using=nx.DiGraph)
G.remove_edges_from(nx.selfloop_edges(G))

In [20]:
betweenness = pd.Series(nx.betweenness_centrality(G, weight="weight"))

In [28]:
df = pd.DataFrame(
    {
        "betweenness": betweenness,
        "text": dataset["test"].select(range(250))["text"],
        "label": dataset["test"].select(range(250))["label"],
    }
)