# Configs

In [16]:
config = {
    'model_id': 'mistralai/Mistral-7B-v0.1',
    'sequence_id': 'thematic',
    'text_sequence': "At an undetermined time between 18:00 on May 12, 2017 and 06:00 on May 13, 2017, at the parked delivery vehicle branded Peugeot Boxer, an unknown individual used an unidentified object to pry open the locks of the driver’s door, the passenger door, and then the cargo space. The individual entered the vehicle and stole from it a car radio, a demolition hammer, an electric saw, a drill, and other work tools, all valued at 8,700 CZK [...] By damaging the door lock, he caused damage worth 3,500 CZK. The stolen items were sold to unknown persons."
}
safe_model_id = config['model_id'].replace('/', '')
safe_sequence_id = config['sequence_id'].replace('/', '')

# Installs and Secrets

In [17]:
!pip install -U bertviz -q

In [18]:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, utils
import numpy as np
import torch

In [19]:
hf_token = ''

# Helper Functions
From https://github.com/jessevig/bertviz/blob/master/bertviz/util.py, as implemented in Vig (2019), [A Multiscale Visualization of Attention in the Transformer Model](https://aclanthology.org/P19-3007.pdf) (ACL System Demonstrations 2019).

In [20]:
def num_layers(attention):
    return len(attention)


def num_heads(attention):
    return attention[0][0].size(0)


def format_special_chars(tokens):
    return [t.replace('Ġ', ' ').replace('▁', ' ').replace('</w>', '') for t in tokens]


def format_attention(attention, layers=None, heads=None): # takes a set of attention matrices (one per layer), removes the batch dimension (if it exists), and then stacks them into a single tensor with shape (num_layers, num_heads, seq_len, seq_len)
    if layers:
        attention = [attention[layer_index] for layer_index in layers]
    squeezed = []
    for layer_attention in attention:
        # 1 x num_heads x seq_len x seq_len
        if len(layer_attention.shape) != 4:
            raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
                             "output_attentions=True when initializing your model.")
        layer_attention = layer_attention.squeeze(0)
        if heads:
            layer_attention = layer_attention[heads]
        squeezed.append(layer_attention)
    # num_layers x num_heads x seq_len x seq_len
    return torch.stack(squeezed)

# Instantiate Model and Components

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config['model_id'], use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(config['model_id'], output_attentions=True, use_auth_token=hf_token)

# Forward Pass on the Text Sequence

In [25]:
text_sequence = config['text_sequence']
inputs = tokenizer.encode(text_sequence, return_tensors='pt')              # Converts the text_sequence into token IDs and returns them as a PyTorch tensor.
outputs = model(inputs)                                                    # This is used for forward passes through the model, typically when you want to obtain the raw output logits or hidden states from the model, not for generating text, which uses model.generate().
attention_matrices = format_attention(outputs[-1])                         # Obtain the stacked attention matrices (one for each layer)
output_tokens = tokenizer.convert_ids_to_tokens(inputs[0])                 # Converts the token IDs back into their corresponding token strings. Output: ['<s>', '▁Question', ':', '▁What', '▁is', '▁the', '_time', '?']

In [None]:
print(f"Model Output Tokens: {output_tokens}\n")
print(f"Model Output Sequence Length: {len(output_tokens)}\n")
print(f"Model Attention Matrices: {type(attention_matrices)} of shape {attention_matrices.shape}\n")
print("Preview of attention scores in layer 0 head 0:\n")
print(attention_matrices[0][0])

### Save raw attention matrices

In [27]:
np.save(f'{safe_sequence_id}_{safe_model_id}_attentions_raw.npy', attention_matrices.detach().numpy())

# Exclude attention of tokens of low semantic importance, such as special tokens

### Inspect Model Output Tokens

In [None]:
print(f"Model Output Tokens: {output_tokens}\n")

### Specify tokens to be excluded

In [29]:
tokens_to_exclude = ['<s>', ':', '▁', ',', '.', '▁[...]']

### Identify indices of excluded tokens in the token sequence

