In [1]:
import torch
import numpy as np
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import networkx as nx
import matplotlib.pyplot as plt
import hierarchical as hrc
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

import warnings
warnings.filterwarnings('ignore')

In [2]:
device = torch.device("cuda:1")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
g = torch.load('matricies.pth').to(device) # 'FILE_PATH' in store_matrices.py

vocab_dict = tokenizer.get_vocab()
vocab_list = [None] * (max(vocab_dict.values()) + 1)
for word, index in vocab_dict.items():
    vocab_list[index] = word

In [3]:
vocab_list.index("hello")

15339

In [4]:
A = torch.load("inv_sqrt_Cov_gamma.pth").to(device)

In [5]:
A_inv_trans = A.inverse().T

In [6]:
from nnsight import LanguageModel
import torch

In [7]:
with open("data/animals.json", "r", encoding="utf-8") as f:
    data = json.load(f)

categories = ['mammal', 'bird', 'reptile', 'fish', 'amphibian', 'insect']
animals_token, animals_ind, animals_g = hrc.get_animal_category(data, categories,  vocab_dict, g)

dirs = {k: hrc.estimate_cat_dir(v, g, vocab_dict) for k, v in animals_token.items()}

all_animals_tokens = [a for k, v in animals_token.items() for a in v]
dirs.update({'animal': hrc.estimate_cat_dir(all_animals_tokens, g, vocab_dict)})
animals_token.update({'animal': all_animals_tokens})

with open("data/plants.json", "r", encoding="utf-8") as f:
    data = json.load(f)

plants_token = []
vocab_set = set(vocab_dict.keys())
lemmas = data["plant"]
for w in lemmas:
    plants_token.extend(hrc.noun_to_gemma_vocab_elements(w, vocab_set))

dirs.update(hrc.estimate_cat_dir(plants_token, g, vocab_dict))

In [8]:
dirs

