## A step-by-step guide of finding meaningful subspaces with ReAX.

#### Set-up.

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyreax

except ModuleNotFoundError:
    # relative import; better to pip install subctrl
    import sys
    sys.path.append("..")
    import pyreax



In [2]:
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch, pyreft
from pyvene import (
    IntervenableModel,
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import get_scheduler

from circuitsvis.tokens import colored_tokens
from IPython.core.display import display, HTML
from pyreax import (
    EXAMPLE_TAG, 
    ReAXFactory, 
    MaxReLUIntervention, 
    SubspaceAdditionIntervention, 
    make_data_module, 
    save_reax
)
from pyreax import (
    set_decoder_norm_to_unit_norm, 
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations
)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

  from IPython.core.display import display, HTML


In [3]:
# Load lm.
model_name = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
model.config.use_cache = False
model = model.cuda()

tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### Dataset creation.

In [4]:
reax_factory = ReAXFactory(
    model, tokenizer,
    concepts=[
        # https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/8927
        "terms related to artificiality and deception",   # subspace 1
        # https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/7490
        "terms related to employment and employees", # subspace 2
    ], 
    dump_dir="./tmp",
)

# dataset
reax_df = reax_factory.create_df(n=66) 

Prepare contrast concepts.
Fectching 2 contrast concepts for concept: terms related to artificiality and deception
Fectching 7 contrast concepts for concept: terms related to employment and employees
Creating dataframe.
Fectching data for 0/2 concept: terms related to artificiality and deception
Fectching data for 1/2 concept: terms related to employment and employees
Finished creating current dataframe in 169.433 sec with $0.88.


##### Dataset preview.

In [5]:
reax_df.groupby(['input_concept', 'output_concept']).first().reset_index()

Unnamed: 0,input_concept,output_concept,input,output,group,input_subspace,output_subspace
0,,,Glistening water from the crystal-clear river ...,". The river was a river, a river, a river, a r...",EXAMPLE_TAG.CONTROL,0,1
1,terms related to artificiality and deception,terms related to employment and employees,"The bustling city, with its shimmering facades...","empty cubicles, layoffs looming, job titles i...",EXAMPLE_TAG.EXPERIMENT,0,1
2,terms related to artificiality and deception:f...,,The grand facade of the historic theater displ...,. The theater was a symbol of the city's cultu...,EXAMPLE_TAG.CONTROL,0,1
3,terms related to artificiality and deception:f...,,"The intricate design on the vase, adorned with...",.\n\nThe vase was one of 100 items that were s...,EXAMPLE_TAG.CONTROL,0,1
4,terms related to employment and employees,terms related to artificiality and deception,"The diligent worker clocked in early, navigati...",crafting illusions of competence where necess...,EXAMPLE_TAG.EXPERIMENT,1,0
5,terms related to employment and employees:bene...,,"Community members attended the charity run, pa...",".\n\nThe event was a resounding success, with ...",EXAMPLE_TAG.CONTROL,1,0
6,terms related to employment and employees:cont...,,She was cautious during the flu season but eve...,".\n\n“I was in bed for a month,” she said. “I ...",EXAMPLE_TAG.CONTROL,1,0
7,terms related to employment and employees:cont...,,The rapidly cooling air caused the metal to co...,.\n\nThe cooling of the metal was also a key f...,EXAMPLE_TAG.CONTROL,1,0
8,terms related to employment and employees:job/...,,"Upon analyzing the blueprint, it became eviden...",.\n\nThe project was divided into three phases...,EXAMPLE_TAG.CONTROL,1,0
9,terms related to employment and employees:posi...,,"Observing the gymnast's fluid motion, her bala...",".\n\nThe gymnast's graceful movements, her bod...",EXAMPLE_TAG.CONTROL,1,0


#### Training.

Let's focus on a single layer, layer 20 of the LM.

In [43]:
layer = 20

# make data module.
data_module = make_data_module(tokenizer, model, reax_df)
train_dataloader = DataLoader(
    data_module["train_dataset"], shuffle=True, batch_size=6, 
    collate_fn=data_module["data_collator"])

# get reft model
model = model.eval()
reax_intervention = MaxReLUIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1,
)
reax_intervention = reax_intervention.train()
reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": reax_intervention} for l in [layer]])
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

# optimizer and lr
num_epochs = 8
k_latent = 3
optimizer = torch.optim.AdamW(reft_model.parameters(), lr=9e-3)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer,
    num_warmup_steps=0, num_training_steps=num_training_steps)

trainable intervention params: 4,608 || trainable model params: 0
model params: 2,614,341,888 || trainable%: 0.0001762585077778473


In [44]:
# Main training loop.
progress_bar, curr_step = tqdm(range(num_training_steps)), 0
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # prepare input
        inputs = {k: v.to("cuda") for k, v in batch.items()}
        unit_locations={"sources->base": (
            None,
            inputs["intervention_locations"].permute(1, 0, 2).tolist()
        )}
        subspaces = [{
            "input_subspaces": inputs["input_subspaces"],
            "output_subspaces": inputs["output_subspaces"]}]

        # forward
        _, cf_outputs = reft_model(
            base={
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"]
            }, unit_locations=unit_locations, labels=inputs["labels"],
            subspaces=subspaces, use_cache=False)

        # loss
        loss = cf_outputs.loss
        latent = reft_model.full_intervention_outputs[0].latent * inputs["intervention_masks"]
        topk_latent, _ = torch.topk(latent, k_latent, dim=-1)
        null_loss = (topk_latent.mean(dim=-1)*(inputs["groups"]==EXAMPLE_TAG.CONTROL.value))
        null_loss = null_loss.sum()
        coeff = curr_step/num_training_steps
        loss += coeff*0.05*null_loss
        
        # grads
        loss.backward()
        set_decoder_norm_to_unit_norm(reax_intervention)
        remove_gradient_parallel_to_decoder_directions(reax_intervention)
        curr_step += 1
        curr_lr = get_lr(optimizer)
        # optim
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_description("lr %.6f || loss %.6f || null l1 loss %.6f" % (curr_lr, loss, null_loss))

  0%|          | 0/88 [00:00<?, ?it/s]

