## Init

In [1]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

## Helper Functions

In [2]:
ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else 
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]") 
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False, 
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum: 
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
        
    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

## Extract Weights

In [3]:
model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight") 
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight") 
                           for j in range(num_layers)]).detach()


In [4]:
K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

In [5]:
emb_inv = emb.T

## Interpretation

#### Alternative I: No Token List

In [6]:
tokens_list = set()

#### Alternative II: Can Load Token List from IMDB

In [7]:
from datasets import load_dataset

In [8]:
imdb = load_dataset("clarin-knext/wsd_polish_datasets", trust_remote_code=True)['train']['text']

In [9]:
max_tokens_num = None

In [10]:
if max_tokens_num is None:
    tokens_list = set()
    for txt in tqdm(imdb):
        tokens_list = tokens_list.union(set(tokenizer.tokenize(txt)))
else:
    tokens_list = Counter()
    for txt in tqdm(imdb):
        tokens_list.update(set(tokenizer.tokenize(txt)))
    tokens_list = map(lambda x: x[0], tokens_list.most_common(max_tokens_num))
    

100%|██████████| 7848/7848 [00:07<00:00, 1059.18it/s]


In [11]:
tokens_list = set([*map(lambda x: x.strip('Ġ▁').lower(), tokens_list)])

### FF Keys & Values

In [12]:
i1, i2 = 23, 907
# i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)

print(f"{i1}, {i2}")
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

23, 907
K          V              -K          -V
---------  -------------  ----------  ---------
przycu     kody           dotychczas  #bot
zalog      #gory          rodzi       #lot
wylegi     #ei            do          #dzista
#walifik   #zmy           przebie     #remont
przesp     Apo            gatunku     #lee
pochowany  apokali        przez       #wan
#ppe       #128           #ja         #ette
wep        ludy           rodzin      #up
sfinans    #ords          zrazu       #bul
#iss       Cezary         lokalnie    #puszczam
#CS        archa          porywa      spu
#gny       przy           two         zamyka
#zwol      Homo           jeszcze     odstawi
#-).       litera         #jaw        #mont
lock       #pka           na          #beki
skonfisk   Benedykt       wymaga      #laks
#post      narodem        ty          tap
postoju    polityki       Drze        poby
erek       symbolu        pod         posto
#ionu      #lachet        Nie         #wana
#ppo       publicznego  

In [13]:
i1, i2 = 21, 7
# i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)

print(f"{i1}, {i2}")
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

21, 7
K            V         -K            -V
-----------  --------  ------------  -------------
#akija       #czony    #stra         #bing
#anel        #bol      #krzy         Broad
#aster       #czeni    krzy          #sora
#owiak       #jmi      odw           #lki
#akuje       #gos      wiesz         #net
kulturowego  #fik      panika        Krasi
#ader        #sion     szy           turystycznych
#akow        #J        #tu           #700
#erka        atu       wnie          #lek
europy       #rodni    winie         #ing
#imes        #wart     #to           #enne
#ontent      drabiny   wsu           #zeli
#ansen       wiec      #cieka        #1000
#ATE         #czona    wyba          #lant
#ress        #jskim    si            interna
#ords        #andar    rozczarowany  #800
Amster       #bra      #kwa          #letki
ciagu        magister  mo            #iagno
#yne         #jska     #du           #lno
#rzak        #roz      zni           Legislacyjne
#adem        #bej      tu      

In [14]:
i1, i2 = 19, 13
# i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)

print(f"{i1}, {i2}")
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

