# **Self-Attention Attribution: Interpreting Information Interactions Inside Transformer**

## **1️ Introduction**
Transformers use **self-attention mechanisms** to capture dependencies between words in a sentence. However, interpreting these attention mechanisms is **challenging** because attention scores do not directly indicate importance.

This paper introduces **Self-Attention Attribution (ATTATTR)**, an **Integrated Gradients-based** method to:
- **Attribute importance** to self-attention connections.
- **Interpret information flow** between token pairs.
- **Prune redundant attention heads** while preserving model performance.

# **2️ Comparison: Normal Integrated Gradients (IG) vs. Self-Attention Attribution (ATTATTR)**

## **- Example Sentence**
Consider the sentence:

> `"This paper introduces a new interpretation method."`

Let's analyze how **Normal IG** and **Self-Attention Attribution (ATTATTR)** attribute importance.

---

## **- Normal Integrated Gradients (IG)**

- **Measures feature importance** at the **token level**.
- Computes **gradients w.r.t. input embeddings**.

### **Example: Normal IG at Token Level**
| Token   | IG Attribution |
|---------|--------------|
| **This** | 0.21 |
| **paper** | 0.35 |
| **introduces** | 0.25 |
| **a** | 0.05 |
| **new** | 0.08 |
| **interpretation** | 0.29 |
| **method** | 0.12 |

 **Key Limitation:** Normal IG **ignores token interactions** (e.g., how "paper" interacts with "introduces").

---

## **- Self-Attention Attribution (ATTATTR)**
- **Extends IG to token pairs** by computing gradients w.r.t. **attention scores** instead of input tokens.
- Uses **self-attention matrices** as the feature of interest.
- **Interpolates between**:
  - A **zero attention baseline**.
  - The **actual attention scores**.

### **Example: Self-Attention Attribution at Token-Pair Level**
| Token 1 | Token 2 | Attention Attribution |
|---------|---------|----------------------|
| **This** | **paper** | 0.42 |
| **paper** | **introduces** | 0.51 |
| **introduces** | **a** | 0.12 |
| **a** | **new** | 0.07 |
| **new** | **interpretation** | 0.38 |
| **interpretation** | **method** | 0.22 |

**Key Advantage:**  
- **Captures interactions** (e.g., "paper" → "introduces" has strong attribution) which is helpfull for better interpretting

---

 **This makes ATTATTR more useful for interpreting attention mechanisms in Transformers!**
## **3️ How ATTATTR Works**
### **- Step 1: Define Self-Attention Attribution**
Self-Attention Attribution extends **Integrated Gradients (IG)** to measure the **importance of attention scores** rather than individual token embeddings.

It is defined as:

$$
\text{Attr}_h(A) = A_h \circ \int_{0}^{1} \frac{\partial F(\alpha A)}{\partial A_h} d\alpha \in \mathbb{R}^{n \times n}
$$


This equation follows the **Integrated Gradients approach**, but instead of working with token embeddings, it operates on **attention matrices**.


### **- Step 2: Compute Integrated Gradients for Attention**

This measures **how much each attention score contributes** to the final prediction.

### **- Step 3: Aggregate Importance Across Layers**
The final **self-attention attribution** is obtained by **summing across layers** and normalizing.

---
## **4️ Approach I Used**
- I used a pretained  **BERT model and I fine-tuned it to SST-2** for 2 epochs.
- Sentences are tokenized and passed through the model to extract self-attention scores.
- Now we compute IG all along the path for 20 steps where we compute the gradient of interpolated attention wrt to actual layer attention and intergate and scale it.
- This gives attribution of all 144 heads with respect to each pair and we take the max from each head.
- This method allows us to identify the most influential attention heads that contribute to the model’s final decision.
- In my approach for pruning I used top 3 atributions meaning the top 3 heads from each layer which contribute for the binary classification in SST.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load SST-2 dataset
dataset = load_dataset("glue", "sst2")

model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

model.to(device)

def tokenize(batch):
    return tokenizer(batch["sentence"], padding="max_length", truncation=True, max_length=128)

dataset = dataset.map(tokenize, batched=True)

dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

training_args = TrainingArguments(
    output_dir="./bert-sst2",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16,
    num_train_epochs=2,  #  epochs 
    weight_decay=0.01,
    logging_dir="./logs",
    fp16=True,  
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)

trainer.train()

# Save the fine-tuned model
model.save_pretrained("./fine-tuned-bert-sst2")
tokenizer.save_pretrained("./fine-tuned-bert-sst2")


