In [1]:
from transformers import AutoModel, AutoTokenizer
from IPython.display import display
from tqdm import tqdm
import plotly.graph_objects as go
import ipywidgets as widgets
import pandas as pd
import torch



model_name = "sdadas/polish-gpt2-medium"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
model = model.to('cpu')

In [2]:
model

GPT2Model(
  (wte): Embedding(51200, 1024)
  (wpe): Embedding(2048, 1024)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-23): 24 x GPT2Block(
      (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=3072, nx=1024)
        (c_proj): Conv1D(nf=1024, nx=1024)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=4096, nx=1024)
        (c_proj): Conv1D(nf=1024, nx=4096)
        (act): FastGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)

In [10]:

E = model.wte.weight.data  # shape: (vocab_size, d_model)
E = E.cpu()

# Transpose E for easier multiplication
E_T = E.T  # shape: (d_model, vocab_size)

# Function to extract and project FF values into embedding space
def extract_ff_values(model, E_T, tokenizer, k=10):
    ff_neuron_tokens = []
    for layer_idx, block in enumerate(tqdm(model.h, desc="Processing FF layers")):
        mlp = block.mlp
        # Get the feed-forward weights
        # c_fc: Conv1D(nf=4096, nx=1024) - FF keys (K)
        # c_proj: Conv1D(nf=1024, nx=4096) - FF values (V)
        K = mlp.c_fc.weight.data.T  # shape: (4096, 1024)
        V = mlp.c_proj.weight.data  # shape: (4096, 1024)

        # Project FF values into embedding space
        V_proj = V @ E_T  # shape: (4096, vocab_size)

        # For each neuron, get top-k tokens
        layer_neuron_tokens = []
        for neuron_idx in range(V_proj.shape[0]):
            neuron_proj = V_proj[neuron_idx]  # shape: (vocab_size,)
            topk = torch.topk(neuron_proj, k=k)
            topk_indices = topk.indices  # indices of top-k tokens
            topk_tokens = [tokenizer.decode([idx.item()]) for idx in topk_indices]
            layer_neuron_tokens.append({
                'neuron': neuron_idx,
                'top_tokens': topk_tokens
            })
        ff_neuron_tokens.append({
            'layer': layer_idx,
            'neurons': layer_neuron_tokens
        })
    return ff_neuron_tokens

# Function to extract and project FF keys into embedding space
def extract_ff_keys(model, E_T, tokenizer, k=10):
    ff_key_tokens = []
    for layer_idx, block in enumerate(tqdm(model.h, desc="Processing FF keys")):
        mlp = block.mlp
        # Get the FF keys
        K = mlp.c_fc.weight.data.T  # shape: (4096, 1024)

        # Project FF keys into embedding space
        K_proj = K @ E_T  # shape: (4096, vocab_size)

        # For each neuron, get top-k tokens
        layer_neuron_tokens = []
        for neuron_idx in range(K_proj.shape[0]):
            neuron_proj = K_proj[neuron_idx]  # shape: (vocab_size,)
            topk = torch.topk(neuron_proj, k=k)
            topk_indices = topk.indices
            topk_tokens = [tokenizer.decode([idx.item()]) for idx in topk_indices]
            layer_neuron_tokens.append({
                'neuron': neuron_idx,
                'top_tokens': topk_tokens
            })
        ff_key_tokens.append({
            'layer': layer_idx,
            'neurons': layer_neuron_tokens
        })
    return ff_key_tokens

# Extract and project FF values and keys
ff_values = extract_ff_values(model, E_T, tokenizer, k=10)
ff_keys = extract_ff_keys(model, E_T, tokenizer, k=10)


Processing FF layers: 100%|██████████| 24/24 [00:22<00:00,  1.06it/s]
Processing FF keys: 100%|██████████| 24/24 [00:22<00:00,  1.07it/s]


In [13]:

print("Feed-forward values (Layer 3, first 5 neurons):")
for neuron_info in ff_values[3]['neurons'][:100]:
    print(f"Neuron {neuron_info['neuron']}: {neuron_info['top_tokens']}")

print("\nFeed-forward keys (Layer 0, first 5 neurons):")
for neuron_info in ff_keys[0]['neurons'][:100]:
    print(f"Neuron {neuron_info['neuron']}: {neuron_info['top_tokens']}")

