In [1]:
from IPython import get_ipython
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

import os
os.environ['HF_HOME'] = '/raid/xd/.cache/torch'
from types import MethodType
from tqdm import tqdm
from collections import defaultdict, OrderedDict, Counter

import torch
import torch.nn as nn
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataloader import DataLoader

from transformers.data.data_collator import DataCollator, default_data_collator
from transformers import AutoConfig, pipeline
from transformers import RobertaForMaskedLM, RobertaTokenizer, GPT2LMHeadModel, GPT2Tokenizer, GPTNeoForCausalLM
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import HfArgumentParser, Trainer, TrainingArguments, set_seed
# from transformers.trainer_utils import EvaluationStrategy

from utils import *

In file_utils.py: default_cache_path = /raid/xd/.cache/torch/transformers


In [3]:
models = {}
cache_dir = '/nas/xd/.cache/torch/transformers/'  # for models besides t5-3b/11b
proxies = {'http': '192.168.50.1:1081'} 

In [None]:
model = RobertaForMaskedLM.from_pretrained('roberta-large', cache_dir=cache_dir)
tokenizer = RobertaTokenizer.from_pretrained('roberta-large', cache_dir=cache_dir)
models['roberta-large'] = (model, tokenizer)

In cached_path: url_or_filename = https://huggingface.co/roberta-large/resolve/main/config.json
In cached_path: output_path = /nas/xd/.cache/torch/transformers/roberta-large-config.json
In cached_path: url_or_filename = https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin
In cached_path: output_path = /nas/xd/.cache/torch/transformers/roberta-large-pytorch_model.bin


In [None]:
model_name = 't5-11b'
proxies = {'http': '192.168.50.1:1081'}
model = model11b = T5ForConditionalGeneration.from_pretrained(model_name, proxies=proxies)

tokenizer = T5Tokenizer.from_pretrained('t5-11b')
tokenizer.decode_strip_special_tokens = MethodType(decode_strip_special_tokens, tokenizer)
tokenizer.decode_old = MethodType(decode_old, tokenizer)

models['t5-11b'] = model, tokenizer

device_map = {0: list(range(0, 6)), 1: list(range(6, 15)), 2: list(range(15, 24))}
model.parallelize(device_map)
device = torch.device('cuda:0')

In [None]:
model_name = 'gpt2-xl'  # medium / large / xl
model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=cache_dir)  
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
models[model_name] = model, tokenizer

In [None]:
model_name = "EleutherAI/gpt-neo-1.3B"
model = GPTNeoForCausalLM.from_pretrained(model_name, proxies=proxies, cache_dir=cache_dir)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
models[model_name] = model, tokenizer

In [1]:
model_name = "EleutherAI/gpt-neo-2.7B"
model = GPTNeoForCausalLM.from_pretrained(model_name, proxies=proxies, cache_dir=cache_dir)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
models[model_name] = model, tokenizer

NameError: name 'GPTNeoForCausalLM' is not defined

In [5]:
# model_name = 'roberta-large'
# model_name = 'gpt2-xl'
model_name = 'EleutherAI/gpt-neo-2.7B'
# model_name = 'EleutherAI/gpt-neo-1.3B'
model, tokenizer = models[model_name]

masked_lm = tokenizer.mask_token is not None and len(tokenizer.additional_special_tokens) == 0
if masked_lm:
    mask_token = tokenizer.mask_token  # '<mask>' for roberta
elif len(tokenizer.additional_special_tokens) > 0:
    mask_token = tokenizer.additional_special_tokens[0]  # '<sxtra_id_0>' for t5
else:
    mask_token = ''  # for gpt2
if masked_lm: nlp = pipeline('fill-mask', model=model, tokenizer=tokenizer, top_k=5)

Using mask_token, but it is not set yet.


In [4]:
from random import choice, choices, shuffle, sample

#vocab = list(string.ascii_uppercase) #+ ['_'] * 16
vocab = list(string.digits)[1:]
#query_vocab = list(string.ascii_uppercase)
nrows, ncols = 8, 4 #nrow是生成的行数，ncol是每组的字母数
has_query = False
has_output = True
def map_fn(x): return x.lower()

for _ in range(nrows):
    input_tokens = sample(vocab, ncols) #sample函数随机从vocab里选n个
    #input_tokens = [ map_fn(choice(vocab)) for _ in range(ncols)]
    i = random.randint(0, len(input_tokens) - 1)
    input_tokens[i] = input_tokens[i] + '0'
#     input_tokens[-1] = input_tokens[0]
#    special = choice(vocab)
#     input_tokens = [choice(vocab).lower()] * (ncols - 1) + [special]
#     shuffle(input_tokens)
    print(' '.join(input_tokens), end='')
    if has_query:
#         query_tokens = sample(input_tokens, ncols - 1)
        query_tokens = input_tokens.copy()
        i = random.randint(0, len(input_tokens) - 1)
#         query_token = input_tokens[i]
        query_tokens[i] = choice(vocab)
