[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tanikina/mi-tutorials/blob/main/notebooks/MiniProject_Token_Attribution.ipynb)

In [None]:
!pip install transformers[torch]
!pip install numpy
!pip install datasets
!pip install captum

In [4]:
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
import numpy as np
import datasets

# Setting seed for reproducibility
seed_num = 2027
torch.manual_seed(seed_num)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
!wget https://raw.githubusercontent.com/tanikina/mi-tutorials/main/datasets/xsid_de_train.csv
!wget https://raw.githubusercontent.com/tanikina/mi-tutorials/main/datasets/xsid_de_val.csv
!wget https://raw.githubusercontent.com/tanikina/mi-tutorials/main/datasets/xsid_de_test.csv
!wget https://raw.githubusercontent.com/tanikina/mi-tutorials/main/datasets/xsid_de_examples.csv

--2025-02-18 08:05:36--  https://raw.githubusercontent.com/tanikina/mi-tutorials/main/datasets/xsid_de_train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 228806 (223K) [text/plain]
Saving to: ‘xsid_de_train.csv.2’


2025-02-18 08:05:37 (5.66 MB/s) - ‘xsid_de_train.csv.2’ saved [228806/228806]

--2025-02-18 08:05:37--  https://raw.githubusercontent.com/tanikina/mi-tutorials/main/datasets/xsid_de_val.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
200 OKequest sent, awaiting response... 
Length: 35932 (35K) [text/plain]
Saving to: ‘xsid_de_val.csv.2’


2025-02-18 08:05:37 (7.

In [5]:
class TokenBaselineModel(nn.Module):
    """Baseline model."""

    def __init__(self, num_labels, embedding_model_name, embedding_size=768, hidden_dim=256):
        """
        Arguments:
            num_labels (int): Number of final classification labels.
            embedding_model_name (str): Name of the embedding model
            embedding_size (int): Size of the input embeddings.
            hidden_dim (int): Hidden dimenion size.
        """
        super(TokenBaselineModel,self).__init__()
        self.num_labels = num_labels
        self.embedding_size = embedding_size
        self.hidden_dim = hidden_dim

        self.embedding_model = AutoModel.from_pretrained(embedding_model_name)
        self.embedding_model.eval()

        self.classifier_in = nn.Linear(self.embedding_size, self.hidden_dim)
        self.relu = nn.ReLU()
        self.classifier_out = nn.Linear(self.hidden_dim, self.num_labels)

    def forward(self, batch):
        token_embeddings = self.embedding_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).last_hidden_state
        token_embeddings = torch.mean(token_embeddings, dim=-2)
        label = batch["labels"]

        classifier_scores = self.classifier_in(token_embeddings)
        classifier_scores = self.classifier_out(self.relu(classifier_scores))
        return classifier_scores

In [6]:
def train(model, model_save_path, train_dataloader, dev_dataloader, num_epochs, loss_function, optimizer):
    model.to(device)
    no_decay = ["bias", "LayerNorm.weight"]
    min_dev_loss = None
    min_dev_loss_epoch = 0
    # training loop
    for epoch in range(num_epochs):
        train_losses_per_epoch = []
        for batch_idx, batch in enumerate(train_dataloader):
            for k, v in batch.items():
                batch[k] = v.to(device)
            labels = batch["labels"]
            predictions = model(batch)
            loss = loss_function(predictions, labels)
            train_losses_per_epoch.append(loss.item())
            loss.backward()
            param_grads = [n for n, p in model.named_parameters() if p.grad is not None]
            optimizer.step()
            optimizer.zero_grad()
        if epoch % 5 == 0:
            print("Epoch:", epoch, "Train loss:", round(sum(train_losses_per_epoch)/len(train_losses_per_epoch),3))
            
        # evaluation on development set
        with torch.no_grad():
            dev_losses_per_epoch = []
            for batch_idx, batch in enumerate(dev_dataloader):
                for k, v in batch.items():
                    batch[k] = v.to(device)#v.detach().clone().to(device)
                labels = batch["labels"]
                predictions = model(batch)

                loss = loss_function(predictions, labels)
                dev_losses_per_epoch.append(loss.item())
            cur_dev_loss = round(sum(dev_losses_per_epoch)/len(dev_losses_per_epoch),3)
            if min_dev_loss is None or (cur_dev_loss < min_dev_loss):
                min_dev_loss = cur_dev_loss
                min_dev_loss_epoch = epoch
                torch.save(model.state_dict(), model_save_path)
            if epoch % 5 == 0:
                print("Epoch:", epoch, "Dev loss:", cur_dev_loss)
                print("***")
    print("Best epoch:", min_dev_loss_epoch)