In [None]:
indices_of_tokens_to_exclude = [i for i, token in enumerate(output_tokens) if token in tokens_to_exclude]
number_of_tokens_to_exclude = sum(output_tokens.count(token) for token in tokens_to_exclude)
print(indices_of_tokens_to_exclude)
assert len(indices_of_tokens_to_exclude) == number_of_tokens_to_exclude, "Number of indices retrieved does not match number of tokens to exclude."

### Reset attention scores of excluded token indices to 0

In [None]:
attention_matrices_filtered = attention_matrices.clone()

for idx in indices_of_tokens_to_exclude:
    attention_matrices_filtered[:, :, :, idx] = 0  # Setting attention for excluded tokens across all layers and heads

print(f"Filtered attention matrix of layer 0, head 0:\n")
print(attention_matrices_filtered[0][0])

### Save filtered attention matrices

In [32]:
np.save(f'{safe_sequence_id}_{safe_model_id}_attentions_filtered.npy', attention_matrices_filtered.detach().numpy())

# Compute proportion of filtered attention given to tokens representing legal facets

### Inspect Tokens

In [None]:
print(f"Model Output Tokens: {output_tokens}\n")

### Specify tokens of indices representing legal facets

In [34]:
tokens_to_compute = ['▁entered', '▁the', '▁vehicle', '▁and', '▁stole', '▁from', '▁it']

In [None]:
indices_of_tokens_to_compute = [i for i, token in enumerate(output_tokens) if token in tokens_to_compute]
number_of_tokens_to_compute = sum(output_tokens.count(token) for token in tokens_to_compute)
print(indices_of_tokens_to_compute)
# assert len(indices_of_tokens_to_compute) == number_of_tokens_to_compute, "Number of indices retrieved does not match number of tokens to exclude."
indices_of_tokens_to_compute = [ ]
print(indices_of_tokens_to_compute)

### Compute proportion of attention given to specified token indices

In [36]:
num_layers, num_heads, seq_len, _ = attention_matrices_filtered.shape

# Step 1: Create an empty NumPy array to store the proportion for each layer and head
proportion_matrix = np.zeros((num_layers, num_heads))

# Iterate over each layer and each head
for layer in range(num_layers):
    for head in range(num_heads):
        # Step 2: Extract the attention matrix for the current layer and head
        current_attention = attention_matrices_filtered[layer][head]

        # Step 3: Sum the attention directed to the indices of interest
        attention_sum_indices = current_attention[:, indices_of_tokens_to_compute].sum().item()

        # Step 4: Sum all the attention values for the current head
        total_attention_sum = current_attention.sum().item()

        # Step 5: Compute the proportion of attention to the indices of interest
        proportion = attention_sum_indices / total_attention_sum

        # Step 6: Store the proportion in the NumPy proportion matrix
        proportion_matrix[layer, head] = proportion

### Save proportion matrix to local

In [37]:
np.save(f'{safe_sequence_id}_{safe_model_id}_proportions.npy', proportion_matrix)

# Compare proportion of attention given to token indices across Mistral-7B-v0.1, SaulLM-7B, and SaulLM-7B-Instruct

### Load the proportion matrices of each model

In [None]:
proportions_mistral = np.load('thematic_mistral-7b-v0.1_proportions.npy')
proportions_saul = np.load('thematic_saul-7b_proportions.npy')
proportions_saul_instruct = np.load('thematic_saul-7b-instruct_proportions.npy')

### Visualize proportion matrix

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming 'proportion_matrix' is already a NumPy array
# Example proportion_matrix (replace this with the actual one from the code above)
# proportion_matrix = np.array(...)

# Step 1: Create a heatmap using seaborn
plt.figure(figsize=(8, 8))  # Set the figure size to make it square
sns.heatmap(proportion_matrix, annot=False, fmt=".2f", cmap='viridis', square=True, cbar_kws={"shrink": .8})

# Step 2: Set labels and title
plt.title(f"Proportion of Attention Allocated to Legal Facet Tokens {', '.join(tokens_to_compute)}\n{safe_model_id}")
plt.xlabel('Heads')
plt.ylabel('Layers')

# Step 3: Show the heatmap
plt.show()