Feed-forward values (Layer 3, first 5 neurons):
Neuron 0: ['aks', ' dłuż', 'stka', 'gat', ' Jast', 'Jakub', 'lacz', ' mac', ' rondo', ' rabat']
Neuron 1: [' komunalnych', ' podatkowej', ' spalin', 'kowicz', 'cian', ' administracyjnych', ' mieszkaniowej', 'zów', ' budżetowych', ' przemysłowej']
Neuron 2: ['sey', 'biu', 'bridge', 'ney', ' zapytania', 'witz', 'bież', ' numerów', 'nerów', 'dis']
Neuron 3: ['rola', ' Rei', 'remier', 'stanowi', ' Głównej', 'roe', 'dora', ' oczekując', ' Bogdana', 'stre']
Neuron 4: [' cywilnego', 'eu', 'neta', '\u202f', ' kredytu', 'olat', ' jesli', ' ukarany', ' warunk', ' bak']
Neuron 5: [' Wodnej', ' Pona', 'finy', 'usse', 'niny', ' żyjesz', 'ninę', 'foni', 'owity', 'fina']
Neuron 6: ['ancie', 'amie', 'alizacja', 'amina', 'unku', 'antem', 'antu', 'orki', 'anci', 'ensy']
Neuron 7: ['aku', 'ange', 'uge', ' Bazy', 'ō', 'fla', 'gospodar', 'bio', '�', 'ffe']
Neuron 8: ['źnica', ' maszynowego', 'rup', ' samochodowym', 'paskow', 'lizowała', ' kierowcą', ' motocy'

In [None]:
def extract_attention_heads(model, E, tokenizer, k=10):
    attention_heads = []
    n_heads = model.config.n_head  # Number of attention heads
    head_dim = model.config.n_embd // n_heads  # Dimension per head

    for layer_idx, block in enumerate(tqdm(model.h, desc="Processing attention layers")):
        attn = block.attn
        # Get the combined QKV weights
        Wqkv = attn.c_attn.weight.data  # shape: (model_dim, 3 * model_dim)
        Wqkv = Wqkv.cpu()

        # Split into Wq, Wk, Wv
        Wq = Wqkv[:, :model.config.n_embd]
        Wk = Wqkv[:, model.config.n_embd:2 * model.config.n_embd]
        Wv = Wqkv[:, 2 * model.config.n_embd:]

        # Get Wo
        Wo = attn.c_proj.weight.data  # shape: (model_dim, model_dim)
        Wo = Wo.cpu()

        # Reshape weights to separate heads
        Wq = Wq.view(model.config.n_embd, n_heads, head_dim)  # (1024, 16, 64)
        Wk = Wk.view(model.config.n_embd, n_heads, head_dim)
        Wv = Wv.view(model.config.n_embd, n_heads, head_dim)
        Wo = Wo.view(n_heads, head_dim, model.config.n_embd)   # (16, 64, 1024)

        for head_idx in range(n_heads):
            Wq_head = Wq[:, head_idx, :]  # shape: (1024, 64)
            Wk_head = Wk[:, head_idx, :]
            Wv_head = Wv[:, head_idx, :]
            Wo_head = Wo[head_idx, :, :]  # shape: (64, 1024)

            # Project Wq_head and Wk_head into embedding space
            Wq_proj = E @ Wq_head  # shape: (vocab_size, head_dim)
            Wk_proj = E @ Wk_head  # shape: (vocab_size, head_dim)

            # Compute top-k token pairs for W_QK
            tokens_QK = []
            for i in range(head_dim):
                wq = Wq_proj[:, i]  # shape: (vocab_size,)
                wk = Wk_proj[:, i]  # shape: (vocab_size,)

                # Get top-k indices
                topk_wq_values, topk_wq_indices = torch.topk(wq, k=k)
                topk_wk_values, topk_wk_indices = torch.topk(wk, k=k)

                # Compute contributions for top-k tokens only
                for idx_q, val_q in zip(topk_wq_indices, topk_wq_values):
                    for idx_k, val_k in zip(topk_wk_indices, topk_wk_values):
                        contribution = val_q.item() * val_k.item()
                        tokens_QK.append((tokenizer.decode([idx_q.item()]), tokenizer.decode([idx_k.item()]), contribution))
            # Sort and select top-k
            tokens_QK = sorted(tokens_QK, key=lambda x: abs(x[2]), reverse=True)[:k]

            # Similarly for W_VO
            # Project Wv_head and Wo_head into embedding space
            Wv_proj = E @ Wv_head  # shape: (vocab_size, head_dim)
            Wo_proj = Wo_head @ E.T  # shape: (head_dim, vocab_size)

            tokens_VO = []
            for i in range(head_dim):
                wv = Wv_proj[:, i]  # shape: (vocab_size,)
                wo = Wo_proj[i, :]  # shape: (vocab_size,)

                # Get top-k indices
                topk_wv_values, topk_wv_indices = torch.topk(wv, k=k)
                topk_wo_values, topk_wo_indices = torch.topk(wo, k=k)

                # Compute contributions for top-k tokens only
                for idx_v, val_v in zip(topk_wv_indices, topk_wv_values):
                    for idx_o, val_o in zip(topk_wo_indices, topk_wo_values):
                        contribution = val_v.item() * val_o.item()
                        tokens_VO.append((tokenizer.decode([idx_v.item()]), tokenizer.decode([idx_o.item()]), contribution))
            # Sort and select top-k
            tokens_VO = sorted(tokens_VO, key=lambda x: abs(x[2]), reverse=True)[:k]

            attention_heads.append({
                'layer': layer_idx,
                'head': head_idx,
                'W_QK_top_tokens': tokens_QK,
                'W_VO_top_tokens': tokens_VO
            })
    return attention_heads

In [6]:
attention_data = extract_attention_heads(model, E, tokenizer, k=50)

Processing attention layers: 100%|██████████| 24/24 [20:36<00:00, 51.52s/it]


In [7]:
print("\nAttention Head (Layer 0, Head 0) W_QK top token pairs:")
for token_pair in attention_data[0]['W_QK_top_tokens']:
    print(token_pair)

print("\nAttention Head (Layer 0, Head 0) W_VO top token pairs:")
for token_pair in attention_data[0]['W_VO_top_tokens']:
    print(token_pair)


Attention Head (Layer 0, Head 0) W_QK top token pairs:
('ałam', 'rzałam', 0.1122945284010024)
('ałam', 'rzałem', 0.11090440780752608)
('ałam', 'lałem', 0.1105061935700613)
('ałam', 'siadłem', 0.11007357352015745)
('ałam', 'siadłam', 0.10937657869586115)
('ałam', 'żyłem', 0.10925697906098542)
('ałam', 'żyłam', 0.10824086712875047)
('ałem', 'rzałam', 0.10775755608316562)
('ałem', 'rzałem', 0.10642359974578319)
('ałem', 'lałem', 0.10604147433292699)
('ałyśmy', 'rzałam', 0.10600535208626916)
('ałem', 'siadłem', 0.10562633318621195)
('ałem', 'siadłam', 0.10495749864958537)
('ałem', 'żyłem', 0.10484273113111264)
('ałyśmy', 'rzałem', 0.10469308669763322)
('ałam', 'cisnąłem', 0.10441236781058016)
('ałyśmy', 'lałem', 0.10431717487851522)
('ałyśmy', 'siadłem', 0.10390878418163396)
('ałem', 'żyłam', 0.10386767259456864)
('ałyśmy', 'siadłam', 0.10325082530506258)
('ałyśmy', 'żyłem', 0.10313792397687749)
('ałam', 'rzyłem', 0.10297483796873674)
('ałyśmy', 'żyłam', 0.10217872049056886)
('ałam', 'chn

In [8]:
data_rows = []

for layer_data in attention_data:
    layer = layer_data['layer']
    head = layer_data['head']

    for token_q, token_k, contribution in layer_data['W_QK_top_tokens']:
        data_rows.append({
            'Layer': layer,
            'Head': head,
            'Type': 'W_QK',
            'Token 1': token_q,
            'Token 2': token_k,
            'Contribution': contribution
        })

    for token_v, token_o, contribution in layer_data['W_VO_top_tokens']:
        data_rows.append({
            'Layer': layer,
            'Head': head,
            'Type': 'W_VO',
            'Token 1': token_v,
            'Token 2': token_o,
            'Contribution': contribution
        })

# Create DataFrame
df_tokens = pd.DataFrame(data_rows)

# Display the DataFrame
df_tokens.head()

Unnamed: 0,Layer,Head,Type,Token 1,Token 2,Contribution
0,0,0,W_QK,ałam,rzałam,0.112295
1,0,0,W_QK,ałam,rzałem,0.110904
2,0,0,W_QK,ałam,lałem,0.110506
3,0,0,W_QK,ałam,siadłem,0.110074
4,0,0,W_QK,ałam,siadłam,0.109377


In [11]:
def interactive_attention_visualization(attention_data, model):
    # Create sliders and dropdowns
    layer_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=model.config.n_layer - 1,
        step=1,
        description='Layer:',
        continuous_update=False
    )
    head_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=model.config.n_head - 1,
        step=1,
        description='Head:',
        continuous_update=False
    )
    token_type_dropdown = widgets.Dropdown(
        options=['W_QK_top_tokens', 'W_VO_top_tokens'],
        value='W_QK_top_tokens',
        description='Token Type:',
    )
    top_n_slider = widgets.IntSlider(
        value=20,
        min=1,
        max=50,
        step=1,
        description='Top N:',
        continuous_update=False
    )

    # Initialize an empty FigureWidget
    fig = go.FigureWidget(layout=go.Layout(
        title='',
        xaxis=dict(title='Contribution'),
        yaxis=dict(title='Token Pair', categoryorder='total ascending'),
        height=600,
        width=1200
    ))

    # Function to update the plot
    def update_plot(layer, head, token_type, top_n):
        fig.data = []
        fig.layout.title = f'Layer {layer}, Head {head} - Top {top_n} {token_type}'

        selected_data = None
        for data in attention_data:
            if data['layer'] == layer and data['head'] == head:
                selected_data = data[token_type]
                break

        if not selected_data:
            with fig.batch_update():
                fig.add_annotation(
                    text="No data available for the selected layer and head.",
                    xref="paper", yref="paper",
                    showarrow=False,
                    font=dict(size=20)
                )
            return

        # Create a DataFrame for plotting
        df = pd.DataFrame(selected_data, columns=['Token 1', 'Token 2', 'Contribution'])
        df['Token Pair'] = df['Token 1'] + ' ➔ ' + df['Token 2']
        df = df.sort_values(by='Contribution', ascending=False).head(top_n)

        # Update the figure
        with fig.batch_update():
            fig.add_bar(
                x=df['Contribution'],
                y=df['Token Pair'],
                orientation='h',
                marker_color='grey'
            )

    controls = {
        'layer': layer_slider,
        'head': head_slider,
        'token_type': token_type_dropdown,
        'top_n': top_n_slider
    }
    out = widgets.interactive_output(update_plot, controls)

    ui = widgets.VBox([layer_slider, head_slider, token_type_dropdown, top_n_slider])
    display(ui, fig)

    update_plot(layer_slider.value, head_slider.value, token_type_dropdown.value, top_n_slider.value)