In [7]:
def evaluate(model, model_load_path, test_dataloader, test_data_path, batch_size, id2label):
    model.load_state_dict(torch.load(model_load_path))
    model.to(device)
    scores_per_class = dict()
    correct = 0
    total_sample = 0
    acc_total = 0
    macro_f1 = 0

    dataset_obj = datasets.Dataset.from_csv(test_data_path, delimiter="\t")
    dataset = []
    for i in range(len(dataset_obj)):
        dataset.append((dataset_obj["text"][i], dataset_obj["intent"][i]))
    test_dataloader_texts = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # test evaluation loop
    with torch.no_grad():
        for batch, text_inputs in zip(test_dataloader, test_dataloader_texts):
            for k, v in batch.items():
                batch[k] = v.to(device)#v.detach().clone().to(device)
            labels = batch["labels"]
            _, gold_idx = torch.max(labels, dim=-1)
            predictions = model(batch)
            _, max_idx = torch.max(predictions, dim=-1)

            predicted_array = np.array(max_idx.cpu())
            target_array = np.array(gold_idx.cpu())

            text_and_label_per_batch = [(txt, lbl) for txt, lbl in zip(text_inputs[0], text_inputs[1])]

            for i in range(len(predicted_array)):
                predicted_sample = predicted_array[i]
                target_sample = target_array[i]
                txt, lbl = text_and_label_per_batch[i]
                if id2label[predicted_sample] != id2label[target_sample]:
                    print("text:", txt, "gold label:", lbl, "predicted sample:", id2label[predicted_sample], "target sample:", id2label[target_sample])
                if not(predicted_sample) in scores_per_class:
                    scores_per_class[predicted_sample] = {"tp":0, "fp":0, "fn":0}
                if not(target_sample) in scores_per_class:
                    scores_per_class[target_sample] = {"tp":0, "fp":0, "fn":0}

                if predicted_sample == target_sample:
                    scores_per_class[predicted_sample]["tp"] += 1.0
                else:
                    scores_per_class[predicted_sample]["fp"] += 1.0
                    scores_per_class[target_sample]["fn"] += 1.0

            correct += (predicted_array == target_array).sum()

            total_sample += len(predicted_array)

        acc_total = correct/total_sample

        f1_per_class = dict()
        for cls in scores_per_class:
            tp = scores_per_class[cls]["tp"]
            fp = scores_per_class[cls]["fp"]
            fn = scores_per_class[cls]["fn"]
            if (tp+fp) > 0:
                prec = tp / (tp+fp)
            else:
                prec = 0
            if (tp+fn) > 0:
                rec = tp / (tp+fn)
            else:
                rec = 0
            if (prec + rec) > 0:
                cls_f1 = 2 * (prec * rec) / (prec + rec)
            else:
                cls_f1 = 0
            f1_per_class[cls] = cls_f1

            macro_f1 = sum([score for cls, score in f1_per_class.items()])/len(f1_per_class)
    return acc_total, macro_f1

In [8]:
class DatasetEncoder:
    def __init__(self, tokenizer, label2id):
        self.tokenizer = tokenizer
        self.label2id = label2id

    def map_labels(self, data):
        return {"labels": [[1. if i==self.label2id[intent] else 0 for i in range(len(self.label2id))] \
                           for intent in data["intent"]], "input_ids": data["input_ids"], "attention_mask": data["attention_mask"]}

    def encode_data(self, data):
        return self.tokenizer([doc_tokens for doc_i, doc_tokens in enumerate(data["text"])], pad_to_max_length=True, \
                               padding="max_length", max_length=64, truncation=True, add_special_tokens=True)

def preprocess_dataloader(data, do_shuffle, label2id, seq_len, batch_size, tokenizer):
    # create a dataset from a csv file
    d_encoder = DatasetEncoder(tokenizer, label2id)
    dataset = datasets.Dataset.from_csv(data, delimiter="\t")
    # map intent to label id and tokenize text
    dataset = dataset.map(d_encoder.encode_data, batched=True, batch_size=batch_size)
    dataset = dataset.map(d_encoder.map_labels, batched=True, batch_size=batch_size)
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=do_shuffle)
    return dataloader

In [6]:
model_save_path = "baseline_multilingual_model.pth"

