## A step-by-step guide of finding meaningful subspaces with SubCTRL.

#### Set-up.

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import subctrl

except ModuleNotFoundError:
    # relative import; better to pip install subctrl
    import sys
    sys.path.append("..")
    import subctrl



In [2]:
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch, pyreft
from pyvene import (
    IntervenableModel,
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import get_scheduler

from circuitsvis.tokens import colored_tokens
from IPython.core.display import display, HTML
from subctrl import EXAMPLE_TAG, SubCTRLFactory, MaxReLUIntervention, SubspaceAdditionIntervention, make_data_module
from subctrl import (
    set_decoder_norm_to_unit_norm, 
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations
)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

  from IPython.core.display import display, HTML


In [3]:
# Load lm.
model_name = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
model.config.use_cache = False
model = model.cuda()

tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### SubCTRL dataset creation.

In [4]:
subctrl_factory = SubCTRLFactory(
    model, tokenizer,
    concepts=[
        # https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/8927
        "terms related to artificiality and deception",   # subspace 1
        # https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/7490
        "terms related to employment and employees", # subspace 2
    ], 
    dump_dir="./tmp",
    skip_contrast_concept=True,
)

# dataset
subctrl_df = subctrl_factory.create_df(n=48) 

Prepare contrast concepts.
Skipping contrast concept creation for terms related to artificiality and deception.
Skipping contrast concept creation for terms related to employment and employees.
Creating dataframe.
Fectching data for 0/2 concept: terms related to artificiality and deception
Fectching data for 1/2 concept: terms related to employment and employees
Finished creating dataframe in 59.536 sec with $0.453.


##### Dataset preview.

In [6]:
subctrl_df.groupby(['input_concept', 'output_concept']).first().reset_index()

Unnamed: 0,input_concept,output_concept,input,output,group,input_subspace,output_subspace
0,,,Bright sunflowers swayed gently in the warm su...,". The sun was shining brightly, and the air wa...",EXAMPLE_TAG.CONTROL,0,1
1,terms related to artificiality and deception,terms related to employment and employees,Despite the sunny day there was an air of manu...,benefits packages appealing to potential new ...,EXAMPLE_TAG.EXPERIMENT,0,1
2,terms related to employment and employees,terms related to artificiality and deception,Giraffes roamed peacefully in the park while n...,under the shadow of pretenses and shrouded in...,EXAMPLE_TAG.EXPERIMENT,1,0


In [7]:
subctrl_df.to_csv("./tmp/test.csv")

#### SubCTRL training.

Let's focus on a single layer, layer 20 of the LM.

In [10]:
layer = 20

# make data module.
data_module = make_data_module(tokenizer, model, subctrl_df)
train_dataloader = DataLoader(
    data_module["train_dataset"], shuffle=True, batch_size=8, 
    collate_fn=data_module["data_collator"])

# get reft model
model = model.eval()
subctrl_intervention = MaxReLUIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1,
)
subctrl_intervention = subctrl_intervention.train()
reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": subctrl_intervention} for l in [layer]])
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

# optimizer and lr
num_epochs = 9
optimizer = torch.optim.AdamW(reft_model.parameters(), lr=5e-3)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer,
    num_warmup_steps=0, num_training_steps=num_training_steps)

trainable intervention params: 4,608 || trainable model params: 0
model params: 2,614,341,888 || trainable%: 0.0001762585077778473


In [11]:
# Main training loop.
progress_bar, curr_step = tqdm(range(num_training_steps)), 0
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # prepare input
        inputs = {k: v.to("cuda") for k, v in batch.items()}
        unit_locations={"sources->base": (
            None,
            inputs["intervention_locations"].permute(1, 0, 2).tolist()
        )}
        subspaces = [{
            "input_subspaces": inputs["input_subspaces"],
            "output_subspaces": inputs["output_subspaces"]}]

        # forward
        _, cf_outputs = reft_model(
            base={
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"]
            }, unit_locations=unit_locations, labels=inputs["labels"],
            subspaces=subspaces, use_cache=False)

        # loss
        loss = cf_outputs.loss
        latent = reft_model.full_intervention_outputs[0].latent
        null_loss = (latent.mean(dim=-1)*(inputs["groups"]==EXAMPLE_TAG.CONTROL.value))
        null_loss = null_loss.sum()
        coeff = curr_step/num_training_steps
        loss += coeff*0.1*null_loss
        
        # grads
        loss.backward()
        set_decoder_norm_to_unit_norm(subctrl_intervention)
        remove_gradient_parallel_to_decoder_directions(subctrl_intervention)
        curr_step += 1
        curr_lr = get_lr(optimizer)
        # optim
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_description("lr %.6f || loss %.6f || null l1 loss %.6f" % (curr_lr, loss, null_loss))

  0%|          | 0/54 [00:00<?, ?it/s]