#         query_tokens = [t.lower() for t in query_tokens]
        print(',', ' '.join(query_tokens), end='')
    print(' -> ', end='')
    if has_output:
        #output_tokens = input_tokens[i][:-1] #.upper()
#        output_tokens = special.lower()
#         output_tokens = choice(input_tokens)
#         output_tokens = list(set(input_tokens) - set(query_tokens))
#         output_tokens = map(map_fn, input_tokens)
#         output_tokens = reverse(input_tokens)
#         output_tokens = query_tokens[i].lower()
        print(''.join(output_tokens), end='')
    print()

NameError: name 'string' is not defined

In [140]:
# adapted from attattr
def scaled_input(emb, num_points, baseline=None):
    # shape of emb: (bsz, num_head, seq_len, seq_len)
    assert emb.size(0) == 1
    if baseline is None: baseline = torch.zeros_like(emb)   
    step = (emb - baseline) / num_points
    #res = torch.cat([baseline + step * i for i in range(num_points)], dim=0) # orig
    res = torch.cat([baseline + step * (i+1) for i in range(num_points)], dim=0)  # revised
    return res, step

# from https://discuss.pytorch.org/t/get-top-k-indices-values-of-all-rows/89354
def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    r = tuple(reversed(out))
    return torch.cat([i.unsqueeze(-1) for i in r], dim=-1).cpu().tolist() if type(index) in [torch.Tensor] else r

def numpy(a, decimals=4): return a.detach().cpu().numpy().round(decimals)

def h2topk(h, k=4, return_probs=True):
    if not hasattr(h2topk, 'ln') or h2topk.ln.normalized_shape[0] != h.size(-1):
        h2topk.ln = nn.LayerNorm(h.size(-1))
#     r = model.lm_head(h2topk.ln(h))
    r = model.lm_head(h)
    if return_probs: r = r.softmax(-1)
    return r.topk(k, dim=-1) if k > 0 else r

def globalize(tensor):
    if tensor.dim() == 4: return tensor  # global attention
    assert tensor.dim() == 5, str(tensor.dim())
    assert tensor.size(1) == 1, str(tensor.size(1))  # num_blocks
    seq_len = tensor.size(3)
    return tensor.squeeze(1)[:, :, :, -seq_len:]  # (bsz, num_blocks, H, seq_len, block_len) -> (bsz, H, seq_len, seq_len)

def show_topk(values, indices, values_fn=numpy, indices_fn=numpy):
    return dict(OrderedDict(zip(indices_fn(indices), values_fn(values))))

def append_tokens_to_positions(position_tensor):
    positions = numpy(position_tensor)
    return ['%d %s' % (p, tokens[p]) for p in positions]

In [7]:
_ = model.to('cuda:1')

In [8]:
torch.cuda.empty_cache()

In [9]:
texts = [
 '''
A H S -> A
S D N -> S
U D B -> U
Z G M ->''',
 '''
A H S H -> a
S D N F -> s
U D B S -> u
Z G M E ->''',  # [(22, 15, 0.223), (21, 13, 0.197), (25, 0, 0.16), (23, 11, 0.144), (20, 19, 0.14)]
 '''
A H S H -> A
S D N F -> S
U D B S -> U
Z G M E ->''',  # [(22, 15, 0.188), (21, 13, 0.113), (23, 11, 0.106), (21, 19, 0.09), (24, 9, 0.072)]
'''
o t j -> O
r n k -> R
n m c -> N
m g d -> M
g c j -> G
x z o -> X
i c p -> I
u a o ->''',  # [(22, 15, 0.119), (23, 11, 0.109), (21, 13, 0.101), (29, 0, 0.062), (24, 11, 0.044)]
 '''
A L A -> l
F B F -> b
M A M -> a
O W O -> w
W Y W -> y
D G D ->''',  # [(26, 16, 0.109), (22, 24, 0.07), (25, 0, 0.063), (22, 15, 0.052), (21, 16, 0.045)]
 '''
N S N -> S
N K N -> K
O M O -> M
T V T ->''',  # [(26, 16, 0.062), (21, 13, 0.052), (22, 24, 0.049), (25, 0, 0.044), (22, 15, 0.041)]
 '''
n s n -> S
n k n -> K
o m o -> M
f b f -> B
m a m -> A
t v t ->''',  # [(26, 14, 0.078), (21, 4, 0.044), (17, 1, 0.042), (21, 13, 0.041), (25, 15, 0.03)] *******
 '''
n s n -> s
n k n -> k
o m o -> m
f b f -> b
m a m -> a
t v t ->''',  # [(22, 24, 0.068), (24, 9, 0.051), (17, 1, 0.04), (26, 16, 0.037), (22, 15, 0.032)]
'''
M T L -> T
X I J -> I
T D U -> D
K L H -> L
L C V -> C
J Y D -> Y
A K G -> K
V E H -> E
N I B -> I
K U I ->''',  # [(21, 4, 0.046), (26, 16, 0.043), (22, 24, 0.03), (16, 24, 0.029), (29, 2, 0.029)]
             #  [(21, 4, 0.056), (26, 16, 0.05), (22, 24, 0.049), (29, 2, 0.03), (16, 24, 0.029)]
'''
M T L -> T
X I J -> I
T D U -> D
K L H -> L
L C V -> C
J Y D -> Y
A K G -> K
V E H ->''',
'''
Q N J P, Q J J P -> J
B V I F, B V I W -> W
J C B T, J H B T -> H
Q I M G, Q M M G -> M
J D Q K, J U Q K -> U
Q M A S, Q Q A S -> Q
J E E J, J E V J -> V
V T H V, V T S V -> S
Z R N C, Z R N C -> N
A H Z G, A H O G ->''',
'''
n d d d -> n
f f d f -> d
e b e e -> b
s q s s -> q
d d d o -> o
e c e e ->'''
]


