## Evaluating ReAX.

#### Set-up.

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyreax

except ModuleNotFoundError:
    # relative import; better to pip install subctrl
    import sys
    sys.path.append("../../pyreax")
    import pyreax



In [2]:
import json
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch, pyreft
from pathlib import Path
from pyvene import (
    IntervenableModel,
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import get_scheduler

from circuitsvis.tokens import colored_tokens
from IPython.core.display import display, HTML
from pyreax import (
    EXAMPLE_TAG, 
    ReAXFactory, 
    MaxReLUIntervention, 
    SubspaceAdditionIntervention, 
    JumpReLUSAECollectIntervention,
    make_data_module, 
    save_reax,
    load_reax,
    load_sae,
    generate_html_with_highlight_text
)
from pyreax import (
    set_decoder_norm_to_unit_norm, 
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations,
    get_lr
)

  from IPython.core.display import display, HTML


In [57]:
# params
dump_dir = "./tmp/gemma-2-2b/20-reax-res-gpt-4o/"
val_n = 10
n_decimal = 3
reax_topk = 10
input_length = 32

# Load saved meta.
config, training_df, concept_metadata, weight, bias = load_reax(dump_dir)

# Load lm.
model_name = config.model_name
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
model.config.use_cache = False
model = model.cuda()

tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

sae_weights = load_sae(concept_metadata)

LAYER = config.layer

reax_intervention = MaxReLUIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=weight.shape[0],
)
reax_intervention.proj.weight.data = weight.data
reax_intervention.proj.bias.data = bias.data
_ = reax_intervention.cuda()
pv_reax_model = IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": reax_intervention}, model=model)

sae_intervention = JumpReLUSAECollectIntervention(
    embed_dim=sae_weights['W_enc'].shape[0],
    low_rank_dimension=sae_weights['W_enc'].shape[1]
)
sae_intervention.load_state_dict(sae_weights, strict=False)
_ = sae_intervention.cuda()
pv_sae_model = IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": sae_intervention}, model=model)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### Latent activation eval.


In [51]:
validation_df_map = {}
id_sae_link_map = {}
for meta in concept_metadata:
    meta_dict = json.loads(meta)
    concept = meta_dict["concept"]
    contrast_concepts = {}
    contrast_concepts[concept] = meta_dict["contrast_concepts"]
    concept_genres = {}
    concept_genres[concept] = meta_dict["concept_genres"]
    print("Testing with concept:", concept)
    
    reax_id = int(meta_dict["_id"])
    sae_id = int(meta_dict["sae_concept"].split("/")[-1])
    id_sae_link_map[reax_id] = meta_dict["sae_concept"]
    
    # test prompt
    reax_factory = ReAXFactory(
        model, tokenizer,
        concepts=[concept], 
        contrast_concepts=contrast_concepts,
        dump_dir=dump_dir
    )

    positive_df = reax_factory.create_eval_df(n=val_n, category="positive", input_length=input_length)
    negative_df = reax_factory.create_eval_df(n=val_n, category="negative", input_length=input_length)
    hard_negative_df = reax_factory.create_eval_df(n=val_n, category="hard negative", input_length=input_length)
    validation_df = pd.concat([positive_df, negative_df, hard_negative_df], axis=0)
    validation_df_map[concept] = validation_df



Testing with concept: terms related to artificiality and deception




Testing with concept: terms related to employment and employees




In [58]:
torch.cuda.empty_cache()
all_validation_dfs = []
with torch.no_grad():
    for meta in concept_metadata:
        meta_dict = json.loads(meta)
        concept = meta_dict["concept"]
        contrast_concepts = {}
        contrast_concepts[concept] = meta_dict["contrast_concepts"]
        print("Testing with concept:", concept)
        
        reax_id = int(meta_dict["_id"])
        sae_id = int(meta_dict["sae_concept"].split("/")[-1]) 
        validation_df = validation_df_map[concept]
        
        all_sae_acts = []
        all_reax_acts = []
        all_sae_max_act = []
        all_reax_max_act = []
        for _, row in validation_df.iterrows():
            inputs = tokenizer.encode(
                row["input"], return_tensors="pt", add_special_tokens=True).to("cuda")
            # sae acts
            sae_acts = pv_sae_model.forward(
                {"input_ids": inputs}, return_dict=True
            ).collected_activations[0][1:, sae_id].data.cpu().numpy().tolist() # no bos token
            sae_acts = [round(x, n_decimal) for x in sae_acts]
            max_sae_act = max(sae_acts)
            
            # reax acts
            reax_in = gather_residual_activations(model, LAYER, inputs)
            reax_acts, _ = reax_intervention.encode(
                reax_in[:,1:], # no bos token
                subspaces={
                    "input_subspaces": torch.tensor([reax_id])}, k=reax_topk)
            reax_acts = reax_acts.flatten().data.cpu().numpy().tolist()
            reax_acts = [round(x, n_decimal) for x in reax_acts]
            max_reax_act = max(reax_acts)
            
            all_sae_acts += [sae_acts]
            all_reax_acts += [reax_acts]
            all_sae_max_act += [max_sae_act]
            all_reax_max_act += [max_reax_act]
            
        validation_df['sae_acts'] = all_sae_acts
        validation_df['reax_acts'] = all_reax_acts
        validation_df['max_sae_act'] = all_sae_max_act
        validation_df['max_reax_act'] = all_reax_max_act
        validation_df['reax_id'] = reax_id
        validation_df['sae_id'] = sae_id
        validation_df['sae_link'] = meta_dict["sae_concept"]
        all_validation_dfs += [validation_df]
    
    all_validation_df = pd.concat(all_validation_dfs, axis=0)
    all_validation_df.to_csv(Path(dump_dir) / f"val_latent.csv")

