In [1]:
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
import numpy as np
import os
import torch.nn.functional as F
from tqdm import tqdm
import time
from peft import PeftConfig, PeftModel, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import json 
from IPython.display import display, HTML

model_dir = "/raid/models/llama2/llama-2-13b-chat/hf"
output_dir = "/raid/slee3473/LLM/llama-output/sentence_transform_complex_jan3"
ckpt_dir = os.path.join(output_dir, "checkpoint-94")
if 'model' in globals():
    del model
    torch.cuda.empty_cache() 

tokenizer = LlamaTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

device = "cuda:2"
finetuned = False

if os.path.exists(ckpt_dir) and len(os.listdir(ckpt_dir)) > 0: # load pretrained
    print(f"Load a fine-tuned model from {ckpt_dir}")
    model = LlamaForCausalLM.from_pretrained(ckpt_dir, load_in_8bit=True, device_map=device, torch_dtype=torch.float16)
    finetuned = True
else:
    model = LlamaForCausalLM.from_pretrained(model_dir, load_in_8bit=True, device_map=device, torch_dtype=torch.float16)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Load a fine-tuned model from /raid/slee3473/LLM/llama-output/sentence_transform_complex_jan3/checkpoint-94


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)

if not finetuned:
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"])
    print(peft_config)
    model = get_peft_model(model, peft_config)
else:
    peft_config = PeftConfig.from_pretrained(ckpt_dir)
    peft_config.inference_mode = False
    print(peft_config)
    model = PeftModel.from_pretrained(model, ckpt_dir, is_trainable=True)
    
model.print_trainable_parameters()

LoraConfig(peft_type='LORA', auto_mapping=None, base_model_name_or_path='/raid/models/llama2/llama-2-13b-chat/hf', revision=None, task_type='CAUSAL_LM', inference_mode=False, r=8, target_modules=['q_proj', 'v_proj'], lora_alpha=32, lora_dropout=0.05, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)
trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002


### Dataset

In [3]:
import datasets
import os
data_dir = "./../../../data/"
train_dataset = datasets.load_from_disk(os.path.join(data_dir, "sentence_transformation_complex/train.hf"))
test_dataset = datasets.load_from_disk(os.path.join(data_dir, "sentence_transformation_complex/test.hf"))

train_dataset = train_dataset.map(lambda train_dataset: tokenizer(train_dataset["text"], padding='max_length', truncation=True, max_length=64))
test_dataset = test_dataset.map(lambda test_dataset: tokenizer(test_dataset["text"]))

In [4]:
train_dataset = train_dataset.add_column("labels", train_dataset["input_ids"])

In [5]:
print(tokenizer.decode(test_dataset[960]["input_ids"]))
# print(tokenizer.decode(train_dataset[960]["input_ids"]))

<s> Repeat Each Word Twice
    Then, Double Every Consonant
    For example:
    Music whispers in ears. ->  MMussicc MMussicc wwhhisspperrss wwhhisspperrss inn inn earrss. earrss.</s>


## Check base model

In [6]:
eval_i = 10
eval_prompt = test_dataset[eval_i]["prompt"]
model_input = tokenizer(eval_prompt, return_tensors="pt").to(device)
model.eval()

with torch.no_grad():
    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))

print(test_dataset[eval_i]["answer"])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Capitalize Every Other Letter
    For example:
    Feathers float on dreams. ->  fEaThErS FlOaT oN DrEaMs.
fEaThErS FlOaT On dReAmS.


### Prepare for the attribution

In [7]:
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
grad_dir = f"{ckpt_dir}/training_grads_post"
if not os.path.exists(grad_dir):
    os.makedirs(grad_dir)

In [8]:
grad_computed = (len(os.listdir(grad_dir)) == len(train_dataset))

In [9]:
model.eval()

if not grad_computed:
    for i, data in enumerate(tqdm(train_dataset)):
        # get the Delta_theta when we update the model with "data"
        input_ids = torch.LongTensor(data["input_ids"]).unsqueeze(0).to(device)
        attention_mask = torch.LongTensor(data["attention_mask"]).unsqueeze(0).to(device)
        labels = torch.LongTensor(data["labels"]).unsqueeze(0).to(device)
        out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = out.loss
        grad_loss = torch.autograd.grad(loss, [param for param in model.parameters() if param.requires_grad])
        torch.save(grad_loss, f"{grad_dir}/{i}.pt")

### DataInf Score