In [2]:
texts = {
    'A B C -> B':'''
M T L -> T
X I J -> I
T D U -> D
K L H -> L
L C V -> C
J Y D -> Y
A K G -> K
V E H ->''',
    'A B C -> A':'''
M T L -> M
X I J -> X
T D U -> T
K L H -> K
L C V -> L
J Y D -> J
A K G -> A
V E H ->''',
    'A B A -> B': '''
N S N -> S
N K N -> K
O M O -> M
T V T ->''',
    'A B C -> C':'''
Z J V -> V
H V G -> G
H X L -> L
S X M -> M
R N Y -> Y
Z D A -> A
V W M -> M
D T C ->''',
    
    'A B C D -> A': '''
A H S H -> A
S D N F -> S
U D B S -> U
Z G M E ->''',
    'A B C D -> a': '''
A H S H -> a
S D N F -> s
U D B S -> u
Z G M E ->''',
    'a b c -> A':'''
o t j -> O
r n k -> R
n m c -> N
m g d -> M
g c j -> G
x z o -> X
i c p -> I
u a o ->''',
    'A B C -> a':'''
D Q H -> d
I J L -> i
O V R -> o
T F E -> t
Q N R -> q
W Q B -> w
Z H J -> z
W V Y ->''',
    'A B C -> b':'''
T N X -> n
D E U -> e
U R Q -> r
C T B -> t
X B F -> b
Q G V ->''',
    'A B C -> c':'''
P Q Z -> z
C E E -> e
U G H -> h
X B F -> f
K P Y -> y
A M A -> a
K E T ->''',
    'A B A -> b': '''
A L A -> l
F B F -> b
M A M -> a
O W O -> w
W Y W -> y
D G D ->''',
    'a b a -> B': '''
n s n -> S
n k n -> K
o m o -> M
f b f -> B
m a m -> A
t v t ->''',
    'a b a -> b': '''
n s n -> s
n k n -> k
o m o -> m
f b f -> b
m a m -> a
t v t ->''',
    'A B C -> B 2':'''
M T L -> T
X I J -> I
T D U -> D
K L H -> L
L C V -> C
J Y D -> Y
A K G -> K
V E H -> E
N I B -> I
K U F ->''',# 答案错误
    'Q N J P, Q J J P -> J':'''
Q N J P, Q J J P -> J
B V I F, B V I W -> W
J C B T, J H B T -> H
Q I M G, Q M M G -> M
J D Q K, J U Q K -> U
Q M A S, Q Q A S -> Q
J E E J, J E V J -> V
V T H V, V T S V -> S
Z R N C, Z R N C -> N
A H Z G, A H O G ->''',
    'a b b b -> a':'''
n d d d -> n
f f d f -> d
e b e e -> b
s q s s -> q
d d d o -> o
e c e e ->''',
    'N J , J J  -> J':'''
I F, I W -> W
C B, H B -> H
A G, M G -> M
J D, J U -> U
M Q, K Q -> K
E C, L C ->''', 
    'Q N J , Q J J -> J':'''
Q N J, Q J J -> J
V I F, V I W -> W
J C B, J H B -> H
Q I M, Q M M -> M
D Q K, U Q K -> U
Q M A, Q Q A -> Q
J E E, J E V -> V
H Z G, H O G ->''',
    'G L C, G L -> C':'''
G L C, G L -> C
Y P J, P Y -> J
E S A, A S -> E
U P W, P U -> W
Z Q J, Z J -> Q
C K Z, Z C -> K
B L M, L M ->''',
    'Z Y, y -> z':'''
Z Y, y -> z
K B, b -> k
N E, e -> n
J S, j -> s
O W, o -> w
F R, f -> r
J S, s -> j
N O, o -> n
P R, p ->''',
    '1 70 4 -> 7':'''
1 70 4 -> 7
4 7 20 -> 2
60 4 8 -> 6
50 9 3 -> 5
8 30 6 -> 3
7 6 90 -> 9
5 2 10 -> 1
1 80 3 ->''', #答案错误
    '10 5 6 -> 1':'''
10 5 6 -> 1
4 5 90 -> 9
40 9 5 -> 4
4 7 80 -> 8
3 4 60 -> 6
3 4 80 -> 8
6 70 4 -> 7
5 2 90 ->''',
    '6 7 30 1 -> 3':'''
4 5 9 70 -> 7
50 1 7 9 -> 5
2 90 5 4 -> 9
10 9 7 5 -> 1
5 9 60 1 -> 6
30 8 7 2 -> 3
1 90 2 7 -> 9
2 5 9 40 ->''',#此类答案均错误
    'b 1 g t -> 1':'''
b 1 g t -> 1
v p 3 y -> 3
u 2 a h -> 2
m d j 5 -> 5
t o s 6 -> 6
9 v i q -> 9
m 5 p w ->''',
    'b k d e -> k':'''
a 1 c d -> 1
b k d e -> k
d e q g -> q
f g h p -> p
o p t r -> t
h y j k ->''',
}
# ABC->A
# ABC->B
# ABC->C
# ABC->a
# ABC->b
# ABC->c
# a 1 c d -> 1
# 

