In [None]:
!pip install transformers
!pip install seaborn

In this Colab, we will use DistilBERT, as it is a small model.
It is distinct from the GPT models, in that it does not use Masked Multi-head attention.
This means that each output is allowed to attend to future tokens.
So do not be alarmed if you see that an output can attend future tokens!


In [None]:
from transformers import AutoTokenizer, AutoModel
import torch

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased", output_attentions=True)

We will use two sentences where the word _bank_ is ambiguous.

In [None]:
# Two sentences with the same word "bank" but different meanings
sent1 = "The bank of the river was calm."
sent2 = "She went to the bank to deposit money."


In [None]:
# Tokenize the sentences and get input IDs
# So far bank, is represented with the same token.
inputs = tokenizer([sent1, sent2], return_tensors="pt", padding=True, truncation=True)

# print token ids mapped to their respective tokens
sent1_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
sent2_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][1])
print(f"Tokens for sent1: {sent1_tokens}")
print(f"Tokens for sent2: {sent1_tokens}")
sent1_bank_index = sent1_tokens.index('bank')
sent2_bank_index = sent2_tokens.index('bank')
bank1 = inputs['input_ids'][0][sent1_bank_index].item()  # 'bank' in sent1
bank2 = inputs['input_ids'][1][sent2_bank_index].item()  # 'bank' in sent2
print(f"Token ID for 'bank' in sent1: {bank1}")
print(f"Token ID for 'bank' in sent2: {bank2}")


In [None]:
# Now we will pass the inputs through the model to get the hidden states
# The hidden states are the outputs of each layer in the model.
# By comparing the hidden states of the 'bank' token in both sentences, 
# we can see how the model differentiates between the two meanings based on context.
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states


In [None]:
# We can first look at the architecture of the model and the number of hidden layers
# There is the embeddings layer, and 6 transformer layers in DistilBERT, so we should have 7 hidden states (including the input embeddings).
print(f"Number of hidden layers: {len(hidden_states)}")
print(model)

In [None]:
# We can now look at the last hidden state (the output of the last transformer layer) for the 'bank' token in both sentences.
# And we can compare the embeddings of the 'bank' token in both sentences to see how they differ based on context.
output_layer = hidden_states[-1]
bank1 = output_layer[0][sent1_bank_index]
bank2 = output_layer[1][sent2_bank_index]
# compare the two 'bank' token embeddings
cosine_similarity = torch.nn.functional.cosine_similarity(bank1.unsqueeze(0), bank2.unsqueeze(0)).item()
print(f"Cosine similarity between 'bank' in sent1 and sent2: {cosine_similarity:.4f}")

In [None]:
# we can do this for each hidden layer to see how the embeddings evolve across layers
# early layers should have more similar embeddings for 'bank' in both sentences, 
# while later layers should differentiate more based on context.
for i, layer in enumerate(hidden_states):
    bank1 = layer[0][sent1_bank_index]
    bank2 = layer[1][sent2_bank_index]
    cosine_similarity = torch.nn.functional.cosine_similarity(bank1.unsqueeze(0), bank2.unsqueeze(0)).item()
    print(f"Layer {i}: Cosine similarity between 'bank' in sent1 and sent2: {cosine_similarity:.4f}")

In [None]:
# This shows that the model has learned to differentiate between the two meanings of 'bank' based on context, 
# as the cosine similarity is not very high, indicating that the embeddings for 'bank' in both sentences are different.
#
# Now we check the attention weights for the 'bank' token in both sentences to see how the model attends to different 
# parts of the input when processing the 'bank' token.

inputs = tokenizer([sent1, sent2], return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs, output_attentions=True)

attentions = outputs.attentions  # This is a tuple of attention weights for each transformer layer
# attentions is a tuple of length equal to the number of transformer layers, each element is a tensor 
# of shape (batch_size, num_heads, seq_length, seq_length)
# so attentions[0] gives the attention for the first transformer layer, and attentions[0][0] gives the 
# attention weights for the first sentence in the batch.


In [None]:
# We can visualise the attention for the last layer, head 0 (the first attention head) for the 'bank' token in both sentences.
import matplotlib.pyplot as plt
import seaborn as sns
# Get the attention weights for the last layer and head 0
last_layer_attention_sent1 = attentions[-1][0]  # shape (num_heads, seq_length, seq_length)
last_layer_attention_sent2 = attentions[-1][1]  # shape (num_heads, seq_length, seq_length)