Testing with concept: terms related to artificiality and deception
Testing with concept: terms related to employment and employees


In [59]:
html_content_interactive = generate_html_with_highlight_text(
    id_sae_link_map,
    pd.read_csv(Path(dump_dir) / f"val_latent.csv"), 
    tokenizer
)
output_file_interactive = Path(dump_dir) / f"val_latent.html"
with open(output_file_interactive, 'w') as file:
    file.write(html_content_interactive)

In [66]:
import os
import asyncio
from openai import AsyncOpenAI

client = AsyncOpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

memo = {}

async def main_one() -> None:
    chat_completion = await client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": "Say this is a test",
            }
        ],
        model="gpt-3.5-turbo",
    )
    print("1", chat_completion)

async def main_two() -> None:
    chat_completion = await client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": "Say this is a test",
            }
        ],
        model="gpt-3.5-turbo",
    )
    print("2", chat_completion)

async def main_three() -> None:
    chat_completion = await client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": "Say this is a test",
            }
        ],
        model="gpt-3.5-turbo",
    )
    print("3", chat_completion)

async def main_four() -> None:
    chat_completion = await client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": "Say this is a test",
            }
        ],
        model="gpt-3.5-turbo",
    )
    print("4", chat_completion)

In [69]:
await main()

ChatCompletion(id='chatcmpl-AGecN0kfDEJnydL556nQCseC1pqAj', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='This is a test.', role='assistant', function_call=None, tool_calls=None, refusal=None))], created=1728532747, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=5, prompt_tokens=12, total_tokens=17, prompt_tokens_details={'cached_tokens': 0}, completion_tokens_details={'reasoning_tokens': 0}))


In [70]:
import asyncio
import time

async def sleep():
    print(f'Time: {time.time() - start:.2f}')
    await asyncio.sleep(1)

async def sum(name, numbers):
    total = 0
    for number in numbers:
        print(f'Task {name}: Computing {total}+{number}')
        await sleep()
        total += number
    print(f'Task {name}: Sum = {total}\n')

start = time.time()

loop = asyncio.get_event_loop()
tasks = [
    loop.create_task(sum("A", [1, 2])),
    loop.create_task(sum("B", [1, 2, 3])),
]
loop.run_until_complete(asyncio.wait(tasks))
loop.close()

end = time.time()
print(f'Time: {end-start:.2f} sec')

RuntimeError: asyncio.run() cannot be called from a running event loop

In [75]:
import os
import asyncio
from openai import AsyncOpenAI

client = AsyncOpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
)

async def main_one() -> dict:
    chat_completion = await client.chat.completions.create(
        messages=[
            {"role": "user", "content": "Say this is a test"}
        ],
        model="gpt-3.5-turbo",
    )
    print("1", chat_completion)
    return chat_completion

async def main_two() -> dict:
    chat_completion = await client.chat.completions.create(
        messages=[
            {"role": "user", "content": "Say this is another test"}
        ],
        model="gpt-3.5-turbo",
    )
    print("2", chat_completion)
    return chat_completion

async def main_three() -> dict:
    chat_completion = await client.chat.completions.create(
        messages=[
            {"role": "user", "content": "Say this is yet another test"}
        ],
        model="gpt-3.5-turbo",
    )
    print("3", chat_completion)
    return chat_completion

async def main_four() -> dict:
    chat_completion = await client.chat.completions.create(
        messages=[
            {"role": "user", "content": "Say this is the final test"}
        ],
        model="gpt-3.5-turbo",
    )
    print("4", chat_completion)
    return chat_completion

async def one():
    await main_one()

async def two():
    await main_two()

async def three():
    await main_three()

async def four():
    await main_four()
    
start = time.time()

loop = asyncio.get_event_loop()
tasks = [
    loop.create_task(one()),
    loop.create_task(two()),
    loop.create_task(three()),
    loop.create_task(four()),
]
loop.run_until_complete(asyncio.wait(tasks))
loop.close()

end = time.time()
print(f'Time: {end-start:.2f} sec')

RuntimeError: This event loop is already running

2 ChatCompletion(id='chatcmpl-AGiANfnPp5pWDKOmXYtLIHhK6S3Mc', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='This is another test.', role='assistant', function_call=None, tool_calls=None, refusal=None))], created=1728546387, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=5, prompt_tokens=12, total_tokens=17, prompt_tokens_details={'cached_tokens': 0}, completion_tokens_details={'reasoning_tokens': 0}))
1 ChatCompletion(id='chatcmpl-AGiAN9kwwhHqNPfPyJ9wXJXtgwy33', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='This is a test.', role='assistant', function_call=None, tool_calls=None, refusal=None))], created=1728546387, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=5, prompt_tokens=12, total_tokens=17, 

In [72]:
loop.close()

end = time.time()
print(f'Time: {end-start:.2f} sec')

RuntimeError: Cannot close a running event loop