Using device: cuda


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss
1,0.1717,0.241471
2,0.1177,0.309403


('./fine-tuned-bert-sst2\\tokenizer_config.json',
 './fine-tuned-bert-sst2\\special_tokens_map.json',
 './fine-tuned-bert-sst2\\vocab.txt',
 './fine-tuned-bert-sst2\\added_tokens.json')

In [56]:
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification

model_path = "./fine-tuned-bert-sst2"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path, output_attentions=True, output_hidden_states=True)
model.eval()

sentences = [
    "The plot was interesting but the execution was poor.",
    "Absolutely stunning visuals with no real story.",
    "I will never watch this garbage again!",
    "I can't believe how fantastic this movie was!",
]


inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"]
inputs_embeds = model.bert.embeddings.word_embeddings(input_ids).detach()
inputs_embeds.requires_grad = True 

with torch.no_grad():
    outputs = model(inputs_embeds=inputs_embeds, output_attentions=True)
    original_attentions = torch.stack(outputs.attentions)  

num_steps = 20
integrated_gradients = torch.zeros_like(inputs_embeds)

for alpha in torch.linspace(0, 1, num_steps):
    interpolated_attention = (alpha * original_attentions).detach().clone()  
    interpolated_embeds.requires_grad = True

    interpolated_outputs = model(inputs_embeds=inputs_embeds, output_attentions=True)
    loss = interpolated_outputs.logits.max()

    gradients = torch.autograd.grad(loss, inputs_embeds, retain_graph=True)[0]
    integrated_gradients += gradients / num_steps  

expanded_integrated_gradients = integrated_gradients.sum(dim=-1, keepdim=True)
expanded_integrated_gradients = expanded_integrated_gradients.unsqueeze(0).unsqueeze(2).expand_as(original_attentions)

attribution_scores = original_attentions * expanded_integrated_gradients
num_layers, batch_size, num_heads = attribution_scores.shape[:3]
max_attributions_per_layer = attribution_scores.max(dim=-1)[0].max(dim=-1)[0]
top_heads_per_layer = [
    [torch.argsort(max_attributions_per_layer[layer_idx, batch_idx], descending=True)[:3].tolist()
     for batch_idx in range(batch_size)]
    for layer_idx in range(num_layers)
]
print("\n### Max Attributions Per Layer ###")
for layer_idx in range(num_layers):
    print(f"Layer {layer_idx}: {max_attributions_per_layer[layer_idx].tolist()}")
print("\n### Keeping Only Top 3 Heads Per Layer ###")
for layer_idx, batch_heads in enumerate(top_heads_per_layer):
    for batch_idx, heads in enumerate(batch_heads):
        print(f"Layer {layer_idx}, Input {batch_idx}: Keeping Only Heads {heads}")

pruned_attributions = torch.zeros_like(attribution_scores)
for layer_idx in range(num_layers):
    for batch_idx in range(batch_size):
        mask = torch.zeros(num_heads, device=attribution_scores.device)
        mask[top_heads_per_layer[layer_idx][batch_idx]] = 1  # Keep only top heads
        pruned_attributions[layer_idx, batch_idx] = attribution_scores[layer_idx, batch_idx] * mask[:, None, None]

with torch.no_grad():
    outputs_after = model(input_ids=input_ids)
    preds_after = torch.argmax(outputs_after.logits, dim=1)

print(f"\nPredictions Before Pruning: {outputs.logits.argmax(dim=1).tolist()}")
print(f"Predictions After Pruning: {preds_after.tolist()}")



### Max Attributions Per Layer ###
Layer 0: [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.2838859220209997e-07, 4.099202612906083e-07, 6.938620913388149e-07, 4.86448016090435e-07, 1.8767563858546055e-07, 2.583869616046286e-07, 2.3428418671755935e-07, 2.577722568730678e-07, 1.7206644997713738e-07, 3.7211111703072675e-07, 9.305368848799844e-07, 2.52074244144751e-07]]
Layer 1: [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.030790468183113e-07, 5.438438961391512e-07, 8.766593850850768e-07, 7.747558470327931e-07, 3.8774999211454997e-07, 3.084090280935925e-07, 4.3068607169516326e-07, 8.884110798135225e-07, 6.532380893986556e-07, 6.083054131522658e-07, 3.7765150295854255e-07, 7.169562650233274e-07]]
Layer 2: [