In [8]:
import os
from nnsight import LanguageModel
from transformers import AutoTokenizer, AutoConfig
import torch as t 
from datasets import load_dataset
import transformer_lens.utils as utils
import tqdm
import pandas as pd
import numpy as np

from utils import make_token_df


In [9]:
nnmodel = LanguageModel('EleutherAI/pythia-410m', device_map='auto')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-410m', device_map='auto')
config = AutoConfig.from_pretrained('EleutherAI/pythia-410m')

In [4]:

t.manual_seed(42)
np.random.seed(42)
t.set_grad_enabled(False)


nnmodel.eval()

#data = load_dataset("stas/openwebtext-10k", split='train')
data = load_dataset("stas/c4-en-10k", split='train')
first_1k = data.select(range(1000))

tokenized_data = utils.tokenize_and_concatenate(first_1k, tokenizer, max_length=256, column_name='text')

tokenized_data = tokenized_data.shuffle(42)
token_df = make_token_df(tokenized_data['tokens'], tokenizer=tokenizer)

README.md:   0%|          | 0.00/961 [00:00<?, ?B/s]

c4-en-10k.py:   0%|          | 0.00/2.62k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/13.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [15]:
entropy_neuron_layer = config.num_hidden_layers - 1

all_neuron_indices = list(range(0, 4* config.hidden_size))
all_neurons = [f"{entropy_neuron_layer}.{i}" for i in all_neuron_indices]


In [None]:
from tqdm.auto import trange
from tools.decoding import decode_next_token_with_sampling
import torch.nn.functional as F

def generation(nnmodel, tokenizer, prompt, n_new_tokens=10, use_sampling=False):
    device = nnmodel.device
    # Encode the prompt
    model_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, padding=True, padding_side="left").to(device)
    prompt_tokens = model_input["input_ids"].to(device)
    batch_size = prompt_tokens.shape[0]
    prompt_len = prompt_tokens.shape[1] - 1

    # Container for final results
    probas_list = []
    l2toks = {}

    with t.no_grad():
        model_input_base = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, padding=True, padding_side="left").to(device)
        start_len = model_input_base["input_ids"].shape[1]
        
        probas_for_this_layer = []

        model_input = {key: value.clone().detach() for key, value in model_input_base.items()}

        for _ in range(n_new_tokens): 

            # Run the forward pass under the patch
            with nnmodel.trace(model_input, validate=False, scan=False):

                logits = nnmodel.lm_head.output[:, -1, :].save()
            
            if use_sampling:
                # Remove .value since logits is already a tensor
                logits_value = logits.cpu()
                probs, next_token = decode_next_token_with_sampling(logits_value)
                probas_for_this_layer.append(probs.cpu())
                next_token = next_token.to(device)
            else:
                probs = F.softmax(logits.value, dim=-1)  # shape (batch_size, vocab_size)
                probas_for_this_layer.append(probs.cpu())

                # Pick the next token (greedy). Could do top-k, sampling, etc.
                next_token = t.argmax(logits.value, dim=-1, keepdim=True)  # shape (batch_size, 1)

            # Append the new token to 'toks'
            #toks = t.cat([toks, next_token.to(device)], dim=-1)
            model_input["input_ids"] = t.cat([model_input["input_ids"], next_token], dim=-1)
            model_input["attention_mask"] = t.cat([model_input["attention_mask"], t.ones_like(next_token)], dim=-1)
            if t.all(next_token == tokenizer.eos_token_id):
                break
        # Stack probabilities: shape => (batch_size, n_new_tokens, vocab_size)
        probas_for_this_layer = t.stack(probas_for_this_layer, dim=1)
        probas_list.append(probas_for_this_layer)

        # Save final generated tokens (just the newly generated portion)
        # shape => (batch_size, n_new_tokens)
        newly_generated_tokens = model_input["input_ids"][:, start_len:].detach().cpu()
        # Convert each row in the batch to text
        decoded_texts = [tokenizer.decode(seq) for seq in newly_generated_tokens]
    
    # Now probas_list has length == number of tested layers
    # Each element is shape (batch_size, n_new_tokens, vocab_size)
    # Stack them => (n_layers_tested, batch_size, n_new_tokens, vocab_size)
    probas = t.stack(probas_list, dim=0)

    return l2toks, probas