In [10]:
import json
score_dir = os.path.join(ckpt_dir, "datainf.json")
if os.path.exists(score_dir):
    with open(score_dir, "r") as f:
        scores = json.load(f)

scores = np.array(scores)

In [11]:
max_datainf_idx = int(np.argmax(np.abs(scores)))
train_dataset[max_datainf_idx]["text"]

'Double Every Consonant\n    For example:\n    Mountains challenge eager climbers. ->  MMounnttainnss cchhallllenngge eaggerr ccllimmbberrss.</s>'

### Attribute

### 1. Gradient to Embedding

In [12]:
model.eval()
model.zero_grad()

In [13]:
logsoftmax = torch.nn.LogSoftmax(dim=-1)

attr_data = test_dataset[910]
attr_prompt = attr_data["prompt"]
model_input = tokenizer(attr_prompt, return_tensors="pt").to(device)
prompt_len = model_input['input_ids'].shape[1]
attr_tokens = torch.LongTensor(attr_data["input_ids"]).reshape(1,-1)
generated_len = attr_tokens.shape[1]

attr_token_pos = np.arange(prompt_len-1, generated_len-1)
attention_mask = torch.ones_like(attr_tokens)

out = model.base_model(attr_tokens, attention_mask)

attr_logits = out.logits
attr_logprobs = logsoftmax(attr_logits)
attr_logprobs = attr_logprobs[0, attr_token_pos, attr_tokens[0, attr_token_pos+1]]  # 49
attr_logprob = attr_logprobs.sum()
attr_grad = torch.autograd.grad(attr_logprob, [param for param in model.parameters() if param.requires_grad])

model.zero_grad()

In [14]:
focused_data = train_dataset[max_datainf_idx]
focused_prompt = focused_data["prompt"]
focused_attention_mask = torch.LongTensor(focused_data["attention_mask"]).unsqueeze(0).to(device)
focused_labels = torch.LongTensor(focused_data["labels"]).unsqueeze(0).to(device)
model_input = tokenizer(focused_prompt, return_tensors="pt").to(device)
prompt_len = model_input["input_ids"].shape[1]
focused_tokens = torch.LongTensor(focused_data["input_ids"]).reshape(1,-1)

In [22]:
focused_data

