This notebook downloads a small piece of a text dataset to local cache (`/Users/USER/.cache/huggingface/datasets` on macbook). Then, you should manually install the files (which are dicts for activations and locations for each feature) from [https://huggingface.co/datasets/MrGonao/raw_features_gemma16/tree/main](https://huggingface.co/datasets/MrGonao/raw_features_gemma16/tree/main) to the features folder. After doing this, the code below will find you the activating documents for each feature. 

### Get Contexts and View Batch

In [1]:
import torch
#get context
def get_context(batch: list[str], pos: int, n=5) -> str:
    context = ''.join([s for s in batch[pos-n:pos+n] if s != '\n'])
    return context

def view_batch(n: int, tokens: torch.Tensor, tokenizer) -> list[str]:
    assert(tokens.ndim == 2 and tokens.shape[1] == 256)
    if n >= len(tokens):
        return None
    doc = tokens[n]
    return tokenizer.batch_decode(doc)

In [3]:
import networkx as nx
import matplotlib.pyplot as plt

def visualize_parse_tree(parse_dict):
    G = nx.DiGraph()
    for token in parse_dict:
        G.add_node(token['text'])
        for child in token['children']:
            G.add_edge(token['text'], child)
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G)  # You can also try 'shell_layout', 'circular_layout', etc.
    nx.draw(G, pos, with_labels=True, node_color="skyblue", edge_color="gray", node_size=2000, font_size=12, font_weight="bold")
    plt.show()

Collecting matplotlib
  Using cached matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Using cached contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Using cached fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (163 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Using cached kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.3 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Using cached pyparsing-3.2.0-py3-none-any.whl.metadata (5.0 kB)
Using cached matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl (7.8 MB)
Using cached contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl (251 kB)
Using cached cycler-0.12.1-py3-none-any.whl (8.3 kB)
Using cached fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl (2.3 MB)
Using cached ki

In [4]:
!pip list

Package            Version
------------------ -----------
accelerate         1.0.1
aiohappyeyeballs   2.4.3
aiohttp            3.10.10
aiosignal          1.3.1
annotated-types    0.7.0
appnope            0.1.4
asttokens          2.4.1
attrs              24.2.0
beartype           0.14.1
better-abc         0.0.3
bidict             0.23.1
blinker            1.8.2
blis               1.0.1
catalogue          2.0.10
certifi            2024.8.30
charset-normalizer 3.4.0
click              8.1.7
cloudpathlib       0.20.0
comm               0.2.2
confection         0.1.5
contourpy          1.3.0
cycler             0.12.1
cymem              2.0.8
datasets           3.0.2
debugpy            1.8.7
decorator          5.1.1
diffusers          0.31.0
dill               0.3.8
docker-pycreds     0.4.0
einops             0.8.0
en_core_web_sm     3.8.0
executing          2.1.0
fancy-einsum       0.0.3
filelock           3.16.1
Flask              3.0.3
fonttools          4.54.1
frozenlist         1.5.0
fs

### Make and Visualize Parse Trees with Spacy

In [2]:
from graphs import load_tokens, load_activations
tokens, tokenizer = load_tokens("google/gemma-2-9B", "kh4dien/fineweb-100m-sample")
activations, locations = load_activations('features/11_0_3275.safetensors')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import spacy
nlp = spacy.load("en_core_web_sm")

In [24]:
import importlib
import graphs
from graphs import make_graph
importlib.reload(graphs)

<module 'graphs' from '/Users/eriq/Desktop/syntax-sae/graphs.py'>

In [25]:

from graphs import make_graph
import networkx as nx
import matplotlib.pyplot as plt

def visualize_feature(n, activations, locations, k=5):
    print(f"Visualizing Feature {n}")
    idx = locations[:,2]== n
    locations = locations[idx]
    activations = activations[idx]
    location_dicts = []
    for location, activation in zip(locations, activations):
        d = {}
        d['batch'] = location[0]
        d['position'] = location[1]
        d['feature'] = location[2]
        d['activation'] = activation
        location_dicts.append(d)
    sorted_location_dicts = sorted(location_dicts, key=lambda x: x['activation'], reverse=True)
    count = 0
    #visualize top trees first.
    for d in sorted_location_dicts:
        batch = view_batch(int(d['batch']), tokens, tokenizer)
        if batch != None:
                count += 1
                pos = d['position']
                context = get_context(batch, pos) #todo: figure out better parse tree context e.g. by punct.
                root_node = make_parse_tree(context, pos)
                graph, layout = make_graph(root_node)
                labels = nx.get_node_attributes(graph, 'label')
                plt.figure(figsize=(6, 4))
                nx.draw(graph, layout, with_labels=True, labels=labels, node_size=2000, node_color='lightblue', font_size=10, font_weight='bold', arrows=True)
                plt.show()
                print(f'{batch[pos]:<15} @ {context:<50} {d["activation"]}')
        if count > k:
            break

In [26]:
visualize_feature(3000, activations, locations)

Visualizing Feature 3000


ImportError: requires pygraphviz http://pygraphviz.github.io/