In [4]:
from transformers import AutoModel, AutoTokenizer
from IPython.display import display
from tqdm import tqdm
import plotly.graph_objects as go
import ipywidgets as widgets
import pandas as pd
import torch



model_name = "allegro/herbert-base-cased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
model = model.to('cpu')

Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.bias', 'cls.sso.sso_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(50000, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [9]:
E = model.get_input_embeddings().weight
E = E.cpu()

# Transpose E for easier multiplication
E_T = E.T  # shape: (d_model, vocab_size)

# Function to extract and project FF values into embedding space
def extract_ff_values(model, E_T, tokenizer, k=10):
    ff_neuron_tokens = []
    for layer_idx, layer in enumerate(tqdm(model.encoder.layer, desc="Processing FF layers")):
        # Get the feed-forward weights
        # First linear layer
        K = layer.intermediate.dense.weight.data  # shape: (intermediate_size, hidden_size)
        # Second linear layer (FF values)
        V = layer.output.dense.weight.data.T  # shape: (intermediate_size, hidden_size)

        # Project FF values into embedding space
        V_proj = V @ E_T  # shape: (intermediate_size, vocab_size)

        # For each neuron, get top-k tokens
        layer_neuron_tokens = []
        for neuron_idx in range(V_proj.shape[0]):
            neuron_proj = V_proj[neuron_idx]  # shape: (vocab_size,)
            topk = torch.topk(neuron_proj, k=k)
            topk_indices = topk.indices  # indices of top-k tokens
            topk_tokens = [tokenizer.decode([idx.item()]) for idx in topk_indices]
            layer_neuron_tokens.append({
                'neuron': neuron_idx,
                'top_tokens': topk_tokens
            })
        ff_neuron_tokens.append({
            'layer': layer_idx,
            'neurons': layer_neuron_tokens
        })
    return ff_neuron_tokens

# Function to extract and project FF keys into embedding space
def extract_ff_keys(model, E_T, tokenizer, k=10):
    ff_key_tokens = []
    for layer_idx, layer in enumerate(tqdm(model.encoder.layer, desc="Processing FF keys")):
        # Get the FF keys
        K = layer.intermediate.dense.weight.data  # shape: (intermediate_size, hidden_size)

        # Project FF keys into embedding space
        K_proj = K @ E_T  # shape: (intermediate_size, vocab_size)

        # For each neuron, get top-k tokens
        layer_neuron_tokens = []
        for neuron_idx in range(K_proj.shape[0]):
            neuron_proj = K_proj[neuron_idx]  # shape: (vocab_size,)
            topk = torch.topk(neuron_proj, k=k)
            topk_indices = topk.indices
            topk_tokens = [tokenizer.decode([idx.item()]) for idx in topk_indices]
            layer_neuron_tokens.append({
                'neuron': neuron_idx,
                'top_tokens': topk_tokens
            })
        ff_key_tokens.append({
            'layer': layer_idx,
            'neurons': layer_neuron_tokens
        })
    return ff_key_tokens

# Extract and project FF values and keys
ff_values = extract_ff_values(model, E_T, tokenizer, k=10)
ff_keys = extract_ff_keys(model, E_T, tokenizer, k=10)

Processing FF layers: 100%|██████████| 12/12 [00:07<00:00,  1.57it/s]
Processing FF keys: 100%|██████████| 12/12 [00:07<00:00,  1.61it/s]


In [10]:

# Example: Print top tokens for the first 5 neurons in the first layer
print("Feed-forward values (Layer 3, first 5 neurons):")
for neuron_info in ff_values[3]['neurons'][:100]:
    print(f"Neuron {neuron_info['neuron']}: {neuron_info['top_tokens']}")

print("\nFeed-forward keys (Layer 0, first 5 neurons):")
for neuron_info in ff_keys[0]['neurons'][:100]:
    print(f"Neuron {neuron_info['neuron']}: {neuron_info['top_tokens']}")

Feed-forward values (Layer 3, first 5 neurons):
Neuron 0: ['火', '火', '仁', '仁', 'modli', 'ロ', 'lion', 'Ά', 'ե', 'kapli']
Neuron 1: ['ATP', 'OBO', 'schrieb', 'dieta', 'SPA', 'napisal', 'CER', 'OZ', 'złożony', 'NASA']
Neuron 2: ['odtwarza', 'okoliczności', 'nagrywa', 'katolickiej', 'uroczystości', 'Internecie', 'wydarzenie', 'katolickich', 'podkreśla', 'ckiego']
Neuron 3: ['ǎ', 'ǎ', '藤', '藤', 'Toy', 'Archi', '393', '士', '士', 'テ']
Neuron 4: ['schrieb', '†', '§', 'Negrin', '‰', '‰', 'aden', 'Zigrin', 'sena', '0691']
Neuron 5: ['osobisty', 'gust', 'sprawdz', 'zgubi', 'gubi', 'British', 'przykład', 'EKS', 'Ani', 'sympati']
Neuron 6: ['chwy', 'WT', 'niesienie', 'wylew', 'lew', 'wykrę', 'BOT', 'boo', 'wers', 'przejęciem']
Neuron 7: ['か', 'か', 'ư', 'ư', 'GUS', 'RUDNI', '⋅', 'RON', 'ñ', 'STOL']
Neuron 8: ['Zar', 'mal', 'Tuli', 'molest', 'Kri', 'Kry', 'Kul', 'Lud', 'Ola', 'czą']
Neuron 9: ['ň', 'gili', 'Wielkim', 'pewne', 'ň', 'Berg', 'prawdziwego', 'invalid', 'prawdziwym', 'Wigili']
Neuron 10: ['

In [None]:
def extract_attention_heads(model, E, tokenizer, k=10):
    attention_heads = []
    n_heads = model.config.num_attention_heads  # Number of attention heads
    head_dim = model.config.hidden_size // n_heads  # Dimension per head
    hidden_size = model.config.hidden_size

    for layer_idx, layer in enumerate(tqdm(model.encoder.layer, desc="Processing attention layers")):
        # Get the attention weights
        # The weights are transposed compared to GPT-2
        Wq = layer.attention.self.query.weight.data.T  # shape: (hidden_size, hidden_size)
        Wk = layer.attention.self.key.weight.data.T    # shape: (hidden_size, hidden_size)
        Wv = layer.attention.self.value.weight.data.T  # shape: (hidden_size, hidden_size)
        Wo = layer.attention.output.dense.weight.data.T  # shape: (hidden_size, hidden_size)

        # Reshape weights to separate heads
        Wq = Wq.view(hidden_size, n_heads, head_dim)  # shape: (hidden_size, n_heads, head_dim)
        Wk = Wk.view(hidden_size, n_heads, head_dim)
        Wv = Wv.view(hidden_size, n_heads, head_dim)
        Wo = Wo.view(hidden_size, n_heads, head_dim)

        for head_idx in range(n_heads):
            Wq_head = Wq[:, head_idx, :]  # shape: (hidden_size, head_dim)
            Wk_head = Wk[:, head_idx, :]
            Wv_head = Wv[:, head_idx, :]
            Wo_head = Wo[:, head_idx, :].T  # shape: (head_dim, hidden_size)

            # Project Wq_head and Wk_head into embedding space
            Wq_proj = E @ Wq_head  # shape: (vocab_size, head_dim)
            Wk_proj = E @ Wk_head  # shape: (vocab_size, head_dim)

            # Compute top-k token pairs for W_QK
            tokens_QK = []
            for i in range(head_dim):
                wq = Wq_proj[:, i]  # shape: (vocab_size,)
                wk = Wk_proj[:, i]  # shape: (vocab_size,)

                # Get top-k indices
                topk_wq_values, topk_wq_indices = torch.topk(wq, k=k)
                topk_wk_values, topk_wk_indices = torch.topk(wk, k=k)

                # Compute contributions for top-k tokens only
                for idx_q, val_q in zip(topk_wq_indices, topk_wq_values):
                    for idx_k, val_k in zip(topk_wk_indices, topk_wk_values):
                        contribution = val_q.item() * val_k.item()
                        tokens_QK.append((
                            tokenizer.decode([idx_q.item()], clean_up_tokenization_spaces=True, skip_special_tokens=True),
                            tokenizer.decode([idx_k.item()], clean_up_tokenization_spaces=True, skip_special_tokens=True),
                            contribution
                        ))
            # Sort and select top-k
            tokens_QK = sorted(tokens_QK, key=lambda x: abs(x[2]), reverse=True)[:k]

            # Similarly for W_VO
            # Project Wv_head and Wo_head into embedding space
            Wv_proj = E @ Wv_head  # shape: (vocab_size, head_dim)
            Wo_proj = Wo_head @ E.T  # shape: (head_dim, vocab_size)

            tokens_VO = []
            for i in range(head_dim):
                wv = Wv_proj[:, i]  # shape: (vocab_size,)
                wo = Wo_proj[i, :]  # shape: (vocab_size,)

                # Get top-k indices
                topk_wv_values, topk_wv_indices = torch.topk(wv, k=k)
                topk_wo_values, topk_wo_indices = torch.topk(wo, k=k)

                # Compute contributions for top-k tokens only
                for idx_v, val_v in zip(topk_wv_indices, topk_wv_values):
                    for idx_o, val_o in zip(topk_wo_indices, topk_wo_values):
                        contribution = val_v.item() * val_o.item()
                        tokens_VO.append((
                            tokenizer.decode([idx_v.item()], clean_up_tokenization_spaces=True, skip_special_tokens=True),
                            tokenizer.decode([idx_o.item()], clean_up_tokenization_spaces=True, skip_special_tokens=True),
                            contribution
                        ))
            # Sort and select top-k
            tokens_VO = sorted(tokens_VO, key=lambda x: abs(x[2]), reverse=True)[:k]

            attention_heads.append({
                'layer': layer_idx,
                'head': head_idx,
                'W_QK_top_tokens': tokens_QK,
                'W_VO_top_tokens': tokens_VO
            })
    return attention_heads

In [18]:
attention_data = extract_attention_heads(model, E, tokenizer, k=50)

Processing attention layers: 100%|██████████| 12/12 [08:16<00:00, 41.41s/it]


In [19]:
print("\nAttention Head (Layer 0, Head 0) W_QK top token pairs:")
for token_pair in attention_data[0]['W_QK_top_tokens']:
    print(token_pair)

print("\nAttention Head (Layer 0, Head 0) W_VO top token pairs:")
for token_pair in attention_data[0]['W_VO_top_tokens']:
    print(token_pair)


Attention Head (Layer 0, Head 0) W_QK top token pairs:
('brzo', 'ু', 17.095978677458334)
('klasi', 'ু', 16.37561626630577)
('brzo', 'ব', 16.241737153276517)
('brzo', '大', 15.935528486616022)
('brz', 'ু', 15.634091714351143)
('klasi', 'ব', 15.55736937546294)
('brzo', 'ব', 15.391622370302684)
('Literatura', 'ু', 15.34665668335947)
('klasi', '大', 15.264063229190015)
('brzo', 'て', 15.246606470053848)
('pobud', 'ু', 15.19352441673368)
('brzo', 'ো', 15.117897444059054)
('Bez', 'ু', 15.053850268948736)
('brzo', 'て', 14.971461611688937)
('brzo', 'ু', 14.970542413856492)
('brzo', '大', 14.932249952551501)
('bój', 'ু', 14.887767935601005)
('brz', 'ব', 14.852896873901614)
('Muze', 'ু', 14.845197760598467)
('udar', 'ু', 14.839337460733532)
('klasi', 'ব', 14.74307533995102)
('Nadzor', 'ু', 14.732613028336118)
('brzo', '國', 14.72685443512546)
('kojarzą', 'ু', 14.664083399303308)
('zaj', 'ু', 14.659956791414743)
('brzo', 'ե', 14.64265591367348)
('kaz', 'ু', 14.6311489924542)
('klasi', 'て', 14.6041698

In [20]:
# Collect data for all layers and heads
data_rows = []

for layer_data in attention_data:
    layer = layer_data['layer']
    head = layer_data['head']

    # W_QK_top_tokens
    for token_q, token_k, contribution in layer_data['W_QK_top_tokens']:
        data_rows.append({
            'Layer': layer,
            'Head': head,
            'Type': 'W_QK',
            'Token 1': token_q,
            'Token 2': token_k,
            'Contribution': contribution
        })

    # W_VO_top_tokens
    for token_v, token_o, contribution in layer_data['W_VO_top_tokens']:
        data_rows.append({
            'Layer': layer,
            'Head': head,
            'Type': 'W_VO',
            'Token 1': token_v,
            'Token 2': token_o,
            'Contribution': contribution
        })

# Create DataFrame
df_tokens = pd.DataFrame(data_rows)

# Display the DataFrame
df_tokens.head()

Unnamed: 0,Layer,Head,Type,Token 1,Token 2,Contribution
0,0,0,W_QK,brzo,ু,17.095979
1,0,0,W_QK,klasi,ু,16.375616
2,0,0,W_QK,brzo,ব,16.241737
3,0,0,W_QK,brzo,大,15.935528
4,0,0,W_QK,brz,ু,15.634092


In [22]:
def interactive_attention_visualization(attention_data, model):
    # Create sliders and dropdowns
    layer_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=model.config.num_hidden_layers - 1,
        step=1,
        description='Layer:',
        continuous_update=False
    )
    head_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=model.config.num_attention_heads - 1,
        step=1,
        description='Head:',
        continuous_update=False
    )
    token_type_dropdown = widgets.Dropdown(
        options=['W_QK_top_tokens', 'W_VO_top_tokens'],
        value='W_QK_top_tokens',
        description='Token Type:',
    )
    top_n_slider = widgets.IntSlider(
        value=20,
        min=1,
        max=50,
        step=1,
        description='Top N:',
        continuous_update=False
    )

    # Initialize an empty FigureWidget
    fig = go.FigureWidget(layout=go.Layout(
        title='',
        xaxis=dict(title='Contribution'),
        yaxis=dict(title='Token Pair', categoryorder='total ascending'),
        height=600,
        width=1200
    ))

    # Function to update the plot
    def update_plot(layer, head, token_type, top_n):
        fig.data = []
        fig.layout.title = f'Layer {layer}, Head {head} - Top {top_n} {token_type}'

        selected_data = None
        for data in attention_data:
            if data['layer'] == layer and data['head'] == head:
                selected_data = data[token_type]
                break

        if not selected_data:
            with fig.batch_update():
                fig.add_annotation(
                    text="No data available for the selected layer and head.",
                    xref="paper", yref="paper",
                    showarrow=False,
                    font=dict(size=20)
                )
            return

        # Create a DataFrame for plotting
        df = pd.DataFrame(selected_data, columns=['Token 1', 'Token 2', 'Contribution'])
        df['Token Pair'] = df['Token 1'] + ' ➔ ' + df['Token 2']
        df = df.sort_values(by='Contribution', ascending=False).head(top_n)

        # Update the figure
        with fig.batch_update():
            fig.add_bar(
                x=df['Contribution'],
                y=df['Token Pair'],
                orientation='h',
                marker_color='grey'
            )

    controls = {
        'layer': layer_slider,
        'head': head_slider,
        'token_type': token_type_dropdown,
        'top_n': top_n_slider
    }
    out = widgets.interactive_output(update_plot, controls)

    ui = widgets.VBox([layer_slider, head_slider, token_type_dropdown, top_n_slider])
    display(ui, fig)

    update_plot(layer_slider.value, head_slider.value, token_type_dropdown.value, top_n_slider.value)

interactive_attention_visualization(attention_data, model)

VBox(children=(IntSlider(value=0, continuous_update=False, description='Layer:', max=11), IntSlider(value=0, c…

FigureWidget({
    'data': [{'marker': {'color': 'grey'},
              'orientation': 'h',
              'type': 'bar',
              'uid': 'ec261119-9715-4aa7-9563-4fe76c0a3bc2',
              'x': array([17.09597868, 16.37561627, 16.24173715, 15.93552849, 15.63409171,
                          15.55736938, 15.39162237, 15.34665668, 15.26406323, 15.24660647,
                          15.19352442, 15.11789744, 15.05385027, 14.97146161, 14.97054241,
                          14.93224995, 14.88776794, 14.85289687, 14.84519776, 14.83933746]),
              'y': array(['brzo ➔ ু', 'klasi ➔ ু', 'brzo ➔ ব', 'brzo ➔ 大', 'brz ➔ ু', 'klasi ➔ ব',
                          'brzo ➔ ব', 'Literatura ➔ ু', 'klasi ➔ 大', 'brzo ➔ て', 'pobud ➔ ু',
                          'brzo ➔ ো', 'Bez ➔ ু', 'brzo ➔ て', 'brzo ➔ ু', 'brzo ➔ 大', 'bój ➔ ু',
                          'brz ➔ ব', 'Muze ➔ ু', 'udar ➔ ু'], dtype=object)}],
    'layout': {'height': 600,
               'template': '...',
               'titl