# label encoding
all_labels = ['weather/find', 'alarm/set_alarm', 'alarm/show_alarms', 'reminder/set_reminder', 'alarm/modify_alarm', \
              'weather/checkSunrise', 'weather/checkSunset', 'alarm/snooze_alarm', 'alarm/cancel_alarm', \
              'reminder/show_reminders', 'reminder/cancel_reminder', 'alarm/time_left_on_alarm', 'AddToPlaylist', \
              'BookRestaurant', 'PlayMusic', 'RateBook', 'SearchCreativeWork', 'SearchScreeningEvent']
id2label = dict()
for i in range(len(all_labels)):
    id2label[i] = all_labels[i]
label2id = dict()
for k, v in id2label.items():
    label2id[v] = k

batch_size = 32
seq_len = 64
num_epochs = 20

lang_name = "de"
embedding_model_name = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

# dataset preprocessing
train_data = "xsid_de_train.csv"
dev_data = "xsid_de_val.csv"
test_data = "xsid_de_test.csv"

train_dataloader = preprocess_dataloader(data=train_data, do_shuffle=True, label2id=label2id, seq_len=seq_len, batch_size=batch_size, tokenizer=tokenizer)
dev_dataloader = preprocess_dataloader(data=dev_data, do_shuffle=True, label2id=label2id, seq_len=seq_len, batch_size=batch_size, tokenizer=tokenizer)
test_dataloader = preprocess_dataloader(data=test_data, do_shuffle=False, label2id=label2id, seq_len=seq_len, batch_size=batch_size, tokenizer=tokenizer)

# training
base_model = TokenBaselineModel(num_labels=len(all_labels), embedding_model_name=embedding_model_name)
base_model = base_model.to(device)

