In [None]:
import functools
from datetime import datetime
from typing import Any, Dict
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, LlamaTokenizer, LlamaForCausalLM
from tqdm import tqdm
from datasets import load_dataset
from collections import defaultdict, Counter
from functools import partial
import re
from captum.attr import IntegratedGradients
from string import Template

# Dataset
start = 0
end = 9000

# IO
data_dir = Path(".") # Where our data files are stored
model_dir = Path("./.cache/models/") # Cache for huggingface models
results_dir = Path("./results/") # Directory for storing results

# Hardware
gpu = "0"
device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

# Integrated Grads
ig_steps = 64
internal_batch_size = 4

# Model
model_name = "open_llama_7b"
layer_number = -1

model_num_layers = {
    "open_llama_7b" : 32
}
assert layer_number < model_num_layers[model_name]
coll_str = "[0-9]+" if layer_number==-1 else str(layer_number)
model_repos = {
    "open_llama_7b" : ("openlm-research", f".*model.layers.{coll_str}.mlp.up_proj", f".*model.layers.{coll_str}.self_attn.o_proj")
}

def get_stop_token():
    if "llama" in model_name:
        stop_token = 13
    elif "falcon" in model_name:
        stop_token = 193
    else:
        stop_token = 50118
    return stop_token

def load_data():
    trivia_qa = load_dataset('trivia_qa')
    full_dataset = []
    for obs in tqdm(trivia_qa['train'].select(range(9000))):
        aliases = []
        aliases.extend(obs['answer']['aliases'])
        aliases.extend(obs['answer']['normalized_aliases'])
        aliases.append(obs['answer']['value'])
        aliases.append(obs['answer']['normalized_value'])
        full_dataset.append((obs['question'], aliases))
    dataset = full_dataset[start: end]
    return dataset

def get_next_token(x, model):
    with torch.no_grad():
        return model(x).logits

def generate_response(x, model, *, max_length=100, pbar=False):
    response = []
    bar = tqdm(range(max_length)) if pbar else range(max_length)
    for step in bar:
        logits = get_next_token(x, model)
        next_token = logits.squeeze()[-1].argmax()
        x = torch.concat([x, next_token.view(1, -1)], dim=1)
        response.append(next_token)
        if next_token == get_stop_token() and step>5:
            break
    return torch.stack(response).cpu().numpy() #, logits.squeeze()


def answer_question(question, model, tokenizer, *, max_length=100, pbar=False):
    input_ids = tokenizer(question, return_tensors='pt').input_ids.to(model.device)
    response = generate_response(input_ids, model, max_length=max_length, pbar=pbar) #, logits
    return response# , logits , input_ids.shape[-1]


def answer_trivia(question, targets, model, tokenizer):
    response = answer_question(question, model, tokenizer) # , start_pos , logits
    str_response = tokenizer.decode(response, skip_special_tokens=True)
    correct = False
    for alias in targets:
        if alias.lower() in str_response.lower():
            correct = True
            break
    return response, str_response, correct #, logits

def normalize_attributes(attributes: torch.Tensor) -> torch.Tensor:
        # attributes has shape (batch, sequence size, embedding dim)
        attributes = attributes.squeeze(0)

        # if aggregation == "L2":  # norm calculates a scalar value (L2 Norm)
        norm = torch.norm(attributes, dim=1)
        attributes = norm / torch.sum(norm)  # Normalize the values so they add up to 1
        
        return attributes


def model_forward(input_: torch.Tensor, model, extra_forward_args: Dict[str, Any]) \
            -> torch.Tensor:
        output = model(inputs_embeds=input_, **extra_forward_args)
        return torch.nn.functional.softmax(output.logits[:, -1, :], dim=-1)


def get_embedder(model):
    if "falcon" in model_name:
        return model.transformer.word_embeddings
    elif "opt" in model_name:
        return model.model.decoder.embed_tokens
    elif "llama" in model_name:
        return model.model.embed_tokens
    else:
        raise ValueError(f"Unknown model {model_name}")

def get_ig(prompt, forward_func, tokenizer, embedder, model):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
    prediction_id = get_next_token(input_ids, model).squeeze()[-1].argmax()
    encoder_input_embeds = embedder(input_ids).detach()
    ig = IntegratedGradients(forward_func=forward_func)
    attributes = normalize_attributes(
        ig.attribute(
            encoder_input_embeds,
            target=prediction_id,
            n_steps=ig_steps,
            internal_batch_size=internal_batch_size
        )
    ).detach().cpu().numpy()
    return attributes