In [3]:
task_name = 'a b a -> b'
text = texts[task_name]
_text = text.replace('_', mask_token).rstrip()
print(_text)

if masked_lm:
    print(_text, ['%s %.3f' % (i['token_str'], i['score']) for i in nlp(_text)])
    print(tokenizer.tokenize(_text))
else:
    inputs = tokenizer.encode_plus(_text, return_tensors='pt')
    inputs = prepare_inputs(inputs, model.device)
#     max_length = 1 + (inputs['input_ids'].size(1) if mask_token == '' else 0)
#     with torch.no_grad(): outputs = model.generate(**inputs, max_length=max_length, top_k=1)
#     print(tokenizer.decode(outputs[0]))

    for block in model.transformer.h:
        block.h_in, block.attn_out, block.h_mid = None, None, None
    with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True)
    logits = outputs.logits #if hasattr(outputs, 'logits') else outputs[0]
    values, indices = logits[0, -1].softmax(-1).topk(5)
    dict(OrderedDict(zip(tokenizer.convert_ids_to_tokens(indices), numpy(values))))

NameError: name 'mask_token' is not defined

In [207]:
salient_heads, all_attrs, all_attns = {}, {}, {}

In [208]:
all_top_heads = {}#http://octa:8890/notebooks/notebooks/attn_analysis_lwm.ipynb#{}

In [209]:
pred_label

256

In [210]:
inputs = tokenizer.encode_plus(_text, return_tensors='pt')
inputs = prepare_inputs(inputs, model.device)
outputs = model(**inputs, output_attentions=True)
attentions = outputs.attentions #if hasattr(outputs, 'attentions') else outputs[-1]
L, H = len(attentions), attentions[0].size(1)
logits = outputs.logits #if hasattr(outputs, 'logits') else outputs[0]
pred_label = logits[0, -1].argmax().item()
attns = torch.cat([globalize(a) for a in attentions])

# layer_range = (15, 41)  # gpt2-xl, 48L
# layer_range = (10, 28)  # gpt-neo-2.7B, 32L
# layer_range = (8, 21)  # gpt-neo-1.3B, 24L
layer_range = (0, 22)
attrs = []
for i in tqdm(range(*layer_range)):
    attn = attentions[i]
    scaled_attn, step = scaled_input(attn, 10) # 5
    _ = scaled_attn.requires_grad_(True)
    
    attn_module = model.transformer.h[i].attn
    if hasattr(attn_module, 'attention'): attn_module = attn_module.attention  # for gpt-neo
    attn_module.w = scaled_attn
    try: outputs = model(**inputs)
    finally: attn_module.w = None
        
#     model.transformer.h[layer].exit = True
#     try:
#         attention = model.transformer(**inputs)
#     finally:
#         attn_module.w = None
#         model.transformer.h[layer].exit = None
#     y = globalize(attention)[:, head, src, tgt]
    
    logits = outputs.logits #if hasattr(outputs, 'logits') else outputs[0]
    #y = logits[:, -1, pred_label]  #　orig
    y = logits.softmax(-1)[:, -1, pred_label]  # revised
    y
    attn_grad = torch.autograd.grad(torch.unbind(y), scaled_attn)[0]
    attn_grad = attn_grad.sum(dim=0, keepdim=True) # (bsz, H, qlen, klen) -> (1, H, qlen, klen)
    attr = attn_grad * step
    attrs.append(attr.data)
attrs = torch.cat([globalize(a) for a in attrs])
all_attns[task_name] = attns
all_attrs[task_name] = attrs

  0%|          | 0/22 [00:00<?, ?it/s]

tensor([5.0962e-05, 9.1244e-04, 1.9195e-04, 1.3151e-02, 6.0786e-01, 7.4873e-01,
        8.6274e-01, 9.4092e-01, 9.7023e-01, 9.7745e-01], device='cuda:1',
       grad_fn=<SelectBackward>)

  5%|▍         | 1/22 [00:00<00:11,  1.90it/s]