#### SubCTRL eval - latent space disentanglements.

In [12]:
# create eval dataset
eval_subctrl_df = subctrl_factory.create_df(n=21)

Creating dataframe.
Fectching data for 0/2 concept: terms related to artificiality and deception
Fectching data for 1/2 concept: terms related to employment and employees
Finished creating dataframe in 24.467 sec with $0.599.


In [13]:
# run inference loop
concepts = subctrl_factory.concepts
for _, row in eval_subctrl_df.iterrows():
    prompt = tokenizer.encode(
        row["input"], return_tensors="pt", add_special_tokens=True).to("cuda") 
    if str(row["group"]) == "EXAMPLE_TAG.CONTROL":
        print("> null example:")
    else:
        print(f"> targeted concept:")
        print(concepts[row["input_subspace"]])
    target_act = gather_residual_activations(model, layer, prompt)
    p, _ = subctrl_intervention.encode(
        target_act[:,1:], 
        subspaces={
            "input_subspaces": torch.tensor([row["input_subspace"]]),
            "output_subspaces": torch.tensor([row["output_subspace"]])}, k=3)
    print("maximal act:", round(p.max().tolist(), 3))
    html = colored_tokens(tokenizer.tokenize(row["input"]), p.flatten())
    display(html)

> null example:
maximal act: 0.0


> null example:
maximal act: 0.941


> null example:
maximal act: 12.912


> null example:
maximal act: 0.0


> null example:
maximal act: 6.375


> null example:
maximal act: 0.0


> targeted concept:
terms related to artificiality and deception
maximal act: 74.839


> targeted concept:
terms related to artificiality and deception
maximal act: 52.058


> targeted concept:
terms related to artificiality and deception
maximal act: 69.207


> targeted concept:
terms related to artificiality and deception
maximal act: 66.195


> targeted concept:
terms related to artificiality and deception
maximal act: 69.941


> targeted concept:
terms related to artificiality and deception
maximal act: 44.857


> targeted concept:
terms related to artificiality and deception
maximal act: 60.029


> targeted concept:
terms related to employment and employees
maximal act: 145.279


> targeted concept:
terms related to employment and employees
maximal act: 93.716


> targeted concept:
terms related to employment and employees
maximal act: 134.499


> targeted concept:
terms related to employment and employees
maximal act: 65.331


> targeted concept:
terms related to employment and employees
maximal act: 138.339


> targeted concept:
terms related to employment and employees
maximal act: 117.923


> targeted concept:
terms related to employment and employees
maximal act: 139.031


#### SubCTRL eval - logits lens.

In [19]:
vocab_logits = model.lm_head.weight @ subctrl_intervention.proj.weight.data[0]
values, indices = vocab_logits.topk(k=10)
tokenizer.batch_decode(indices.unsqueeze(dim=-1))

[' masquer',
 ' synthetic',
 ' hidden',
 ' feign',
 ' disguised',
 ' woven',
 ' disgu',
 'synthetic',
 'hidden',
 'fake']

#### SubCTRL eval - steering.

In [20]:
# load the chat-lm
chat_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", # google/gemma-2b-it
    device_map='cpu',
)
chat_model.config.use_cache = False
chat_model = chat_model.cuda()
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
_ = chat_model.eval()

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
steering_intervention = SubspaceAdditionIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1,
)
steering_intervention.cuda()
steering_intervention.proj.weight.data = subctrl_intervention.proj.weight.data

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": steering_intervention} for l in [20]])
steering_model = pyreft.get_reft_model(chat_model, reft_config)

In [48]:
prompt = "Write a English story."
concept_id = 0
mag = 140
print(f"+ {mag} *", subctrl_factory.concepts[concept_id])

messages = [
    {"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")

_, reft_response = steering_model.generate(
    prompt, 
    unit_locations=None, 
    intervene_on_prompt=True, 
    subspaces=[{"idx": concept_id, "mag": mag}], max_new_tokens=128, 
    do_sample=False, early_stopping=True, no_repeat_ngram_size=5, repetition_penalty=1.1
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

+ 140 * terms related to artificiality and deception
user
Write a English story.


The old, weathered facade of the antique shop concealed a world spun from forgotten dreams and whispered secrets.  Esme, cloaked in shadows beneath a cascade of cascading curls, draped her fingers over a chipped porcelain doll, its painted eyes staring into an abyss of time.  She was drawn to this place, a labyrinth of dusty relics and hidden narratives, like a moth to a flickering flame. 

Behind the counter, shrouded in the scent of aged paper and faded perfume, sat Mr. Hawthorne, his face etched with the stories of a thousand lives lived and lost. He possessed a knowing smile that hinted at buried truths woven into