loss_function = nn.BCEWithLogitsLoss()
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [{"params": [p for n, p in base_model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 1e-4,}, \
                                {"params": [p for n, p in base_model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0,},]
optimizer = torch.optim.AdamW(lr=1e-5, params=optimizer_grouped_parameters) # 2e-3 if frozen
# NB: larning rate otimization: lr=1e-5 when non-frozen, otherweise lr=2e-3

train(base_model, model_save_path, train_dataloader, dev_dataloader, num_epochs, loss_function, optimizer)
acc_total, macro_f1 = evaluate(model=TokenBaselineModel(num_labels=len(all_labels), embedding_model_name=embedding_model_name), \
                                model_load_path=model_save_path, test_dataloader=test_dataloader, test_data_path=test_data, batch_size=batch_size, id2label=id2label)
print("Accuracy:", acc_total)
print("Macro F1:", macro_f1)
print("Finished!")

2025-02-18 08:10:56.468551: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739866256.488160     672 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739866256.494061     672 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Epoch: 0 Train loss: 0.46
Epoch: 0 Dev loss: 0.354
***
Epoch: 5 Train loss: 0.166
Epoch: 5 Dev loss: 0.162
***
Epoch: 10 Train loss: 0.067
Epoch: 10 Dev loss: 0.084
***
Epoch: 15 Train loss: 0.035
Epoch: 15 Dev loss: 0.064
***
Best epoch: 19
text: erstelle eine Geburtstagserinnerung für max gold label: reminder/set_reminder predicted sample: SearchCreativeWork target sample: reminder/set_reminder
text: Stelle den Wecker auf 8:00 Uhr ein gold label: alarm/set_alarm predicted sample: alarm/modify_alarm target sample: alarm/set_alarm
text: Erstelle einen Wecker für 3 Uhr nachmittags . gold label: alarm/set_alarm predicted sample: alarm/modify_alarm target sample: alarm/set_alarm
text: entferne alle Wecker gold label: alarm/cancel_alarm predicted sample: alarm/show_alarms target sample: alarm/cancel_alarm
text: Erstelle einen täglichen Wecker für 8 Uhr morgens gold label: alarm/set_alarm predicted sample: alarm/show_alarms target sample: alarm/set_alarm
text: schneit es in kanada gold labe

In [1]:
import torch
from captum.attr import LayerIntegratedGradients
from transformers import AutoModel, AutoTokenizer

softmax = torch.nn.Softmax(dim=0)


def get_embedding_layer(model):
    return model.embedding_model.embeddings


def get_inputs_and_additional_args(batch):
    input_ids = batch["input_ids"]
    additional_forward_args = {"attention_mask": batch["attention_mask"], "labels": batch["labels"]}
    return input_ids, additional_forward_args

In [2]:
def get_forward_func():
    def bert_forward(input_ids, extra):
        # adapt to input mask and enlarge the dimension of input_mask
        input_ids.to(device)
        attention_masks = extra["attention_mask"]
        labels = extra["labels"]
        attention_masks.to(device)
        input2model = {
            'input_ids': input_ids.long().cuda(),
            'attention_mask': attention_masks.long().cuda(),
            'labels': labels.cuda(),
        }
        output_model = token_model(input2model)
        output_model.to(device)
        return output_model

    return bert_forward

def compute_feature_attribution_scores(batch, model, tokenizer):
    model.to(device)
    model.eval()
    model.zero_grad()
    inputs, additional_forward_args = get_inputs_and_additional_args(
        batch=batch
    )

    inputs.to(device)
    for k,v in additional_forward_args.items():
        additional_forward_args[k] = v.to(device)

    forward_func = get_forward_func()
    predictions = forward_func(
        input_ids=inputs,
        extra=additional_forward_args
    )
    predictions.to(device)
    
    pred_id = torch.argmax(predictions, dim=1)
    pred_id.to(device)

    baseline = torch.zeros(batch["input_ids"].shape)
    for bi in range(baseline.shape[0]):
        for j in range(baseline.shape[1]):
            if batch["input_ids"][bi][j] not in [tokenizer.cls_token_id, tokenizer.sep_token_id]:
                baseline[bi][j] = tokenizer.pad_token_id

    baseline.to(device)

    explainer = LayerIntegratedGradients(forward_func=forward_func,
                                         layer=get_embedding_layer(model))

    attributions = explainer.attribute(
        inputs=inputs,
        n_steps=50,
        additional_forward_args=additional_forward_args,
        target=pred_id,
        baselines=baseline,
        internal_batch_size=1,
    )
    attributions = torch.sum(attributions, dim=2).to(device)
    return attributions, predictions

In [9]:
embedding_model_name = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)

batch_size = 1
seq_len = 64

token_model = TokenBaselineModel(num_labels=len(all_labels), embedding_model_name=embedding_model_name)
token_model.load_state_dict(torch.load('baseline_multilingual_model.pth'))
token_model.to(device=device)

train_data = "xsid_de_examples.csv"
train_dataloader = preprocess_dataloader(data=train_data, do_shuffle=False, label2id=label2id, seq_len=seq_len, batch_size=batch_size, tokenizer=tokenizer)

samples_as_subtokens = []
labels = []
attributions = []
texts = []

for idx_batch, b in enumerate(train_dataloader):
    for k,v in b.items():
        b[k] = v.to(device)

    attribution, predictions = compute_feature_attribution_scores(b, token_model, tokenizer)
    num_valid_tokens = sum(b["attention_mask"][0])
    attribution_softmax = softmax(attribution[0][:num_valid_tokens])
    subtokens = tokenizer.convert_ids_to_tokens(b["input_ids"][0])[:num_valid_tokens]
    samples_as_subtokens.append(subtokens)
    lbl = id2label[torch.argmax(b["labels"][0]).item()]
    labels.append(lbl)
    text = tokenizer.batch_decode(b["input_ids"], skip_special_tokens=True)[0]
    texts.append(text)
    print("text:", text)
    print("label:", lbl)
    # store which subtokens have which importance values
    attributions.append(attribution_softmax.tolist())
    for attr, subtoken in zip(attribution_softmax.tolist(), subtokens):
        if subtoken != "[PAD]":
            print(subtoken, attr)
    print()

2025-02-18 15:53:17.854463: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739893997.874757    8290 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739893997.880654    8290 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


torch.Size([1, 64, 768])
text: Wie lautet die heutige Vorhersage?
label: weather/find
[CLS] 0.06412150679033628
Wie 0.09448230235518984
lautet 0.2632113827227858
die 0.0675399957184859
heutige 0.06406173413556233
Vor 0.06353771371718475
##hers 0.10712994364392045
##age 0.08937855484134058
? 0.1351664079156189
[SEP] 0.05137045815957513

torch.Size([1, 64, 768])
text: Wettervorhersage von LakeBay
label: weather/find
[CLS] 0.043754347552641584
Wet 0.07602399893412413
##ter 0.05588779470884589
##vor 0.03776262141448787
##hers 0.12906740431943856
##age 0.07647517104537327
von 0.20655562728908142
Lake 0.06988461196161852
##B 0.10673043224875976
##ay 0.13993064778692788
[SEP] 0.057927342738701225

torch.Size([1, 64, 768])
text: Wann hört der Regen auf?
label: weather/find
[CLS] 0.02805893972056959
Wan 0.046485328186228095
##n 0.06175652453927059
hör 0.022320551795763838
##t 0.047133747465526565
der 0.04464398071716183
Reg 0.033953021016527196
##en 0.03653102279347473
auf 0.05972954401627704
?