tensor([5.3904e-06, 5.7904e-03, 1.9582e-02, 5.6671e-02, 1.0602e-01, 8.9333e-01,
        9.8516e-01, 9.7900e-01, 9.7842e-01, 9.7745e-01], device='cuda:1',
       grad_fn=<SelectBackward>)

  9%|▉         | 2/22 [00:01<00:12,  1.55it/s]

tensor([0.9793, 0.9809, 0.9778, 0.9762, 0.9749, 0.9753, 0.9762, 0.9770, 0.9774,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 14%|█▎        | 3/22 [00:01<00:12,  1.48it/s]

tensor([0.8888, 0.9551, 0.9770, 0.9857, 0.9874, 0.9869, 0.9852, 0.9829, 0.9807,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 18%|█▊        | 4/22 [00:02<00:12,  1.47it/s]

tensor([0.9453, 0.9590, 0.9658, 0.9698, 0.9725, 0.9743, 0.9756, 0.9764, 0.9770,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 23%|██▎       | 5/22 [00:03<00:11,  1.47it/s]

tensor([0.9800, 0.9809, 0.9814, 0.9817, 0.9816, 0.9813, 0.9807, 0.9799, 0.9788,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 27%|██▋       | 6/22 [00:03<00:10,  1.50it/s]

tensor([0.9123, 0.9294, 0.9433, 0.9542, 0.9624, 0.9684, 0.9725, 0.9752, 0.9768,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 32%|███▏      | 7/22 [00:04<00:09,  1.53it/s]

tensor([0.8757, 0.9141, 0.9396, 0.9552, 0.9646, 0.9701, 0.9731, 0.9755, 0.9770,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 36%|███▋      | 8/22 [00:05<00:08,  1.56it/s]

tensor([0.8690, 0.9077, 0.9308, 0.9450, 0.9545, 0.9612, 0.9663, 0.9705, 0.9737,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 41%|████      | 9/22 [00:05<00:08,  1.60it/s]

tensor([0.8984, 0.9335, 0.9519, 0.9623, 0.9685, 0.9721, 0.9742, 0.9753, 0.9766,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 45%|████▌     | 10/22 [00:06<00:07,  1.65it/s]

tensor([0.6764, 0.8060, 0.8451, 0.8628, 0.8817, 0.9033, 0.9270, 0.9505, 0.9675,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 50%|█████     | 11/22 [00:06<00:06,  1.71it/s]

tensor([0.7747, 0.8513, 0.9023, 0.9326, 0.9503, 0.9609, 0.9676, 0.9721, 0.9753,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 55%|█████▍    | 12/22 [00:07<00:05,  1.76it/s]

tensor([0.2479, 0.3971, 0.5571, 0.7034, 0.8181, 0.8953, 0.9393, 0.9612, 0.9717,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 59%|█████▉    | 13/22 [00:07<00:04,  1.83it/s]

tensor([0.6850, 0.7715, 0.8371, 0.8856, 0.9202, 0.9435, 0.9585, 0.9679, 0.9738,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 64%|██████▎   | 14/22 [00:08<00:04,  1.90it/s]

tensor([0.9074, 0.9206, 0.9321, 0.9422, 0.9509, 0.9584, 0.9646, 0.9697, 0.9740,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 68%|██████▊   | 15/22 [00:08<00:03,  1.97it/s]

tensor([0.9531, 0.9574, 0.9612, 0.9646, 0.9675, 0.9701, 0.9724, 0.9743, 0.9760,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 73%|███████▎  | 16/22 [00:09<00:02,  2.06it/s]

tensor([0.9573, 0.9602, 0.9630, 0.9656, 0.9680, 0.9702, 0.9723, 0.9742, 0.9759,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 77%|███████▋  | 17/22 [00:09<00:02,  2.12it/s]

tensor([0.9477, 0.9536, 0.9585, 0.9627, 0.9662, 0.9691, 0.9717, 0.9739, 0.9758,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 82%|████████▏ | 18/22 [00:10<00:01,  2.21it/s]

tensor([0.9738, 0.9743, 0.9748, 0.9753, 0.9757, 0.9761, 0.9765, 0.9768, 0.9772,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 86%|████████▋ | 19/22 [00:10<00:01,  2.30it/s]

tensor([0.9727, 0.9733, 0.9739, 0.9745, 0.9750, 0.9756, 0.9761, 0.9765, 0.9770,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 91%|█████████ | 20/22 [00:10<00:00,  2.42it/s]

tensor([0.9752, 0.9756, 0.9759, 0.9763, 0.9766, 0.9768, 0.9771, 0.9773, 0.9774,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

 95%|█████████▌| 21/22 [00:11<00:00,  2.54it/s]

tensor([0.9788, 0.9786, 0.9784, 0.9783, 0.9781, 0.9780, 0.9779, 0.9777, 0.9776,
        0.9774], device='cuda:1', grad_fn=<SelectBackward>)

100%|██████████| 22/22 [00:11<00:00,  1.90it/s]


In [216]:
tokens = [token.replace('Ġ', '').replace('Ċ', '^') for token in tokenizer.tokenize(_text)]
for i, token in enumerate(tokens):
    if token in ['Ċ', '^']: print()
    else: print((i, token), end=' ')
seq_len = len(tokens)
answer = tokens[-1]
tgt = [i for i, token in enumerate(tokens[:-1]) if token.lower() == answer.lower()][-1]
tgt


(1, 'Z') (2, 'J') (3, 'V') (4, '->') (5, 'V') 
(7, 'H') (8, 'V') (9, 'G') (10, '->') (11, 'G') 
(13, 'H') (14, 'X') (15, 'L') (16, '->') (17, 'L') 
(19, 'S') (20, 'X') (21, 'M') (22, '->') (23, 'M') 
(25, 'R') (26, 'N') (27, 'Y') (28, '->') (29, 'Y') 
(31, 'Z') (32, 'D') (33, 'A') (34, '->') (35, 'A') 
(37, 'V') (38, 'W') (39, 'M') (40, '->') (41, 'M') 
(43, 'D') (44, 'T') (45, 'C') (46, '->') 

40

In [212]:
ki = [41, 35, 29, 23, 17, 11, 5]
tgt_i = 22 # 是要找箭头关注变化的位置的头
_attrs = attrs #/ attrs.view(attrs.size(0), -1).norm(dim=1).view(attrs.size(0), 1, 1, 1)
_attrs = _attrs[:, :, -1:, :]#tgt_i:tgt_i+1]
#values, indices = _attrs.view(_attrs.size(0), H, -1).topk(1, dim=-1)
values, indices = _attrs.view(_attrs.size(0), H, -1).topk(5, dim=-1)
val, ind = values.sum(dim=-1).view(-1).topk(50)
val, ind = numpy(val), unravel_index(ind, values.size()[:-1])

top_heads = []
# for (l, h), v in zip(ind, val):
#     top_heads.append((l, h, v, list(zip(unravel_index(indices[l, h], _attrs.size()[-2:]), numpy(values[l, h])))))
#     if l<2: continue
#     print('%d-%d\t%.4f\t' % (l, h, v), top_heads[-1][-1])
    #if l < 2: continue
   # print('%d-%d\t%.4f\t' % (l, h, v), top_heads[-1][-1])#, attns[l, h, -1, tgt_i].item())
for (l, h), v in zip(ind, val):
#     top_heads[(l, h)] = (v, list(zip(unravel_index(indices[l, h], _attrs.size()[-2:]), numpy(values[l, h]))))
    top_heads.append((l, h, v, list(zip(unravel_index(indices[l, h], _attrs.size()[-2:]), numpy(values[l, h])))))
    if l <5: continue
    print('%d-%d\t%.4f\t' % (l, h, v), top_heads[-1][-1])#, attns[l, h, -1, tgt_i].item())

10-8	0.1958	 [([0, 35], 0.1217), ([0, 41], 0.0335), ([0, 17], 0.0149), ([0, 23], 0.0145), ([0, 29], 0.0112)]
12-16	0.1604	 [([0, 35], 0.0651), ([0, 41], 0.0361), ([0, 29], 0.0299), ([0, 17], 0.0148), ([0, 23], 0.0145)]
12-18	0.1286	 [([0, 35], 0.0723), ([0, 5], 0.0206), ([0, 17], 0.015), ([0, 23], 0.0112), ([0, 29], 0.0094)]
10-0	0.1249	 [([0, 46], 0.1226), ([0, 34], 0.0012), ([0, 0], 0.0006), ([0, 16], 0.0002), ([0, 10], 0.0002)]
13-2	0.0900	 [([0, 45], 0.09), ([0, 33], 0.0), ([0, 43], 0.0), ([0, 42], 0.0), ([0, 32], 0.0)]
12-1	0.0679	 [([0, 17], 0.0243), ([0, 29], 0.0136), ([0, 23], 0.0124), ([0, 35], 0.0103), ([0, 11], 0.0073)]
13-12	0.0483	 [([0, 45], 0.0463), ([0, 43], 0.0016), ([0, 37], 1e-04), ([0, 39], 1e-04), ([0, 42], 1e-04)]
12-5	0.0383	 [([0, 35], 0.0107), ([0, 23], 0.0082), ([0, 11], 0.0074), ([0, 17], 0.0065), ([0, 29], 0.0055)]
12-2	0.0258	 [([0, 17], 0.0091), ([0, 35], 0.0085), ([0, 23], 0.0048), ([0, 11], 0.0024), ([0, 5], 0.001)]
14-6	0.0252	 [([0, 11], 0.0065), ([0, 

In [214]:
ki = [41, 35, 29, 23, 17, 11, 5]
tgt_i = 45 # 是要找箭头关注变化的位置的头
_attrs = attrs #/ attrs.view(attrs.size(0), -1).norm(dim=1).view(attrs.size(0), 1, 1, 1)
_attrs = _attrs[:, :, -1:, tgt_i:tgt_i+1]
values, indices = _attrs.view(_attrs.size(0), H, -1).topk(1, dim=-1)
#values, indices = _attrs.view(_attrs.size(0), H, -1).topk(5, dim=-1)
val, ind = values.sum(dim=-1).view(-1).topk(50)
val, ind = numpy(val), unravel_index(ind, values.size()[:-1])

top_heads = []
# for (l, h), v in zip(ind, val):
#     top_heads.append((l, h, v, list(zip(unravel_index(indices[l, h], _attrs.size()[-2:]), numpy(values[l, h])))))
#     if l<2: continue
#     print('%d-%d\t%.4f\t' % (l, h, v), top_heads[-1][-1])
    #if l < 2: continue
   # print('%d-%d\t%.4f\t' % (l, h, v), top_heads[-1][-1])#, attns[l, h, -1, tgt_i].item())
for (l, h), v in zip(ind, val):
#     top_heads[(l, h)] = (v, list(zip(unravel_index(indices[l, h], _attrs.size()[-2:]), numpy(values[l, h]))))
    top_heads.append((l, h, v, list(zip(unravel_index(indices[l, h], _attrs.size()[-2:]), numpy(values[l, h])))))
    if l <7: continue
    print('%d-%d\t%.4f\t' % (l, h, v), top_heads[-1][-1], attns[l, h, -1, tgt_i].item())

13-2	0.0900	 [([0, 0], 0.09)] 0.9725266098976135
13-12	0.0463	 [([0, 0], 0.0463)] 0.5863183736801147
13-18	0.0190	 [([0, 0], 0.019)] 0.8204120993614197
13-5	0.0175	 [([0, 0], 0.0175)] 0.8392573595046997
13-7	0.0100	 [([0, 0], 0.01)] 0.31353142857551575
13-3	0.0093	 [([0, 0], 0.0093)] 0.4202425181865692
20-19	0.0079	 [([0, 0], 0.0079)] 0.4849521517753601
19-8	0.0073	 [([0, 0], 0.0073)] 0.9890338182449341
14-0	0.0067	 [([0, 0], 0.0067)] 0.35114535689353943
15-8	0.0066	 [([0, 0], 0.0066)] 0.520453155040741
17-17	0.0059	 [([0, 0], 0.0059)] 0.8330322504043579
11-17	0.0055	 [([0, 0], 0.0055)] 0.2726283669471741
20-13	0.0047	 [([0, 0], 0.0047)] 0.7654220461845398
16-15	0.0045	 [([0, 0], 0.0045)] 0.30768778920173645
13-10	0.0042	 [([0, 0], 0.0042)] 0.8127387762069702
17-13	0.0042	 [([0, 0], 0.0042)] 0.1281743347644806
14-13	0.0035	 [([0, 0], 0.0035)] 0.9591408967971802
18-9	0.0033	 [([0, 0], 0.0033)] 0.7281856536865234
13-11	0.0032	 [([0, 0], 0.0032)] 0.2630724012851715
19-9	0.0030	 [([0, 0], 

In [125]:
layer, head = 12, 4 #
# layer, head = 
layer2, head2 = 13,12# 是
src, tgt = 61, 58
h = model.transformer.h

In [130]:
m = h[layer].attn.attention
mask = torch.ones(H, seq_len)
mask[:, -1] = 0
mask[head, -1] = 1# 0是去掉，1是不去
m.attn_mask = mask.unsqueeze(-1)
m.attn_out = None
try:
    with torch.no_grad(): outputs = model(**inputs, output_attentions=True)
finally: m.attn_mask = None
attn_out = m.attn_out
delattr(m, 'attn_out')

attn = globalize(outputs.attentions[layer2])[0, head2, -1]
show_topk(*attn.topk(5), indices_fn=append_tokens_to_positions)
probs = outputs.logits[0, -1].softmax(-1)
show_topk(*probs.topk(5), indices_fn=tokenizer.convert_ids_to_tokens)
# 第一个输出是对最后进行结果预测的head的影响，看指向的概率变换，第二个是最终结果的变化

attn_output.retain_grad() failed


{'58 R': 0.1839,
 '57 P': 0.1824,
 '60 p': 0.1175,
 '59 ,': 0.0854,
 '61 ->': 0.0789}

{'Ġr': 0.707, 'Ġp': 0.1634, 'Ġq': 0.0661, 'Ġpr': 0.0053, 'Ġo': 0.0051}

In [None]:
all_top_heads[task_name] = top_heads

In [None]:
all_top_heads[task_name]

In [None]:
# inputs = tokenizer.encode_plus(_text, return_tensors='pt')
# outputs = model(**inputs, output_attentions=True)

# logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
# y = logits[0, -1].max()
# attentions = outputs.attentions if hasattr(outputs, 'attentions') else outputs[-1]
# for a in attentions: a.retain_grad()
# model.zero_grad()
# y.backward()

# # attns = torch.cat(attentions)
# grads = torch.cat([a.grad for a in attentions])
# attrs2 = attns * grads

In [None]:
for i in range(len(top_heads)):
    layer, head, v, _ = top_heads[i]
#     if layer in [0, 1, ]: continue
#     layer, head, v = 30, 10, 1.
    fig, axs = plt.subplots(1,2,sharey=False, figsize=(10 * 2, 10))
    for i, (a, _ax) in enumerate(zip([attns, attrs], axs)):
        a = a[layer][head].detach().cpu()
        a, annot = ((a * 100).long(), True) if i == -1 else (a, False)
        res = sns.heatmap(a, square=True, cbar=False, annot=annot, fmt='d', linewidths=0.1, linecolor='grey', 
                          xticklabels=tokens, yticklabels=tokens, ax=_ax)
        _ = res.set_xticklabels(res.get_xmajorticklabels(), fontsize=9+3-2, rotation=0)
        _ = res.set_yticklabels(res.get_ymajorticklabels(), fontsize=9+3-2, rotation=0)
        _ = plt.xlabel('%d-%d    %.4f' % (layer, head, v), fontsize=14)

In [None]:
import random

def get_random_string(gpt2_tokenizer):
    tokens = [gpt2_tokenizer.convert_ids_to_tokens(i) for i in range(120)]
    tokens = [token for token in tokens if token not in string.digits + string.ascii_uppercase + string.ascii_lowercase]
    tokens = ['Ġ' + token for token in tokens if gpt2_tokenizer._convert_token_to_id('Ġ' + token) != 50256]
    token_ids = [gpt2_tokenizer._convert_token_to_id(token) for token in tokens]  # XD
    print(tokens, len(tokens))
    
    sampled_tokens_idx = []
    sampled_tokens = []
    sampled_token_ids = []  # XD

    random.seed(6)  # XD
    #get sampled tokens idx
    range_ = list(range(len(tokens)))
    for i in range(362):
        idx = random.choice(range_)
        sampled_tokens_idx.append(idx)

    for idx in sampled_tokens_idx:
        sampled_tokens.append(tokens[idx])
        sampled_token_ids.append(token_ids[idx])

    text = ''.join(sampled_tokens).replace('Ġ', ' ')  # XD
    print(text, len(sampled_token_ids)) # XD
    return sampled_token_ids, text  # XD
    
    # print("".join(sampled_tokens), len(sampled_tokens))
    # return "".join(sampled_tokens), len(sampled_tokens)

token_ids, text = get_random_string(tokenizer)

In [None]:
texts = [
    'Big is to small as fast is to _',
    'Bread is to eat as gun is to _',
    'big: small, fast: _',
    'bread: eat, gun: _ .',
    'flower: fragrant, fire: hot, bread: delicious, gun: _ ',
    'Big and small are _ .',
    'What is twice 3? _.',
    'What is the half 6? _.',
    'There is a sequence: 3, 5, 2, 7. The number immediately precedes 5 is _.',  # :)
    'There is a sequence: 3, 5, 2, 7. The number immediately follows 5 is _.',  # :(
    'There is a sequence: 3, 5, 2, 7. The number between 5 and 7 is _.',
    'There is a sequence of numbers: 3, 5, 2, 4. _ is the first number.',
    'There is a sequence of numbers: 3, 5, 2. The reversed sequence is _.',
    '''There is a sequence of numbers: 5, 1, 6, 3. The second number is 1.
There is a sequence of numbers: 3, 7, 2, 4. The second number is _.''',
    '''There is a sequence of letters: e, c, b, a. The last letter is a.
There is a sequence of letters: f, d, b, g. The last letter is _.''',
    '''The uppercase of c is C. The uppercase of f is _.''',
    '''The successor of 3 is 4. The successor of 8 is _.''',
    '''The successor of 3 is 4. The successor of _ is 6.''',
#     '''The predecessor of 3 is 2. The predecessor of 5 is 4. The predecessor of 6 is _''',
#     '''The previous integer of 4 is 3. The previous integer of 3 is _.''',
#     '''3 minus 1 equals 2. 5 minus 1 equals _.''',
    '''If 2 changes to 3, 5 changes to 6, then _ changes to 9''',
    '''If 2 changes to 20, 3 changes to 30, then 5 changes to _''',
    '''2 -> 3, 4 -> 5, 5 -> 6, 9 -> _.''',
    '''3 -> 2, 5 -> 4, 6 -> 5, 9 -> _''',
    '''9 -> 8, 7 -> 6, 6 -> 5, 2 -> _.''',
    '''3 is to _ as 4 is to 8 and 5 is to 10.''',
#     '''6 : _ :: 5 : 10 :: 7 : 14 :: 8 : 16.''',
#     '''a is to _ as f is to g, h to i, i to j, s to t.''',
#     '''c is to _ as f is to e, h to g, j to i.''',
    '''c is to _ as j is to i, h to g, f to e.''',
#     '''Twice 3 is 6, twice 4 is _.''',
#     '''Half of 4 is 2, half of 6 is _.''',

# '''Shall I compare thee to a summer's day?
# Thou''',
# '''Do not go gentle into that good night,
# Old age should burn and rave at close of day;
# Rage'''
]