head0_attention_sent1 = last_layer_attention_sent1[8]  # shape (seq_length, seq_length)
head0_attention_sent2 = last_layer_attention_sent2[8]  # shape (seq_length, seq_length)
# Get the attention weights for the 'bank' token in both sentences
bank1_attention = head0_attention_sent1[sent1_bank_index]  # attention weights for 'bank' in sent1
bank2_attention = head0_attention_sent2[sent2_bank_index]

# Visualise the attention weights for 'bank' in sent1
plt.figure(figsize=(10, 3))
sns.barplot(x=sent1_tokens,
            y=bank1_attention.detach().numpy())
plt.xticks(rotation=90)
plt.title("Where 'bank' attends in sent1")
plt.show()# Visualise the attention weights for 'bank' in sent2

plt.figure(figsize=(10, 3))
sns.barplot(x=sent2_tokens,
            y=bank2_attention.detach().numpy())
plt.xticks(rotation=90)
plt.title("Where 'bank' attends in sent2")
plt.show()

Note that above, the word output for bank attend to Bank, as the end of sentence (period).

Let's check the first layer.

In [None]:
# average over all heads in the first layer
avg_attention_sent1 = attentions[0][0].mean(dim=0)  # shape (seq_length, seq_length)
avg_attention_sent2 = attentions[0][1].mean(dim=0)  # shape (seq_length, seq_length)
bank1_avg_attention = avg_attention_sent1[sent1_bank_index]  # shape (seq_length,)
bank2_avg_attention = avg_attention_sent2[sent2_bank_index]  # shape (seq_length,)
# Visualise the average attention weights for 'bank' in sent1
plt.figure(figsize=(10, 3))
sns.barplot(x=sent1_tokens,
            y=bank1_avg_attention.detach().numpy())
plt.xticks(rotation=90)
plt.title("Average attention for 'bank' in sent1")
plt.show()
# Visualise the average attention weights for 'bank' in sent2
plt.figure(figsize=(10, 3))
sns.barplot(x=sent2_tokens,
            y=bank2_avg_attention.detach().numpy())
plt.xticks(rotation=90)
plt.title("Average attention for 'bank' in sent2")
plt.show()


Now we can see that in the first layer, the word bank in the first sentence attend to RIVER, whereas in the second sentence it attends to DEPOSIT and MONEY. These are the words that disambiguate what _bank_ actually means in the sentence. It means that the model has learned to understand context of words.

We can look at the attention for the word Bank, over all layers.

In [None]:
# Build layer Ã— token matrix (averaged over heads)
def compute_layer_heatmap(attn_tensor, bank_index):
    # attn_tensor shape: (layers, heads, seq_len, seq_len)
    layer_maps = []
    for layer in range(attn_tensor.shape[0]):
        # Select layer
        layer_attention = attn_tensor[layer]  # (heads, seq_len, seq_len)

        # Select bank row across heads
        bank_attention = layer_attention[:, bank_index, :]  # (heads, seq_len)

        # Average across heads
        mean_attention = bank_attention.mean(dim=0)  # (seq_len,)

        layer_maps.append(mean_attention)

    return torch.stack(layer_maps)  # (layers, seq_len)

heatmap1 = compute_layer_heatmap(torch.stack(attentions)[:,0], sent1_bank_index)
heatmap2 = compute_layer_heatmap(torch.stack(attentions)[:,1], sent2_bank_index)
heatmap1 = heatmap1 / heatmap1.sum(dim=1, keepdim=True)
heatmap2 = heatmap2 / heatmap2.sum(dim=1, keepdim=True)
num_layers = heatmap1.shape[0]

# Plot Sentence 1
plt.figure(figsize=(10, 6))
sns.heatmap(heatmap1.detach().numpy(),
            xticklabels=sent1_tokens,
            yticklabels=[f"L{i}" for i in range(num_layers)],
            cmap="viridis")
plt.title("Attention to Tokens from 'bank' Across Layers (Sentence 1)")
plt.xticks(rotation=90)
plt.ylabel("Layer")
plt.show()

# Plot Sentence 2
plt.figure(figsize=(10, 6))
sns.heatmap(heatmap2.detach().numpy(),
            xticklabels=sent2_tokens,
            yticklabels=[f"L{i}" for i in range(num_layers)],
            cmap="viridis")
plt.title("Attention to Tokens from 'bank' Across Layers (Sentence 2)")
plt.xticks(rotation=90)
plt.ylabel("Layer")
plt.show()