In [1]:
from data_utils import *
from nethook import *
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from utils import InputEmbedCausalTransformer
import json
import numpy as np
import os 
import pandas as pd
import colorama
import torch

  from pandas.core import (


In [2]:
colorama.init()

In [3]:
MODEL_PATH = "/home/loki/projects/filler_tokens/checkpoints/2024-07-21-17-minidata-checkpoint-final/model_weights.pt"
CONFIG_FILE = "/home/loki/projects/filler_tokens/configs/llama_d384l4h6.json"

In [4]:
train_df = pd.read_csv('data/minidata_trainset_2024-07-21.csv', header=None, names=["text"])

In [5]:
train_set = Match3VectorDataset(train_df, 3, 10, 10, 'P')
print(train_set.input_dim)

validate encodings
raw input 0  205 066 491 653 677 358 153 311 875 014 P 0- 6 2- 6 0- 8 4- 8 0- 3 6- 3 7- 1 0- 7 9- 1 2- 3- 3- 6 1- 3 1- 1 6- 9 1- 3 8- 3 9- 0 2- 0 2- 8 5- 7 2- 4 2- 7 8- 6 2- 0 3- 0 3- 1 3- 7 7- 6 8- 8 9- 7 5- 9 6- 2 7- 9 8- 4 9- 6 5- 4 7- 6 8- 2 5- 2 7- 4 6- 8 9- 6 8- 1 9- 5 8- 8 9- 8 A True
encoded sample 0 {'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16), 'labels': tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    5,   31,
           7,   31,    5,   33,    9,   33,    5,   28,   11,   28,   12,   26,
           5,   32,   14,   26,    7,    8,    8,   31,    6,   28,    6,   26,
          11,   34,    6,   28,   13,   28,   14,   25,    7,   25,    7,   33,
          10,   32,    7,   29,    7,   32,   13,  

In [79]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm
        self.block_output_unembedded = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        if isinstance(output, tuple):
            self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
            return output
        else:
            self.block_output_unembedded = self.unembed_matrix(self.norm(output))
            return output

    def reset_block_output(self):
        self.block_output_unembedded = None

class LlamaHelper:
    def __init__(self, config_file, model_path, train_set):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        config = AutoConfig.from_pretrained(config_file)
        model = InputEmbedCausalTransformer(AutoModelForCausalLM.from_config(config), train_set.input_dim)
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict, strict=False)
        model = model.to(self.device)
        self.model = model
        self.word_index_map = train_set.word_index_map
        self.data_len = train_set.data_len
        self.mod = train_set.mod
        
        for i, layer in enumerate(self.model.base_model.model.layers):
            self.model.base_model.model.layers[i] = BlockOutputWrapper(layer, self.model.base_model.lm_head, self.model.base_model.model.norm)

    def decode_tensors(self, sequence):
        decoded_sequence = []
        marker_found = False

        for token in sequence:
            token = token.item()
            if token == -100:
                if not marker_found:
                    decoded_sequence.append("[MASK]")
                continue
            elif token == 0:
                decoded_sequence.append("[EOS]")
                break  # Stop decoding after EOS
            elif token < len(self.word_index_map):
                # Regular word
                word = list(self.word_index_map.keys())[list(self.word_index_map.values()).index(token)]
                decoded_sequence.append(word)
                if word in ["A", "P"]:
                    marker_found = True
            else:
                # Handle digit labels
                offset = len(self.word_index_map)
                if token < offset + self.data_len * 2:
                    # Tuple index encoding
                    idx = (token - offset) % self.data_len
                    tuple_pos = (token - offset) // self.data_len
                    decoded_sequence.append(f"{tuple_pos}-{idx}")
                else:
                    # Single digit or digit in tuple
                    char_pos = (token - offset - self.data_len * 2) // self.mod
                    digit = (token - offset - self.data_len * 2) % self.mod
                    if char_pos == 0 or len(decoded_sequence) == 0 or not decoded_sequence[-1][-1].isdigit():
                        decoded_sequence.append(str(digit))
                    else:
                        decoded_sequence[-1] += str(digit)

        return " ".join(decoded_sequence)

    def generate_text(self, inputs, max_length=100):
        generate_ids = self.model.generate(inputs['input_ids'].float().to(self.device), max_length=max_length)
        return self.decode_tensors(generate_ids)[0]

    def get_logits(self, inputs):
        with torch.no_grad():
            input_ids = inputs['input_ids'].float().to(self.device).unsqueeze(0)
            out = self.model(input_ids)
            return out.logits

    def set_add_attn_output(self, layer, add_output):
        self.model.base_model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.base_model.model.layers[layer].get_attn_activations()
    

    def reset_all_layers(self):
        for layer in self.model.base_model.model.layers:
            layer.reset_block_output()

    def decode_all_layers(self, inputs, topk=10, print_prob=False):
        self.reset_all_layers()
        logits = self.get_logits(inputs)

        for i, layer in enumerate(self.model.base_model.model.layers):
            print(f'\nLayer {i}: Decoded intermediate outputs')
            if layer.block_output_unembedded is not None:
                try:
                    self.print_decoded_activations_horizontal(layer.block_output_unembedded, 'Block output', topk=topk, print_prob=print_prob)
                except Exception as e:
                    print(f"Error processing layer {i}: {str(e)}")
                    print(f"Shape of block_output_unembedded: {layer.block_output_unembedded.shape}")
            else:
                print("No intermediate output available for this layer.")

        # Print final logits
        print("\nFinal output logits:")
        try:
            self.print_decoded_activations_horizontal(logits, 'Final output', topk=topk, print_prob=print_prob)
        except Exception as e:
            print(f"Error processing final logits: {str(e)}")
            print(f"Shape of logits: {logits.shape}")

    def print_decoded_activations_horizontal(self, decoded_activations, label, topk=10, print_prob=False):
        seq_length = decoded_activations.size(1)
        all_tokens = []
        all_probs = []

        for i in range(seq_length):
            softmaxed = torch.nn.functional.softmax(decoded_activations[0][i], dim=-1)
            values, indices = torch.topk(softmaxed, min(topk, len(softmaxed)))
            probs_percent = [int(v * 100) for v in values.tolist()]
            tokens = self.decode_tensors(indices.unsqueeze(-1)).split()
            all_tokens.append(tokens)
            all_probs.append(probs_percent)

        # Find the maximum number of predictions available for any token
        max_predictions = max(len(tokens) for tokens in all_tokens)

        # Print tokens horizontally
        for k in range(min(topk, max_predictions)):
            token_row = []
            for i, tokens in enumerate(all_tokens):
                if k < len(tokens):
                    token_row.append(f"{i}:{tokens[k]:<8}")
                else:
                    token_row.append(f"{i}:{'---':<8}")
            print(f"Top {k+1}: " + " ".join(token_row))

        # Print probabilities if requested
        if print_prob:
            print("\nProbabilities:")
            for k in range(min(topk, max_predictions)):
                prob_row = []
                for i, probs in enumerate(all_probs):
                    if k < len(probs):
                        prob_row.append(f"{i}:{probs[k]:3d}%")
                    else:
                        prob_row.append(f"{i}:{'---':>3}")
                print(f"Top {k+1}: " + " ".join(prob_row))

        print()  # Add a blank line for readability

    def print_top_logit_progression(self, inputs, n_highest=1, chunk_size=10):
        self.reset_all_layers()
        logits = self.get_logits(inputs)
        
        num_layers = len(self.model.base_model.model.layers)
        seq_length = inputs['input_ids'].size(1)

        # Initialize matrices to store nth highest tokens and their logits
        nth_tokens = np.empty((num_layers + 1, seq_length), dtype=object)
        nth_logits = np.zeros((num_layers + 1, seq_length))
        
        # Process each layer
        for i in range(num_layers):
            layer = self.model.base_model.model.layers[i]
            if layer.block_output_unembedded is not None:
                for j in range(min(layer.block_output_unembedded.size(1), seq_length)):
                    values, indices = torch.topk(layer.block_output_unembedded[0][j], n_highest)
                    nth_tokens[i][j] = self.decode_tensors(indices[-1].unsqueeze(-1)).strip()
                    nth_logits[i][j] = values[-1].item()
        
        # Process final output
        for j in range(min(logits.size(1), seq_length)):
            values, indices = torch.topk(logits[0][j], n_highest)
            nth_tokens[-1][j] = self.decode_tensors(indices[-1].unsqueeze(-1)).strip()
            nth_logits[-1][j] = values[-1].item()
        
        # Print the progression in chunks
        print(f"{n_highest}th highest logit progression:")
        
        for chunk_start in range(0, seq_length, chunk_size):
            chunk_end = min(chunk_start + chunk_size, seq_length)
            print(f"\nTokens {chunk_start} to {chunk_end - 1}:")
            print(f"Layer | " + " | ".join([f"Token {i}" for i in range(chunk_start, chunk_end)]))
            print("-" * (7 + 8 * (chunk_end - chunk_start)))
            
            for i in range(num_layers + 1):
                layer_name = f"h{i}_out" if i < num_layers else "h_out"
                tokens = [f"{token:<6}" for token in nth_tokens[i][chunk_start:chunk_end]]
                print(f"{layer_name:<5} | " + " | ".join(tokens))
            

    @staticmethod
    def get_color_code(value):
        # Convert value to a color (you can adjust this to match your preferred color scheme)
        r = int(255 * (1 - value))
        b = int(255 * value)
        return f'\033[38;2;{r};0;{b}m'


In [80]:
model = LlamaHelper(CONFIG_FILE, MODEL_PATH, train_set)

In [72]:
model.print_top_logit_progression(train_set[50],1,10)

1th highest logit progression:

Tokens 0 to 9:
Layer | Token 0 | Token 1 | Token 2 | Token 3 | Token 4 | Token 5 | Token 6 | Token 7 | Token 8 | Token 9
---------------------------------------------------------------------------------------
h0_out | 0-1    | 0-8    | 6      | 0-0    | .      | .      | 0-0    | 0-0    | 0-0    | 0-0   
h1_out | .      | .      | .      | .      | .      | .      | .      | .      | .      | .     
h2_out | .      | .      | .      | .      | .      | .      | .      | .      | .      | .     
h3_out | .      | .      | .      | .      | .      | .      | .      | .      | .      | .     
h_out | .      | .      | .      | .      | .      | .      | .      | .      | .      | .     

Tokens 10 to 19:
Layer | Token 10 | Token 11 | Token 12 | Token 13 | Token 14 | Token 15 | Token 16 | Token 17 | Token 18 | Token 19
---------------------------------------------------------------------------------------
h0_out | .      | 4      | 0-2    | 4      | 0-2    |

In [82]:
model.print_top_logit_progression(train_set[50],2,10)

2th highest logit progression:

Tokens 0 to 9:
Layer | Token 0 | Token 1 | Token 2 | Token 3 | Token 4 | Token 5 | Token 6 | Token 7 | Token 8 | Token 9
---------------------------------------------------------------------------------------
h0_out | 0-0    | 0-1    | 0-0    | 0-4    | 0-0    | 0-0    | 0      | 6      | .      | .     
h1_out | 0-1    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0   
h2_out | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0   
h3_out | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0   
h_out | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0    | 0-0   

Tokens 10 to 19:
Layer | Token 10 | Token 11 | Token 12 | Token 13 | Token 14 | Token 15 | Token 16 | Token 17 | Token 18 | Token 19
---------------------------------------------------------------------------------------
h0_out | 0-1    | 0      | 0-8    | 0      | 0-8    |