In [1]:
import sys
sys.path.append('../..')

from models.inversion_model import InvEncoder, InvDecoder, InvEncoderDecoder, Projector, get_hidden_size

encoder_model_path = '/root/autodl-tmp/inv-general/InvEncoder'
decoder_model_path = '/root/autodl-tmp/sft-full-mix_50k'
projector_model_path = '/root/autodl-tmp/inv-general/projector.pt'
encoder = InvEncoder(model_path=encoder_model_path)
decoder = InvDecoder(model_path=decoder_model_path)

encoder_hidden_size = get_hidden_size('t5-base')
decoder_hidden_size = get_hidden_size('llama2')
projector = Projector(
    encoder_hidden_size=encoder_hidden_size,
    decoder_hidden_size=decoder_hidden_size,
    model_path=projector_model_path
)
inv_model = InvEncoderDecoder(
    encoder=encoder,
    decoder=decoder,
    projector=projector,
)
inv_model.to('cuda')
print('model loaded.')

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.50s/it]
  state_dict = torch.load(model_path, map_location='cuda:0')


model loaded.


In [2]:
def map_subwords_to_word(tokens):
    mapping = []
    word_idx = 0
    word_now = ''
    words = []
    last_token = ''
    for token in tokens:
        if token.startswith("▁") or token in ['<s>', '</s>', '!', ':', '?', '<0x0A>'] or token.startswith('.') or last_token==']' or last_token=='<0x0A>' or token.startswith(')'):
            word_idx += 1
            words.append(word_now)
            word_now = ''
        mapping.append(word_idx)
        word_now += token
        last_token = token
    words.append(word_now)
    return words[1:], [_-1 for _ in mapping]

In [3]:
from transformers import AutoTokenizer
import torch
from visualizer import Visualizer

def inversion_template(x):
    return '[inversion]' + x + '[/inversion]'

words_list = []
weights_list = []

x_label = 'How do you feel about the current political climate in the us? what is one thing that you would like to change? Output:'
y = ' I feel very negatively about the current political climate in the US. I would like to see more bipartisanship and less division between Democrats and Republicans.'

template_y = inversion_template(y)

t1 = AutoTokenizer.from_pretrained(encoder_model_path)
t2 = AutoTokenizer.from_pretrained(decoder_model_path)
viz = Visualizer(decoder.model, t2)

res = t1(template_y, return_tensors='pt', padding=True, truncation=True)
encoder_input_len = res['input_ids'].shape[1]

# for p in ['', template_y]:
for p in [template_y]:
    res = t1(p, return_tensors='pt', padding=True, truncation=True)
    encoder_input_ids = res['input_ids'].to('cuda')
    encoder_input_attention_mask = res['attention_mask'].to('cuda')
    # if p=='':
        # encoder_input_ids = torch.concat([encoder_input_ids, torch.full((1, encoder_input_len-1), 1).to('cuda')], dim=1)
        # encoder_input_ids = torch.concat([encoder_input_ids, torch.full((1, encoder_input_len-1), t1.pad_token_id).to('cuda')], dim=1)
        # encoder_input_attention_mask = torch.concat([encoder_input_attention_mask, torch.full((1, encoder_input_len-1), 1).to('cuda')], dim=1)
    
    hidden_states, hidden_attention_mask = inv_model.forward_hidden_states(
        encoder_input_ids=encoder_input_ids,
        encoder_input_attention_mask=encoder_input_attention_mask,
    )
    encoder_embeds = hidden_states
    res = t2(template_y, return_tensors='pt')

    decoder_input_ids = res['input_ids'].to('cuda')
    decoder_attention_mask = res['attention_mask'].to('cuda')
    text_embeds = decoder.embed_input_ids(decoder_input_ids)


    merge_embeds = torch.concat([encoder_embeds, text_embeds], dim=1)
    merge_attention = torch.concat([hidden_attention_mask, decoder_attention_mask], dim=1)

    outputs = decoder.model.generate(inputs_embeds=merge_embeds, attention_mask=merge_attention, labels=None, do_sample=False)

    token_grads = viz.vis_by_grad_embeds(merge_embeds, x_label)
    
    print('t1 processing...')
    res = t1(p, return_tensors='pt')
    # if p=='':
    #     encoder_input_ids = res['input_ids'].to('cuda')
    #     encoder_input_attention_mask = res['attention_mask'].to('cuda')
        # encoder_input_ids = torch.concat([encoder_input_ids, torch.full((1, encoder_input_len-1), 1).to('cuda')], dim=1)
        # encoder_input_ids = torch.concat([encoder_input_ids, torch.full((1, encoder_input_len-1), t1.pad_token_id).to('cuda')], dim=1)
        # # encoder_input_attention_mask = torch.concat([encoder_input_attention_mask, torch.full((1, encoder_input_len-1), 1).to('cuda')], dim=1)
        # res['input_ids'] = encoder_input_ids
        # res['attention_mask'] = encoder_input_attention_mask
    len1 = res['input_ids'].shape[1]
    decode_y = t1.batch_decode(res['input_ids'])[0]
    t1_tokens = t1.tokenize(decode_y)
    t1_token_grads = token_grads[:len1, :]

    print('t2 processing...')
    res = t2(template_y, return_tensors='pt')
    len2 = res['input_ids'].shape[1]
    t2_tokens = t2.tokenize('<s>'+template_y)
    t2_token_grads = token_grads[len1:len1+len2, :]

    tokens = t1_tokens + t2_tokens
    names=tokens
    token_grads = torch.concat([t1_token_grads, t2_token_grads], dim=0)
    values = [grad.norm().item() for grad in token_grads]

    words, mapping = map_subwords_to_word(tokens)

    word_grads = [torch.zeros_like(token_grads[0]) for _ in range(len(words))]  # Initialize gradient vectors for each word
    for idx, grad in enumerate(token_grads):
        word_grads[mapping[idx]] += grad
    weights = [x.norm().item() for x in word_grads]

    words_list.append(words)
    weights_list.append(weights)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


t1 processing...
t2 processing...


In [4]:
import json
import numpy as np

for i in range(1):
    start_idx = words_list[i].index('<s>')
    words_list[i] = words_list[i][start_idx:]
    weights_list[i] = weights_list[i][start_idx:]
    words_list[i] = [x.replace('▁','') for x in words_list[i]]
    print(weights_list[i])

from html_template import html_template
s = html_template.format(
    DATA1=json.dumps({
        'words': words_list[0],
        'weights': weights_list[0],
    })
)
with open('viz.html', 'w') as f:
    f.write(s)

[0.39726948738098145, 0.4591701924800873, 0.1297350525856018, 0.18277420103549957, 0.1304171234369278, 0.24692991375923157, 0.15261857211589813, 0.10805866867303848, 0.1125379279255867, 0.11912064254283905, 0.1737215220928192, 0.08462700247764587, 0.13007865846157074, 0.25984832644462585, 2.2333953380584717, 0.155464306473732, 0.18522672355175018, 0.13064514100551605, 0.0949675515294075, 0.07916409522294998, 0.07739176601171494, 0.22385768592357635, 0.07543805241584778, 0.060687460005283356, 0.06442674249410629, 0.06355589628219604, 0.17505645751953125, 0.06109282374382019, 0.0670047327876091, 0.4708525538444519]