In [2]:
# Model
model_loader = LlamaForCausalLM if "llama" in model_name else AutoModelForCausalLM
token_loader = LlamaTokenizer if "llama" in model_name else AutoTokenizer
tokenizer = token_loader.from_pretrained(f'{model_repos[model_name][0]}/{model_name}')
model = model_loader.from_pretrained(f'{model_repos[model_name][0]}/{model_name}',
                                        cache_dir=model_dir,
                                        device_map=device,
                                        torch_dtype=torch.bfloat16,
                                        trust_remote_code=True)
forward_func = partial(model_forward, model=model, extra_forward_args={})
embedder = get_embedder(model)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
Loading checkpoint shards: 100%|██████████| 2/2 [01:20<00:00, 40.37s/it]


In [3]:
from captum.attr import Saliency

def get_saliency(prompt, forward_func, tokenizer, embedder, model):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
    prediction_id = get_next_token(input_ids, model).squeeze()[-1].argmax()
    encoder_input_embeds = embedder(input_ids).detach()
    
    saliency = Saliency(forward_func=forward_func)
    attributes = normalize_attributes(
        saliency.attribute(encoder_input_embeds, target=prediction_id)
    ).detach().to(torch.float32).cpu().numpy()
    return attributes

In [4]:
from captum.attr import InputXGradient

def get_input_x_gradient(prompt, forward_func, tokenizer, embedder, model):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
    prediction_id = get_next_token(input_ids, model).squeeze()[-1].argmax()
    encoder_input_embeds = embedder(input_ids).detach()
    
    input_x_gradient = InputXGradient(forward_func=forward_func)
    attributes = normalize_attributes(
        input_x_gradient.attribute(
            encoder_input_embeds,
            target=prediction_id
        )
    ).detach().to(torch.float32).cpu().numpy()
    return attributes

In [5]:
dataset = load_data()

100%|██████████| 9000/9000 [00:01<00:00, 7839.73it/s]


In [None]:
# Generate results
results = defaultdict(list)
for idx in tqdm(range(len(dataset))):

    question, answers = dataset[idx]
    response, str_response, correct = answer_trivia(question, answers, model, tokenizer)
    attributes_first = get_ig(question, forward_func, tokenizer, embedder, model)
    saliency = get_saliency(question, forward_func, tokenizer, embedder, model)
    input_x_gradient = get_input_x_gradient(question, forward_func, tokenizer, embedder, model)
    
    results['question'].append(question)
    results['answers'].append(answers)
    results['response'].append(response)
    results['str_response'].append(str_response)
    results['correct'].append(correct)
    results['attributes_first'].append(attributes_first)
    results['saliency'].append(saliency)
    results['input_x_gradient'].append(input_x_gradient)

with open(results_dir/f"NEW_{model_name}_trivia_qa_start-{start}_end-{end}_{datetime.now().month}_{datetime.now().day}.pickle", "wb") as outfile:
    outfile.write(pickle.dumps(results))

# Time test

In [15]:
import time
import statistics
from tqdm import tqdm

ig_times = []
saliency_times = []
input_xgrad_times = []

for _ in tqdm(range(10)):
    question, answers = dataset[idx]

    start = time.perf_counter()
    attributes_first = get_ig(question, forward_func, tokenizer, embedder, model)
    ig_times.append(time.perf_counter() - start)

    start = time.perf_counter()
    saliency = get_saliency(question, forward_func, tokenizer, embedder, model)
    saliency_times.append(time.perf_counter() - start)

    start = time.perf_counter()
    input_x_gradient = get_input_x_gradient(question, forward_func, tokenizer, embedder, model)
    input_xgrad_times.append(time.perf_counter() - start)

  gradient_mask = apply_gradient_requirements(inputs_tuple)
  gradient_mask = apply_gradient_requirements(inputs_tuple)
100%|██████████| 10/10 [02:58<00:00, 17.88s/it]


In [20]:
np.mean(ig_times)

np.float64(15.533107759966516)

In [21]:
np.mean(saliency_times)

np.float64(1.1813910600030795)

In [25]:
(1.18 - 15.53)/15.53

-0.9240180296200902

In [22]:
np.mean(input_xgrad_times)

np.float64(1.1682947299908846)

In [24]:
(1.17 - 15.53)/15.53

-0.9246619446233098