{'prompt': 'Double Every Consonant\n    For example:\n    Mountains challenge eager climbers. -> ',
 'text': 'Double Every Consonant\n    For example:\n    Mountains challenge eager climbers. ->  MMounnttainnss cchhallllenngge eaggerr ccllimmbberrss.</s>',
 'answer': 'MMounnttainnss cchhallllenngge eaggerr ccllimmbberrss.',
 'variation': 'Double Every Consonant',
 'input_ids': [1,
  11599,
  7569,
  2138,
  265,
  424,
  13,
  1678,
  1152,
  1342,
  29901,
  13,
  1678,
  28418,
  18766,
  19888,
  10784,
  2596,
  29889,
  1599,
  29871,
  28880,
  1309,
  593,
  2408,
  29876,
  893,
  274,
  305,
  27090,
  645,
  264,
  865,
  479,
  321,
  9921,
  29878,
  21759,
  645,
  6727,
  29890,
  495,
  29878,
  893,
  29889,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  

In [15]:
past_key_values = None
position_ids = None 
inputs_embeds = None
assert (inputs_embeds is None) or (focused_tokens is None)

batch_size, seq_length = focused_tokens.shape
seq_length_with_past = seq_length
past_key_values_length = 0

if past_key_values is not None:
    past_key_values_length = past_key_values[0][0].shape[2]
    seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
    device = focused_tokens.device if focused_tokens is not None else inputs_embeds.device
    position_ids = torch.arange(
        past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
    )
    position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
    position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
    inputs_embeds = model.base_model.model.model.embed_tokens(focused_tokens)

inputs_embeds.requires_grad = True

In [16]:
out = model(inputs_embeds=inputs_embeds, attention_mask=focused_attention_mask, labels=focused_labels)
loss = out.loss
grad_loss = torch.autograd.grad(loss, [param for param in model.parameters() if param.requires_grad], create_graph=True)

In [17]:
inner = 0
for g1, g2 in zip(attr_grad, grad_loss):
    inner += (g1*g2).sum()
embedding_grad = torch.autograd.grad(inner, inputs_embeds)[0]
embedding_grad_norm = torch.sqrt(torch.sum(embedding_grad ** 2, dim=-1)[0])

In [18]:
embedding_grad

tensor([[[ 3.2178e+01, -8.0178e+01, -4.3799e+01,  ..., -1.1966e+01,
          -2.1315e+01,  4.0669e+01],
         [ 4.5791e+00, -4.3012e+01, -2.5762e+01,  ..., -1.6998e+01,
           7.1131e+00, -8.6583e+00],
         [-1.2073e+01, -2.6051e+01, -1.1529e+01,  ..., -1.6184e+01,
          -2.3326e+01,  1.2321e+01],
         ...,
         [ 6.1729e-05, -8.0420e-04, -6.6868e-04,  ...,  5.8145e-04,
           2.1785e-04,  3.9444e-04],
         [ 5.8824e-05, -7.5866e-04, -6.6718e-04,  ...,  5.3775e-04,
           1.9321e-04,  3.6751e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]], device='cuda:2')

In [19]:
import matplotlib
import matplotlib.pyplot as plt


# change weights into hex opacity 

# use colormap for color

def colorize_tokens(tokens, weights, attention_mask=None):
    if type(tokens) == torch.Tensor: tokens = tokens.detach().cpu().numpy()
    if tokens.ndim==2: tokens = tokens.reshape(-1)
    if type(weights) == torch.Tensor: weights = weights.detach().cpu().numpy()
    if attention_mask is None: attention_mask = np.ones_like(tokens)
    if attention_mask.ndim==2: attention_mask = attention_mask.reshape(-1)
    assert attention_mask.shape[0] == tokens.shape[0]
    
    cmap = matplotlib.colormaps.get_cmap('Reds')
    template = '<span style="color: black; background-color: {}; display: inline-block">{}</span>'
    colored_string = ''
    
    for token, weight, masked in zip(tokens, weights, attention_mask):
        if not masked: continue
        color = matplotlib.colors.rgb2hex(cmap(weight)[:3]) + "80"
        token_decoded = tokenizer.convert_ids_to_tokens([token])[0]
        if token_decoded=="<0x0A>": 
            colored_string += "<br>"
            continue
        if "▁" in token_decoded: token_decoded = token_decoded.replace("▁", "&nbsp;")
        if "<" in token_decoded: token_decoded = token_decoded.replace("<", "&lt;")
        if ">" in token_decoded: token_decoded = token_decoded.replace(">", "&gt;")
        colored_string += template.format(color, token_decoded)
    
    return colored_string

In [20]:
ignore_tokens = [1, 2, 13]
not_ignored = torch.prod(torch.vstack([focused_tokens != ignore_token for ignore_token in ignore_tokens]), dim=0)
scores = embedding_grad_norm.cpu() * not_ignored  # scores = embedding_grad_norm.cpu()
scores = scores / torch.max(scores)
s = colorize_tokens(focused_tokens, scores, focused_attention_mask[0])

In [21]:
display(HTML(s))

#### Word-level

In [None]:
# integrate the score into word-level (max? average? sum?)

#### Delete momery allocated data

In [231]:
if 'grad_loss' in globals():
    del grad_loss

if 'attr_grad' in globals():
    del attr_grad

if 'out' in globals():
    del out

torch.cuda.empty_cache() 

### 2. Mask each token using `token_id` 0 (`<unk>`) or 2 (`</s>`)

In [232]:
model.eval()
model.zero_grad()

In [233]:
logsoftmax = torch.nn.LogSoftmax(dim=-1)

attr_data = test_dataset[910]
attr_prompt = attr_data["prompt"]
model_input = tokenizer(attr_prompt, return_tensors="pt").to(device)
prompt_len = model_input['input_ids'].shape[1]
attr_tokens = torch.LongTensor(attr_data["input_ids"]).reshape(1,-1)
generated_len = attr_tokens.shape[1]

attr_token_pos = np.arange(prompt_len-1, generated_len-1)
attention_mask = torch.ones_like(attr_tokens)

out = model.base_model(attr_tokens, attention_mask)

attr_logits = out.logits
attr_logprobs = logsoftmax(attr_logits)
attr_logprobs = attr_logprobs[0, attr_token_pos, attr_tokens[0, attr_token_pos+1]]  # 49
attr_logprob = attr_logprobs.sum()
attr_grad = torch.autograd.grad(attr_logprob, [param for param in model.parameters() if param.requires_grad])

model.zero_grad()

In [None]:
focused_data = train_dataset[max_datainf_idx]
focused_prompt = focused_data["prompt"]
focused_attention_mask = torch.LongTensor(focused_data["attention_mask"]).unsqueeze(0).to(device)
focused_labels = torch.LongTensor(focused_data["labels"]).unsqueeze(0).to(device)
model_input = tokenizer(focused_prompt, return_tensors="pt").to(device)
prompt_len = model_input["input_ids"].shape[1]
focused_tokens = torch.LongTensor(focused_data["input_ids"]).reshape(1,-1)
generated_len = focused_tokens.shape[1]
focused_token_pos = np.arange(0, generated_len-1)
scores = np.zeros([generated_len])

for t_idx in range(generated_len):
    # masked_tokens = focused_tokens.clone().detach()
    # masked_tokens[t_idx] = 2 
    masked_attention_mask = focused_attention_mask.clone().detach()
    masked_attention_mask[t_idx] = 0
    out = model.base_model(focused_tokens, masked_attention_mask, focused_labels)
    loss = out.loss
    grad_loss = torch.autograd.grad(loss, [param for param in model.parameters() if param.requires_grad])
    
    for g1, g2 in zip(attr_grad, grad_loss):
        inner += (g1*g2).sum()
    
    scores[t_idx] = inner

In [None]:
print(scores)

In [None]:
# ignore_tokens = [1, 2, 13]
# not_ignored = torch.prod(torch.vstack([focused_tokens != ignore_token for ignore_token in ignore_tokens]), dim=0)
# scores = embedding_grad_norm.cpu() * not_ignored  # scores = embedding_grad_norm.cpu()
# scores = scores / torch.max(scores)
abs_scores = np.abs(scores) / torch.max(torch.abs(scores))
s = colorize_tokens(focused_tokens, abs_scores, focused_attention_mask[0])

In [None]:
logsoftmax = torch.nn.LogSoftmax(dim=-1)

In [None]:
attr_data = test_dataset[910]
attr_prompt = attr_data["prompt"]
model_input = tokenizer(attr_prompt, return_tensors="pt").to(device)
prompt_len = model_input['input_ids'].shape[1]
attr_tokens = torch.LongTensor(attr_data["input_ids"]).reshape(1,-1)
generated_len = attr_tokens.shape[1]
attr_token_pos = np.arange(prompt_len-1, generated_len-1)
# attr_token_pos = np.arange(0, generated_len-1)

# print("DECODED")
# for p in attr_token_pos:
#     print(tokenizer.decode(attr_tokens[0,p]))

In [None]:
attention_mask = torch.ones_like(attr_tokens)
out = model.base_model(attr_tokens, attention_mask)
attr_logits = out.logits
attr_logprobs = logsoftmax(attr_logits)
attr_logprobs = attr_logprobs[0, attr_token_pos, attr_tokens[0, attr_token_pos+1]]  # 49
attr_logprob = attr_logprobs.sum()
attr_grad = torch.autograd.grad(attr_logprob, [param for param in model.parameters() if param.requires_grad])
model.zero_grad()

In [None]:
n_layers = len(attr_grad)
n_train = len(train_dataset)
tr_grad_norm = np.zeros([n_layers, n_train])

In [None]:
for train_i in tqdm(range(n_train)):
    grad_i = torch.load(f"{grad_dir}/{train_i}.pt")
    for l in range(n_layers):
        tr_grad_norm[l, train_i] = (grad_i[l] * grad_i[l]).sum()

In [None]:
d_l = np.array([grad.numel() for grad in attr_grad])
lambdas = np.sum(tr_grad_norm, axis=-1) / (10 * n_train * d_l)

In [None]:
rs = [torch.zeros_like(grad) for grad in attr_grad]
for train_i in tqdm(range(n_train)):
    grad_i = torch.load(f"{grad_dir}/{train_i}.pt")
    for l in range(n_layers):
        c = (attr_grad[l] * grad_i[l]).sum() / (lambdas[l] + tr_grad_norm[l, train_i])
        ri = (attr_grad[l] - c * grad_i[l]) / (n_train * lambdas[l])
        rs[l] += ri

In [None]:
# step 3 
scores = np.zeros([n_train])
for train_k in tqdm(range(n_train)):
    grad = torch.load(f"{grad_dir}/{train_k}.pt")
    for l in range(n_layers):
        scores[train_k] -= (rs[l] * grad[l]).sum()

In [None]:
top_training_idx = np.argsort(-np.abs(scores))
for i in top_training_idx[:10]:
    print(train_dataset[int(i)]['prompt'])