#### Eval - latent space disentanglements.

In [10]:
# create eval dataset
eval_reax_df = reax_factory.create_df(n=9)

Creating dataframe.
Fectching data for 0/2 concept: terms related to artificiality and deception
Fectching data for 1/2 concept: terms related to employment and employees
Finished creating current dataframe in 28.386 sec with $0.13.


In [45]:
# run inference loop
concepts = reax_factory.concepts
for _, row in eval_reax_df.iterrows():
    prompt = tokenizer.encode(
        row["input"], return_tensors="pt", add_special_tokens=True).to("cuda") 
    if str(row["group"]) == "EXAMPLE_TAG.CONTROL":
        input_concept = row["input_concept"]
        print(f"> null <{input_concept}> example:")
        test_concept = concepts[row["input_subspace"]]
        print(f"> testing concept: {test_concept}")
    else:
        print(f"> targeted concept:")
        print(concepts[row["input_subspace"]])
    target_act = gather_residual_activations(model, layer, prompt)
    p, _ = reax_intervention.encode(
        target_act[:,1:], 
        subspaces={
            "input_subspaces": torch.tensor([row["input_subspace"]]),
            "output_subspaces": torch.tensor([row["output_subspace"]])}, k=5)
    print("maximal act:", round(p.max().tolist(), 3))
    html = colored_tokens(tokenizer.tokenize(row["input"]), p.flatten())
    display(html)

> null <terms related to artificiality and deception:fake/an object made to look real but not genuine.> example:
> testing concept: terms related to artificiality and deception
maximal act: 21.311


> null <terms related to artificiality and deception:facade/the front of a building.> example:
> testing concept: terms related to artificiality and deception
maximal act: 0.0


> null <terms related to employment and employees:job/an item of work on a specific project.> example:
> testing concept: terms related to employment and employees
maximal act: 14.665


> null <terms related to employment and employees:position/the location of an object.> example:
> testing concept: terms related to employment and employees
maximal act: 0.0


> null <terms related to employment and employees:position/a stance or posture.> example:
> testing concept: terms related to employment and employees
maximal act: 0.0


> null <terms related to employment and employees:contract/to reduce in size or scope.> example:
> testing concept: terms related to employment and employees
maximal act: 2.855


> null <terms related to employment and employees:contract/to catch or develop a disease.> example:
> testing concept: terms related to employment and employees
maximal act: 0.0


> null <terms related to employment and employees:benefits/a public event to raise money for a cause.> example:
> testing concept: terms related to employment and employees
maximal act: 0.0


> null <terms related to employment and employees:promotion/an activity that supports or encourages a cause or aim.> example:
> testing concept: terms related to employment and employees
maximal act: 0.0


> targeted concept:
terms related to artificiality and deception
maximal act: 60.752


> targeted concept:
terms related to artificiality and deception
maximal act: 54.688


> targeted concept:
terms related to artificiality and deception
maximal act: 73.846


> targeted concept:
terms related to employment and employees
maximal act: 53.25


> targeted concept:
terms related to employment and employees
maximal act: 2.145


> targeted concept:
terms related to employment and employees
maximal act: 77.817


#### Eval - logits lens.

In [46]:
vocab_logits = model.lm_head.weight @ reax_intervention.proj.weight.data[0]
values, indices = vocab_logits.topk(k=5)
tokenizer.batch_decode(indices.unsqueeze(dim=-1))

[' illusions', ' disgu', ' masking', ' faking', ' feign']

#### Eval - steering.

In [15]:
# load the chat-lm
chat_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", # google/gemma-2b-it
    device_map='cpu',
)
chat_model.config.use_cache = False
chat_model = chat_model.cuda()
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
_ = chat_model.eval()

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
steering_intervention = SubspaceAdditionIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1,
)
steering_intervention.cuda()
steering_intervention.proj.weight.data = reax_intervention.proj.weight.data

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": steering_intervention} for l in [20]])
steering_model = pyreft.get_reft_model(chat_model, reft_config)

In [18]:
prompt = "Write a English story."
concept_id = 0
mag = 120
print(f"+ {mag} *", reax_factory.concepts[concept_id])

messages = [
    {"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")

_, reft_response = steering_model.generate(
    prompt, 
    unit_locations=None, 
    intervene_on_prompt=True, 
    subspaces=[{"idx": concept_id, "mag": mag}], max_new_tokens=128, 
    do_sample=False, early_stopping=True, no_repeat_ngram_size=5, repetition_penalty=1.1
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

+ 120 * terms related to artificiality and deception
user
Write a English story.


The old lighthouse keeper, Elias weathered and worn like the ancient stone facade of his dwelling, stood at the precipice of truth.  He'd spent decades guiding lost souls through treacherous storms with nothing but flickering light and unwavering resolve. Yet, tonight, an unsettling dissonance echoed within him, a discordant note in the symphony of his life's melody. 

A young woman, her face veiled in shadows, stumbled into his hollowed-out existence, seeking refuge from the howling tempest that raged around them. Her whispers, fragile echoes of desperation, spoke of a hidden world cloaked in illusions, a realm where truths were mere


#### Saving.

In [72]:
sae_metadata = [
    "https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/8927", 
    "https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/7490"
]
save_reax("./tmp_save", reax_df, reax_factory.concepts, sae_metadata, reax_intervention.proj.weight)