interactive_attention_visualization(attention_data, model)

VBox(children=(IntSlider(value=0, continuous_update=False, description='Layer:', max=23), IntSlider(value=0, c…

FigureWidget({
    'data': [{'marker': {'color': 'grey'},
              'orientation': 'h',
              'type': 'bar',
              'uid': '7a0e28b9-1a95-4068-bcae-9cf6b8c7330f',
              'x': array([0.11229453, 0.11090441, 0.11050619, 0.11007357, 0.10937658, 0.10925698,
                          0.10824087, 0.10775756, 0.1064236 , 0.10604147, 0.10600535, 0.10562633,
                          0.1049575 , 0.10484273, 0.10469309, 0.10441237, 0.10431717, 0.10390878,
                          0.10386767, 0.10325083]),
              'y': array(['ałam ➔ rzałam', 'ałam ➔ rzałem', 'ałam ➔ lałem', 'ałam ➔ siadłem',
                          'ałam ➔ siadłam', 'ałam ➔ żyłem', 'ałam ➔ żyłam', 'ałem ➔ rzałam',
                          'ałem ➔ rzałem', 'ałem ➔ lałem', 'ałyśmy ➔ rzałam', 'ałem ➔ siadłem',
                          'ałem ➔ siadłam', 'ałem ➔ żyłem', 'ałyśmy ➔ rzałem', 'ałam ➔ cisnąłem',
                          'ałyśmy ➔ lałem', 'ałyśmy ➔ siadłem', 'ałem ➔ żyłam',
           