In [2]:
#%pip install transformers

In [1]:
import torch
torch.cuda.is_available()

False

In [2]:
#load a model and tokenizer
from transformers import GPTNeoXForCausalLM, AutoTokenizer

model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-410m-deduped",
  cache_dir="./data/pythia-410m-deduped/default",
)

tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-410m-deduped",
  cache_dir="./data/pythia-410m-deduped/default",
)

inputs = tokenizer("operation is a word that means", return_tensors="pt")
tokens = model.generate(**inputs)
tokenizer.decode(tokens[0])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


'operation is a word that means "to be in the midst of."\n\nThe word is used'

In [3]:

config = model.generation_config 
config.temperature = 1.5
config.max_new_tokens = 30
#config.update(temperature=0.8)

tokens = model.generate(**inputs,generation_config=config)
tokenizer.decode(tokens[0])

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


'operation is a word that means "to be in the midst of."\n\nThe word is used in the sense of "in the midst of" or "in the midst of'

## given a predicted token, can we identify a feature from an earlier residual stream that contributes to the prediction


In [4]:
config = model.generation_config 
config.max_new_tokens = 1
inputs = tokenizer("The capital of Ireland is", return_tensors="pt")
tokens = model.generate(**inputs,generation_config=config)
tokens[0], tokenizer.decode(tokens[0])


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


(tensor([  510,  5347,   273, 11011,   310, 24523]),
 'The capital of Ireland is Dublin')

In [5]:
inputs

{'input_ids': tensor([[  510,  5347,   273, 11011,   310]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [6]:



embeds = model.gpt_neox.embed_in(inputs.input_ids)

transformer_outputs = model.gpt_neox.forward(inputs_embeds=embeds)
residuals = transformer_outputs[0]
lm_logits = model.embed_out(residuals)

lm_logits


tensor([[[-3.3826e+00, -1.4281e+00,  3.9159e+00,  ..., -1.3819e+00,
          -1.5631e+00, -1.3122e+00],
         [-1.2571e+00, -9.4581e-01,  7.2181e+00,  ..., -7.0558e-01,
          -7.9936e-01, -1.2370e+00],
         [-1.9328e+00, -9.6202e-03,  4.5567e+00,  ...,  2.8501e-01,
           2.0072e-01,  6.1429e-02],
         [-8.5777e-02, -1.1183e-01,  1.1387e+01,  ...,  4.4135e-02,
          -1.3206e-01,  1.0479e-01],
         [-8.4291e-01, -5.2717e-01,  8.0660e+00,  ..., -4.3583e-01,
          -4.3852e-01, -1.3545e-01]]], grad_fn=<UnsafeViewBackward>)

In [7]:
from captum.attr import IntegratedGradients

expected_token_id=24523
#define a (differentiable?) function which returns the weight of the predicted token
#IntegratedGradients uses the gradient to attribute that weight to the input
def forward_with_embeds(embeds):
    transformer_outputs = model.gpt_neox.forward(inputs_embeds=embeds)
    hidden_states = transformer_outputs[0]
    lm_logits = model.embed_out(hidden_states)
    #return the logit from the first (and only) batch, last token, at the position of the expected token id
    return lm_logits[0,-1,expected_token_id:expected_token_id+1]

ig=IntegratedGradients(forward_with_embeds)
embeds.requires_grad_()
attr, delta = ig.attribute(embeds,return_convergence_delta=True)#, target=expected_token_id)
attr = attr.detach().cpu().numpy()

In [11]:
#attr matches the shape of the input tensor of embedding vectors. Its elements represent the importance 
#of each element of the embedding token with respect to the generated embedding
embeds.shape, attr.shape

(torch.Size([1, 5, 1024]), (1, 5, 1024))

In [10]:
embeds, attr

(tensor([[[-0.0151,  0.0400,  0.0081,  ...,  0.0009, -0.0088,  0.0104],
          [-0.0543, -0.0066, -0.0008,  ..., -0.0178, -0.0266,  0.0249],
          [-0.0075, -0.0025, -0.0020,  ...,  0.0050,  0.0027, -0.0002],
          [ 0.0188,  0.0038,  0.0273,  ...,  0.0422, -0.0127, -0.0206],
          [ 0.0005, -0.0213, -0.0015,  ..., -0.0014, -0.0146, -0.0208]]],
        grad_fn=<EmbeddingBackward>),
 array([[[ 4.62881540e-03,  2.45691774e-02,  7.35847475e-03, ...,
          -5.81025640e-04, -4.08304952e-03, -2.38937235e-03],
         [ 4.44121634e-02, -5.95175198e-03, -4.80995240e-04, ...,
          -1.02934224e-03,  7.82107470e-03, -4.36509884e-02],
         [ 6.94823241e-04, -3.23122539e-04, -2.95850111e-04, ...,
           1.47434441e-04,  3.79010018e-05,  4.66825345e-05],
         [-1.63826454e-03,  2.30229992e-04, -1.09744005e-03, ...,
           2.18102159e-04,  1.43563063e-03,  4.71485411e-03],
         [ 3.21494793e-05, -1.00761952e-04,  7.51635712e-06, ...,
          -2.93056683e

In [12]:
embeds.grad

  embeds.grad
