## Introducing ReAX.

#### Set-up.

In [None]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyreax

except ModuleNotFoundError:
    # relative import; better to pip install subctrl
    import sys
    sys.path.append("../../pyreax")
    import pyreax

In [None]:
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch, pyreft
from pyvene import (
    IntervenableModel,
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import get_scheduler

from circuitsvis.tokens import colored_tokens
from IPython.core.display import display, HTML
from pyreax import (
    EXAMPLE_TAG, 
    ReAXFactory, 
    MaxReLUIntervention, 
    SubspaceAdditionIntervention, 
    make_data_module, 
    save_reax,
    Config,
    load_config_from_json,
    load_concepts,
)
from pyreax import (
    set_decoder_norm_to_unit_norm, 
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations, 
    get_lr
)

#### Training.

Let's focus on a single layer, layer 20 of the LM.

In [None]:
config = Config(
    lm_model = "gpt-4o",
    concept_path = "../demo/url_concepts.txt",
    model_name = "google/gemma-2-2b",
    n_data = 66,
    layer = 20,
    component = "res",
    input_length = 32,
    output_length = 16,
    
    batch_size = 6,
    n_epochs = 12,
    k_latent_null_loss = 1,
    lr = 3E-3,
    coeff_l1_loss_null = 5E-2,
    coeff_l1_loss = 1E-3,
    dump_dir = "./tmp"
)

# params
lm_model = config.lm_model
concept_path = config.concept_path
model_name = config.model_name
N = config.n_data
lr = config.lr
layer = config.layer
num_epochs = config.n_epochs
k_latent = config.k_latent_null_loss
coeff_l1_loss_null = config.coeff_l1_loss_null
coeff_l1_loss = config.coeff_l1_loss
dump_dir = config.dump_dir
input_length = config.input_length
output_length = config.output_length

In [None]:
# for demo purposes, these are import directly
concepts, sae_metadata = load_concepts(concept_path)

# Load lm.
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
model.config.use_cache = False
model = model.cuda()

tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

# create dataset
reax_factory = ReAXFactory(
    model, tokenizer,
    lm_model=lm_model,
    concepts=concepts, 
    dump_dir=dump_dir,
)
reax_df = reax_factory.create_df(
    n=N,
    input_length=input_length,
    output_length=output_length
) 

In [None]:
# make data module.
data_module = make_data_module(tokenizer, model, reax_df)
train_dataloader = DataLoader(
    data_module["train_dataset"], shuffle=True, batch_size=config.batch_size, 
    collate_fn=data_module["data_collator"])

# get reft model
model = model.eval()
reax_intervention = MaxReLUIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=2,
)
reax_intervention = reax_intervention.train()
reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": reax_intervention} for l in [layer]])
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

# optimizer and lr
optimizer = torch.optim.AdamW(reft_model.parameters(), lr=config.lr)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer,
    num_warmup_steps=0, num_training_steps=num_training_steps)

In [None]:
# Main training loop.
progress_bar, curr_step = tqdm(range(num_training_steps)), 0
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # prepare input
        inputs = {k: v.to("cuda") for k, v in batch.items()}
        unit_locations={"sources->base": (
            None,
            inputs["intervention_locations"].permute(1, 0, 2).tolist()
        )}
        subspaces = [{
            "input_subspaces": inputs["input_subspaces"],
            "output_subspaces": inputs["output_subspaces"]}]

        # forward
        _, cf_outputs = reft_model(
            base={
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"]
            }, unit_locations=unit_locations, labels=inputs["labels"],
            subspaces=subspaces, use_cache=False)

        # loss
        loss = cf_outputs.loss
        latent = reft_model.full_intervention_outputs[0].latent * inputs["intervention_masks"]
        topk_latent, _ = torch.topk(latent, k_latent, dim=-1)
        null_loss = (topk_latent.mean(dim=-1)*(inputs["groups"]==EXAMPLE_TAG.CONTROL.value))
        null_loss = null_loss.sum()

        l1_loss = (latent.mean(dim=-1)*(inputs["groups"]!=EXAMPLE_TAG.CONTROL.value))
        l1_loss = l1_loss.sum()
        
        coeff = curr_step/num_training_steps
        loss += coeff*coeff_l1_loss_null*null_loss + coeff*coeff_l1_loss*l1_loss
        
        # grads
        loss.backward()
        set_decoder_norm_to_unit_norm(reax_intervention)
        remove_gradient_parallel_to_decoder_directions(reax_intervention)
        curr_step += 1
        curr_lr = get_lr(optimizer)
        # optim
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_description("lr %.6f || loss %.6f || null l1 loss %.6f" % (curr_lr, loss, null_loss))

#### Eval - latent space disentanglements.

In [None]:
# run inference loop
concepts = reax_factory.concepts
for _, row in reax_df.iterrows():
    prompt = tokenizer.encode(
        row["input"], return_tensors="pt", add_special_tokens=True).to("cuda") 
    if str(row["group"]) == "EXAMPLE_TAG.CONTROL":
        input_concept = row["input_concept"]
        print(f"> null <{input_concept}> example:")
        test_concept = concepts[row["input_subspace"]]
        print(f"> testing concept: {test_concept}")
    else:
        print(f"> targeted concept:")
        print(concepts[row["input_subspace"]])
    target_act = gather_residual_activations(model, layer, prompt)
    p, _ = reax_intervention.encode(
        target_act[:,1:], 
        subspaces={
            "input_subspaces": torch.tensor([row["input_subspace"]]),
            "output_subspaces": torch.tensor([row["output_subspace"]])}, k=10)
    print("maximal act:", round(p.max().tolist(), 3))
    html = colored_tokens(tokenizer.tokenize(row["input"]), p.flatten())
    display(html)

#### Eval - logits lens.

In [None]:
vocab_logits = model.lm_head.weight @ reax_intervention.proj.weight.data[0]
values, indices = vocab_logits.topk(k=10)
tokenizer.batch_decode(indices.unsqueeze(dim=-1))

#### Eval - steering.

In [None]:
# load the chat-lm
chat_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", # google/gemma-2b-it
    device_map='cpu',
)
chat_model.config.use_cache = False
chat_model = chat_model.cuda()
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
_ = chat_model.eval()

In [None]:
steering_intervention = SubspaceAdditionIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=2,
)
steering_intervention.cuda()
steering_intervention.proj.weight.data = reax_intervention.proj.weight.data

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": steering_intervention} for l in [20]])
steering_model = pyreft.get_reft_model(chat_model, reft_config)

In [None]:
prompt = "Write a English story."
concept_id = 0
mag = 120
print(f"+ {mag} *", reax_factory.concepts[concept_id])

messages = [
    {"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")

_, reft_response = steering_model.generate(
    prompt, 
    unit_locations=None, 
    intervene_on_prompt=True, 
    subspaces=[{"idx": concept_id, "mag": mag}], max_new_tokens=128, 
    do_sample=False, early_stopping=True, no_repeat_ngram_size=5, repetition_penalty=1.1
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

#### Saving.

In [None]:
torch.cuda.empty_cache()
save_reax("./tmp", config, reax_df, reax_factory, sae_metadata, reax_intervention)