{'mammal': {'lda': tensor([-0.1181,  0.7374,  0.1294,  ..., -0.4211, -0.2868, -0.1180],
         device='cuda:1'),
  'mean': tensor([-0.0325,  0.7813,  0.2077,  ..., -0.3415, -0.2703, -0.0381],
         device='cuda:1')},
 'bird': {'lda': tensor([ 0.2890, -0.8864,  1.6438,  ..., -0.5996,  0.2525, -0.0193],
         device='cuda:1'),
  'mean': tensor([ 0.2425, -0.9359,  1.6619,  ..., -0.4542,  0.1092, -0.1407],
         device='cuda:1')},
 'reptile': {'lda': tensor([ 0.1402,  0.1173, -0.1522,  ...,  0.2209, -0.0624,  0.0315],
         device='cuda:1'),
  'mean': tensor([ 0.6738,  1.2193,  0.3772,  ...,  1.1694, -0.4623, -0.2436],
         device='cuda:1')},
 'fish': {'lda': tensor([-0.0961,  0.2546,  0.3447,  ...,  0.3661,  0.3973,  0.2600],
         device='cuda:1'),
  'mean': tensor([-0.0908,  0.3068,  0.2796,  ...,  0.3508,  0.3244,  0.1997],
         device='cuda:1')},
 'amphibian': {'lda': tensor([nan, nan, nan,  ..., nan, nan, nan], device='cuda:1'),
  'mean': tensor([-0.6194,  2.

In [9]:
l_on_animal = dirs['mammal']

In [10]:
vocab_list.index('Paris')

60704

In [11]:
A_inv_trans_inv = A_inv_trans.inverse()

In [12]:
from nnsight import NNsight

In [13]:
outputs = []

In [14]:
model = LanguageModel(
    AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B', device_map=torch.device("cuda:1"), torch_dtype=torch.float32).eval(),
    tokenizer=tokenizer
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [15]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (rotary_emb

In [16]:
l_bar_animal = dirs['mean']

In [17]:
from tqdm.auto import tqdm

In [18]:
def logits_to_probs(logits, temperature: float = 0.6, top_k: float = 50):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

In [19]:
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

In [20]:
def sample(dist, greedy=False):
    if greedy:
        print("here")
        return dist.argmax(-1)
    probs = logits_to_probs(dist)
    token = multinomial_sample_one_no_sync(probs)
    return token

In [21]:
def erase_concept(last_hidden_state, concept, A_inv_trans):
    A_inv_trans_inv = A_inv_trans.inverse()
    concept_unit = concept / concept.pow(2).sum().sqrt()
    l_x = (last_hidden_state @ A_inv_trans)
    proj_on_concept = torch.tensordot(l_x, concept_unit, dims=[[-1], [-1]]).unsqueeze(-1)
    l_x -= (proj_on_concept @ concept_unit.unsqueeze(0))
    hidden_state = (l_x @ A_inv_trans_inv)
    return hidden_state

In [22]:
def erase_concepts(hidden_state, A_inv_trans, dirs):
    for key, value in dirs.items():
        if isinstance(value, torch.Tensor):
            hidden_state = erase_concept(hidden_state, value, A_inv_trans)
        elif isinstance(value, dict):
            hidden_state = erase_concepts(hidden_state, A_inv_trans, value)

    return hidden_state

In [30]:
tokens = tokenizer(["What is a dog?"], return_tensors="pt")['input_ids'].to("cuda:1")
for i in tqdm(range(10)):
    with model.trace(tokens):
        hidden_state = model.model.layers[-1].output[0][0, -1, :]
        
        # hidden_state = erase_concepts(hidden_state, A_inv_trans, {'hmm': dirs['mean']})
        # model.model.layers[-1].output[0][0, :, :] = hidden_state
        token = sample(model.lm_head.output[:, -1:, :]).save()

    tokens= torch.cat([tokens, token.squeeze(0)], dim=1)

  0%|          | 0/10 [00:00<?, ?it/s]

In [31]:
print(tokenizer.batch_decode(tokens)[0])

<|begin_of_text|>What is a dog? How do I know if my dog is a good


In [29]:
print(tokenizer.batch_decode(tokens)[0])

<|begin_of_text|>What is a dog? I mean what is a dog really? You know


In [49]:
tokens = tokenizer(["Dogs are "], return_tensors="pt")['input_ids'].to("cuda:1")
for i in tqdm(range(40)):
    with model.trace(tokens):
        hidden_state = model.model.layers[-1].output[0][0, -1, :]
        
        # hidden_state = erase_concepts(hidden_state, A_inv_trans, {'hmm': dirs['animal']['lda']})
        # model.model.layers[-1].output[0][0, :, :] = hidden_state
        token = sample(model.lm_head.output[:, -1:, :]).save()

    tokens= torch.cat([tokens, token.squeeze(0)], dim=1)

  0%|          | 0/40 [00:00<?, ?it/s]

Exception ignored in: <function WeakIdKeyDictionary.__init__.<locals>.remove at 0x7b335656b100>
Traceback (most recent call last):
  File "/datadrive/usaip/anaconda3/lib/python3.12/site-packages/torch/utils/weak.py", line 125, in remove
    def remove(k, selfref=ref(self)):

KeyboardInterrupt: 


In [50]:
# outputs without eraseure
print(tokenizer.batch_decode(tokens)[0])

<|begin_of_text|>Dogs are mammals because they have hair, sweat glands, mammary glands, nipples, and produce milk. Dogs are also mammals because they have a four-chambered heart, which is necessary for the circulation of blood.


In [48]:
# erasure with 
print(tokenizer.batch_decode(tokens)[0])

<|begin_of_text|>Dogs are mammals because they are warm-blooded, can give live birth, have hair instead of fur, and have a four-chambered heart.
Mammals are animals that have hair or fur and give birth to


In [261]:
for i in range(10):
    output = outputs[i]
    print(output.logits.shape)
    output = output.logits[0, -1, :].argmax(dim=0)
    print(tokenizer.decode(output.argmax()))

IndexError: list index out of range

In [47]:
output = output.logits[0, :, :].argmax(dim=1)

In [48]:
tokenizer.decode(output)

'Question a world where you'

In [25]:
g_y = g[vocab_list.index('cat')]

In [206]:
a = torch.rand(size=(8, 4096))

In [211]:
b = torch.rand(size=(4096,))

In [214]:
torch.tensordot(a, b, dims=[[-1], [-1]])

tensor([1010.8303,  978.7614, 1011.8314, 1000.1763, 1009.1495, 1021.1320,
        1005.3380, 1008.5158])

In [86]:
g_dist = torch.tensordot(g, l_x, dims=([-1], [-1]))

tensor(7.1406, device='cuda:1')

In [77]:
indices = g_dist.topk(10)

In [78]:
tokenizer.batch_decode(indices.indices.reshape(-1, 1))

[' ', ',', '\n', ':', ' -', '.', ' (', '!', ' in', '-']

In [42]:
tokenizer.decode(indicies)

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [55]:
tokenizer.decode(torch.argmax(output.logits[0, -1, :]))

'_<?'

In [6]:
a = torch.arange(20).reshape(5, 4)

In [14]:
b = torch.arange(5).reshape(5)

In [15]:
torch.tensordot(a, b, dims=([0], [0]))

tensor([120, 130, 140, 150])

In [16]:
a

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19]])

In [17]:
b

tensor([0, 1, 2, 3, 4])

In [127]:
a = torch.tensor([2, 3]).float()
b = torch.tensor([2, 2]).float()

In [128]:
b = b/ b.pow(2).sum().sqrt()

In [129]:
a_proj_on_b = a.dot(b)

In [130]:
a_proj_on_b

tensor(3.5355)

In [131]:
a += -1*a_proj_on_b*b

In [132]:
a

tensor([-0.5000,  0.5000])

In [133]:
a.dot(b)

tensor(0.)