19, 13
K            V           -K         -V
-----------  ----------  ---------  ---------------
#feld        #yla        wzajemnej  #sekre
#ations      #mie        #licz      #arek
kanclerza    #Mie        #rzu       #ness
#past        #owala      #usa       #bara
#finale      #darze      #rozumie   ropy
#dzonej      #gil        #Cie       #sat
Zachodnim    #torze      #rze       #head
#dzone       #jmi        #[         #eco
publicznymi  #tul        Szczu      plusy
#onny        #widzi      uznaje     dyplomatycznych
schodowej    #jami       #twier     Luf
technicznym  #Zbigniew   Bi         koncernu
kancle       #torom      #roi       #sol
#dynki       #zer        #Ka        #ranki
#shire       #ciami      Jaku       #arty
#stez        Geral       #us        #aryn
Gwardii      #wczo       #wu        #feld
#dzonymi     #cjonalnie  #try       #ryn
#vre         #hama       Mi         belki
#dowana      #zowano     uznania    pomara
#dzona       #Adam       #MR        bankowego
#strzem

In [15]:
i1, i2 = 20, 9

print(f"{i1}, {i2}")
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

20, 9
K             V           -K          -V
------------  ----------  ----------  ----------
#SKY          #pomor      ad          #WG
#bioty        #cina       namiest     Eth
#ronto        #cin        pogro       #Ger
planetoida    #fora       od          #lich
#bica         ad          przywo      #Trud
#atory        inter       ob          #lowy
#alizowane    dziedzi     inter       #lee
#Daw          #pina       pona        #RS
#KiK          #plika      adiutan     #Is
#iT           #kacji      ober        kontynentu
#ARS          osobistym   ude         #bytu
#Europa       przedmie    pod         #nigdy
internetowym  multime     atak        #Stalin
internetowy   #gno        zbombar     #niew
#alizacji     #rac        zaatak      #lym
#TER          #aka        zwo         #Rach
#ship         #zdra       nad         #RW
#nni          ety         #go         #ARD
#elle         #pal        bombar      #lizmie
#laby         publicznym  rozstrze    #stoj
producentem   #roni       og

In [16]:
i1, i2 = 5, 2031

print(f"{i1}, {i2}")
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

5, 2031
K               V          -K           -V
--------------  ---------  -----------  -----------
Afganistanu     Rzesz      #reb         emerytalne
Afga            #parcie    fair         publiczne
Jehowy          #tyz       #gle         #szto
Wojny           #stie      #cie         ui
Obywatelskich   #tz        #bal         #smo
zbiorowych      #kop       #patrz       Sub
etni            #tkami     #arter       #szard
getta           #zdro      #gl          publiczna
rdzenia         #lock      moimi        odprawy
Kinga           #tto       mym          powsze
etnicznych      #TS        #bki         #Ur
Afganistanie    #wcze      moim         autorskie
#zacje          #ieu       #pi          Anne
przegranej      uderzeniu  #de          #mion
emerytalnych    rzesz      #lar         integra
wojennych       #sci       pas          ur
Krajowych       #zej       biurko       publicznych
zbiorowe        #zno       #dgo         celne
uznanych        hamowania  #dalej       #wymiar
mies

### Attention Weights Interpretation

In [17]:
def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):
    _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[0]
    th_max = np.inf
    left, right = 0, th0 
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid
    return torch.nonzero((mat > th) & (mat < th_max)).tolist()

def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()), 
            remaining_pos)]
    if only_alnum:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()), 
            remaining_pos)]
    if exclude_same:
        remaining_pos = [*filter(
            lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
            remaining_pos)]
    if exclude_fuzzy:
        remaining_pos = [*filter(
            lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
            remaining_pos)]
    if tokens_list:
        remaining_pos = [*filter(
            lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and 
                       (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)), 
            remaining_pos)]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
    good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

#### $W_{VO}$ Interpretation

Choose **layer** and **head** here:

In [18]:
i1, i2 = 23, 9
i1, i2

W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)

all_high_pos = approx_topk(tmp, th0=1, verbose=True) # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()

exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False

get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, 
                exclude_same=exclude_same, tokens_list=None)

one more iteration. 0
one more iteration. 0
one more iteration. 87
one more iteration. 16925


