In [15]:
import torch
import torch.nn as nn
import umap.umap_ as umap

import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.preprocessing import StandardScaler
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from IPython.display import display, HTML
import seaborn as sns
from sparse_auto import SAE

# Load Model and Tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load SAE
activation_data = torch.load("../activation_data.pt")
mlp_dim = activation_data.shape[-1]
hidden_dim_multiplier = 8
hidden_dim = hidden_dim_multiplier * mlp_dim
layer_index = 15

sae = SAE(mlp_dim, hidden_dim).to(device)
sae.load_state_dict(torch.load("../sae_model3.pth", map_location=torch.device("cpu")))
sae.eval()

# Functions

def tokenize_and_get_activations(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    outputs = model(**inputs, output_hidden_states=True)
    layer_activations = outputs.hidden_states[layer_index]  # Change as needed
    return inputs, layer_activations

def color_tokens_by_feature(text, feature_index):
    inputs, activations = tokenize_and_get_activations(text)
    feature_activations, _ = sae(activations[0])  # SAE encodings

    # Normalize activations for visualization
    norm = Normalize(vmin=feature_activations[:, feature_index].min().item(), 
                     vmax=feature_activations[:, feature_index].max().item())
    colors = [plt.cm.viridis(norm(a.item())) for a in feature_activations[:, feature_index]]

    # HTML output
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    colored_text = " ".join([
        f'<span style="background-color: rgba({int(c[0]*255)}, {int(c[1]*255)}, {int(c[2]*255)}, {c[3]:.2f});">{t[1:]}</span>'
        for t, c in zip(tokens, colors)])

    display(HTML(f'<p style="font-family:monospace;">{colored_text}</p>'))

def plot_interactive_umap(texts):
    all_activations = []

    for text in texts:
        _, activations = tokenize_and_get_activations(text)
        feature_activations, _ = sae(activations[0])
        all_activations.append(feature_activations.detach().cpu().numpy())

    all_activations = np.vstack(all_activations)
    reducer = umap.UMAP(random_state=42)
    embeddings = reducer.fit_transform(all_activations)

    # Scatterplot
    plt.figure(figsize=(10, 7))
    scatter = plt.scatter(embeddings[:, 0], embeddings[:, 1], c=np.arange(embeddings.shape[0]), cmap="viridis", s=10)
    plt.colorbar(scatter, label="Sample Index")
    plt.title("UMAP of Feature Activations")
    plt.show()

def plot_feature_similarity_matrix(texts):
    all_features = []

    for text in texts:
        _, activations = tokenize_and_get_activations(text)
        feature_activations, _ = sae(activations[0])
        all_features.append(feature_activations.mean(dim=0).detach().cpu().numpy())

    similarity_matrix = np.corrcoef(all_features)

    sns.heatmap(similarity_matrix, annot=True, cmap="coolwarm", xticklabels=texts, yticklabels=texts)
    plt.title("Feature Similarity Matrix")
    plt.show()

# Example Usage
sample_texts = [
    "The cat sat on the mat.",
    "A dog barked loudly.",
    "The sky is blue and clear.",
    "I love programming in Python."
]

print("Token Colorization:")
color_tokens_by_feature("The cat sat on the mat.", feature_index=158)


bbcss = ["I love wathching the BBC",
    "I love watching TV",
    "I love watching Netflix",
    "I love watching RTVE",
    "I like english channels, like"]

for text in bbcss:
    color_tokens_by_feature(text, feature_index=158)



Token Colorization:


  activation_data = torch.load("../activation_data.pt")
  sae.load_state_dict(torch.load("../sae_model3.pth", map_location=torch.device("cpu")))


In [None]:

print("Interactive UMAP:")
plot_interactive_umap(sample_texts)

print("Feature Similarity Matrix:")
plot_feature_similarity_matrix(sample_texts)