## Evaluating ReAX.

#### Set-up.

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyreax

except ModuleNotFoundError:
    # relative import; better to pip install subctrl
    import sys
    sys.path.append("../../pyreax")
    import pyreax



In [2]:
import json
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch, pyreft
from pathlib import Path
from pyvene import (
    IntervenableModel,
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import get_scheduler

from circuitsvis.tokens import colored_tokens
from IPython.core.display import display, HTML
from pyreax import (
    EXAMPLE_TAG, 
    ReAXFactory, 
    MaxReLUIntervention, 
    SubspaceAdditionIntervention, 
    JumpReLUSAECollectIntervention,
    make_data_module, 
    save_reax,
    load_reax,
    load_sae,
    generate_html_with_highlight_text
)
from pyreax import (
    set_decoder_norm_to_unit_norm, 
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations,
    get_lr
)

  from IPython.core.display import display, HTML


In [3]:
# params
dump_dir = "./tmp/gemma-2-2b/20-reax-res-gpt-4o/"
val_n = 10
n_decimal = 3
reax_topk = 10

# Load saved meta.
config, training_df, concept_metadata, weights = load_reax(dump_dir)

# Load lm.
model_name = config.model_name
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
model.config.use_cache = False
model = model.cuda()

tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

sae_weights = load_sae(concept_metadata)

LAYER = config.layer

reax_intervention = MaxReLUIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=weights.shape[0],
)
reax_intervention.proj.weight.data = weights.data
_ = reax_intervention.cuda()
pv_reax_model = IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": reax_intervention}, model=model)

sae_intervention = JumpReLUSAECollectIntervention(
    embed_dim=sae_weights['W_enc'].shape[0],
    low_rank_dimension=sae_weights['W_enc'].shape[1]
)
sae_intervention.load_state_dict(sae_weights, strict=False)
_ = sae_intervention.cuda()
pv_sae_model = IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": sae_intervention}, model=model)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### Latent activation eval.


In [4]:
validation_df_map = {}
id_sae_link_map = {}
for meta in concept_metadata:
    meta_dict = json.loads(meta)
    concept = meta_dict["concept"]
    contrast_concepts = {}
    contrast_concepts[concept] = meta_dict["contrast_concepts"]
    print("Testing with concept:", concept)
    
    reax_id = int(meta_dict["_id"])
    sae_id = int(meta_dict["sae_concept"].split("/")[-1])
    id_sae_link_map[reax_id] = meta_dict["sae_concept"]
    
    # test prompt
    reax_factory = ReAXFactory(
        model, tokenizer,
        concepts=[concept], 
        contrast_concepts=contrast_concepts,
        dump_dir=dump_dir
    )

    positive_df = reax_factory.create_eval_df(n=val_n, category="positive")
    negative_df = reax_factory.create_eval_df(n=val_n, category="negative")
    hard_negative_df = reax_factory.create_eval_df(n=val_n, category="hard negative")
    validation_df = pd.concat([positive_df, negative_df, hard_negative_df], axis=0)
    validation_df_map[concept] = validation_df



Testing with concept: terms related to artificiality and deception




Testing with concept: terms related to employment and employees




In [5]:
all_validation_dfs = []
for meta in concept_metadata:
    meta_dict = json.loads(meta)
    concept = meta_dict["concept"]
    contrast_concepts = {}
    contrast_concepts[concept] = meta_dict["contrast_concepts"]
    print("Testing with concept:", concept)
    
    reax_id = int(meta_dict["_id"])
    sae_id = int(meta_dict["sae_concept"].split("/")[-1]) 
    validation_df = validation_df_map[concept]
    
    all_sae_acts = []
    all_reax_acts = []
    all_sae_max_act = []
    all_reax_max_act = []
    for _, row in validation_df.iterrows():
        inputs = tokenizer.encode(
            row["input"], return_tensors="pt", add_special_tokens=True).to("cuda")
        # sae acts
        sae_acts = pv_sae_model.forward(
            {"input_ids": inputs}, return_dict=True
        ).collected_activations[0][1:, sae_id].data.cpu().numpy().tolist() # no bos token
        sae_acts = [round(x, n_decimal) for x in sae_acts]
        max_sae_act = max(sae_acts)
        
        # reax acts
        reax_in = gather_residual_activations(model, LAYER, inputs)
        reax_acts, _ = reax_intervention.encode(
            reax_in[:,1:], # no bos token
            subspaces={
                "input_subspaces": torch.tensor([reax_id])}, k=reax_topk)
        reax_acts = reax_acts.flatten().data.cpu().numpy().tolist()
        reax_acts = [round(x, n_decimal) for x in reax_acts]
        max_reax_act = max(reax_acts)
        
        all_sae_acts += [sae_acts]
        all_reax_acts += [reax_acts]
        all_sae_max_act += [max_sae_act]
        all_reax_max_act += [max_reax_act]
        
    validation_df['sae_acts'] = all_sae_acts
    validation_df['reax_acts'] = all_reax_acts
    validation_df['max_sae_act'] = all_sae_max_act
    validation_df['max_reax_act'] = all_reax_max_act
    validation_df['reax_id'] = reax_id
    validation_df['sae_id'] = sae_id
    validation_df['sae_link'] = meta_dict["sae_concept"]
    all_validation_dfs += [validation_df]

all_validation_df = pd.concat(all_validation_dfs, axis=0)
all_validation_df.to_csv(Path(dump_dir) / f"val_latent.csv")

Testing with concept: terms related to artificiality and deception
Testing with concept: terms related to employment and employees


In [7]:
html_content_interactive = generate_html_with_highlight_text(
    id_sae_link_map,
    pd.read_csv(Path(dump_dir) / f"val_latent.csv"), 
    tokenizer
)
output_file_interactive = Path(dump_dir) / f"val_latent.html"
with open(output_file_interactive, 'w') as file:
    file.write(html_content_interactive)