# Setup

In [1]:
import sys

sys.path.append("..")
import torch as th

_ = th.set_grad_enabled(False)

In [None]:
from dotenv import load_dotenv

load_dotenv()

# Offline Dashboard

**Note that if you plan to only run the online dashboard, you might want to switch to a non-gpu runtime.**

In [None]:
from tools.cc_utils import get_available_models
from IPython.display import Markdown

models = get_available_models()
# Display available CrossCoder models
models_md = "\n".join([f"- {model}" for model in models])
display(Markdown(f"## Available CrossCoders:\n{models_md}"))

In [7]:
l1_crosscoder = "gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04"
btopk_crosscoder = (
    "gemma-2-2b-L13-k100-lr1e-04-local-shuffling-CCLoss"  # the topk from the paper
)
crosscoder = btopk_crosscoder

⚠️ Norm diff is from 0 = chat only to 1 = base only which is the opposite of the paper ⚠️

In [None]:
from tools.utils import load_latent_df

df = load_latent_df(crosscoder)
df.sort_values(by="dec_norm_diff").head(10)

### Cool features to explore
Explore some features from the df above, or check the features we already manually analyzed [here](https://flax-group-6cc.notion.site/Using-CrossCoder-for-model-diffing-e9e3e6d48cc542a8b594ab737936d433?pvs=25#e917941297da40278d8c38d99fca9325)

A quick recap is available in the cell below

- **70149**: Refusal related latent: Requests for harmful instructions.
- **7736**: Refusal related latent: Generally sensitive content.
- **24613**: Refusal related latent: Unethical content relating to race, gender and stereotypes.
- **20384**: Refusal related latent: Requests for harmful instructions.
- **38009**: Refusal related latent: The model has refused to answer a user input.
- **2138**: Personal questions: Questions regarding the personal experiences, emotions and preferences, with a strong activation on questions about Gemma itself.
- **14350**: False information detection: Detects when the user is providing false information.
- **62019**: False information detection: Activates on user inputs containing incorrect information, similar to Latent 14350, but activates more strongly on template tokens.
- **58070**: Missing information detection: Activates on user inputs containing missing information.
- **54087**: Rewriting requests: Activates when the model should rewrite or paraphrase something.
- **50586**: Joke detection: Activates after jokes or humorous content.
- **69447**: Response length measurement: measures requested response length, with highest activation on a request for a paragraph.
- **10925**: Summarization requests: Activates when the user requests a summary.
- **6583**: Knowledge boundaries: Activates when the model is missing access to information.
- **4622**: Information detail detection: Activates on requests for detailed information.

### Enjoy your dashboard!

In [None]:
from tools.utils import offline_dashboard

off_dashboard = offline_dashboard(crosscoder)

# Online Feature Dashboard

In [None]:
import sys

sys.path.append("..")
from tools.utils import online_dashboard

online_dashboard(crosscoder).display()

# Inference demo

In [None]:
from nnterp import load_model
from nnterp.nnsight_utils import get_layer_output, get_layer
from tools.utils import load_dictionary_model

base_model = load_model("google/gemma-2-2b", torch_dtype=th.bfloat16)
chat_model = load_model("google/gemma-2-2b-it", torch_dtype=th.bfloat16)
layer = 13
cc_device = "cuda:0" if th.cuda.is_available() else "cpu"
crosscoder = load_dictionary_model(crosscoder).to(cc_device)
sample_conv = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I am fine, thank you!"},
]
toks = chat_model.tokenizer.apply_chat_template(sample_conv, return_tensors="pt")
with base_model.trace(toks):
    base_acts = get_layer_output(base_model, layer).to(cc_device).save()
    get_layer(base_model, layer).output.stop()
with chat_model.trace(toks):
    chat_acts = get_layer_output(chat_model, layer).to(cc_device).save()
    get_layer(chat_model, layer).output.stop()

cc_input = th.stack(
    [
        base_acts.reshape(-1, base_acts.shape[-1]).to(cc_device),
        chat_acts.reshape(-1, chat_acts.shape[-1]).to(cc_device),
    ],
    dim=1,
).float()
print(cc_input.shape)  # (b * seq_len, 2, d)

cc_output = crosscoder(cc_input)