[('].', '['),
 ('],', '['),
 (' ).', ' ('),
 (']', '['),
 ('...),', ' ('),
 ('przyp', ' ('),
 (' ),', ' ('),
 ('-).', ' ('),
 ('ZOBACZ', ' ['),
 ('zob', ' ('),
 ('przyp', ' ['),
 ('.]', '['),
 ('%)', ' ('),
 ('!),', ' ('),
 ('Ash', ' ('),
 ('].', ' ['),
 ('.].', '['),
 ('],', ' ['),
 (' rezygnacja', ' ('),
 ('!).', ' ('),
 ('prem', ' ('),
 ('\x15', ' ('),
 ('...),', ' ("'),
 ('orom', ' ('),
 ('%),', ' ('),
 ('ecie', ' ['),
 ('tul', '['),
 ('.].', ' ['),
 ('CZYTAJ', ' ['),
 ('etu', ' ['),
 ('-)', ' ('),
 ('przyp', ' ("'),
 ('%).', ' ('),
 ('?),', ' ('),
 ('sic', ' ['),
 ('http', ' ('),
 (' ]', ' ['),
 ('Dzwonek', ' ('),
 (' ).', ' ("'),
 ('!)', ' ('),
 ('demon', ' ('),
 (' realizacja', ' ('),
 ('dlaczego', ' ('),
 ('.]', ' ['),
 ('Dzwonek', ' ("'),
 ('Kodeks', ' ['),
 ('Bir', ' ['),
 ('Ash', ' ("'),
 ('eu', ' ('),
 ('przyp', ' (+')]

#### $W_{QK}$ Interpretation

Choose **layer** and **head** here:

In [19]:
i1, i2 = 21, 7
i1, i2

W_Q_tmp, W_K_tmp = W_Q_heads[i1, i2, :], W_K_heads[i1, i2, :]
tmp2 = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)

all_high_pos = approx_topk(tmp2, th0=1, verbose=True) # torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).tolist()

exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = True

get_top_entries(tmp2, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, exclude_same=exclude_same, 
                tokens_list=tokens_list)

one more iteration. 0
one more iteration. 0
one more iteration. 42
one more iteration. 316744
one more iteration. 4556


[('dnieniem', ' przy'),
 (' pod', 'Uwiel'),
 (' w', 'kowicz'),
 ('aryjnych', ' tylko'),
 ('aryjnych', ' po'),
 (' w', 'dagogi'),
 (' przy', 'owienia'),
 (' pod', 'nujesz'),
 (' nad', 'ngu'),
 (' przy', 'Pry'),
 (' pod', 'alski'),
 (' w', 'lujesz'),
 (' przy', ' budowlane'),
 ('aryjnych', ' przy'),
 ('aryjnych', ' nawet'),
 ('Przytak', ' przy'),
 (' po', 'bil'),
 (' od', 'dagogi'),
 (' po', 'sywnie'),
 (' do', 'wychw'),
 ('aryjnych', ' zak'),
 (' nad', 'obior'),
 (' pod', 'niuje'),
 (' pod', 'zc'),
 (' w', 'towaniem'),
 (' pod', 'Przytu'),
 (' przed', 'ingiem'),
 (' w', 'tujesz'),
 (' do', 'ariatu'),
 (' dla', ' pryzmat'),
 (' pod', ' Wojskowego'),
 (' po', 'dujesz'),
 (' z', 'kowicz'),
 (' na', 'dagogi'),
 (' w', 'ustu'),
 (' w', 'lacy'),
 (' pod', ' podwykona'),
 (' pod', 'larii'),
 (' pod', 'tarnego'),
 (' przy', 'ariacie'),
 (' na', 'Google'),
 (' pod', 'Zapra'),
 (' nad', 'lotu'),
 (' za', 'nosc'),
 (' po', 'nujesz'),
 (' O', 'tyki'),
 (' po', 'lecie'),
 (' pod', 'nalnych'),
 (' po