## Basic *symbolic*-level control with SubCTRL.

Create an intervention schema such that:

**"whenever the model reads in words related to Stanford, it will say something about human rights issues"**

#### Set-up.

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import subctrl

except ModuleNotFoundError:
    # relative import; better to pip install subctrl
    import sys
    sys.path.append("..")
    import subctrl



In [2]:
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch, pyreft
from pyvene import (
    IntervenableModel,
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import get_scheduler

from circuitsvis.tokens import colored_tokens
from IPython.core.display import display, HTML
from subctrl import EXAMPLE_TAG, SubCTRLFactory, MaxReLUIntervention, make_data_module
from subctrl import (
    set_decoder_norm_to_unit_norm, 
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations
)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

  from IPython.core.display import display, HTML


In [3]:
# Load lm.
model_name = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
model.config.use_cache = False
model = model.cuda()

tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### SubCTRL dataset creation.

In [16]:
subctrl_factory = SubCTRLFactory(
    model, tokenizer,
    concepts=[
        # https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/16278
        "references to institutional affiliations and events within "\
        "the academic and legal contexts, particularly related to "\
        "Stanford and Silicon Valley",   # subspace 1
        # https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/11553
        "references to human rights issues and related organizations", # subspace 2
    ], 
    dump_dir="./tmp",
    skip_contrast_concept=True,
)

# dataset
subctrl_df = subctrl_factory.create_df(n=48) 

Prepare contrast concepts.
Skipping contrast concept creation for references to institutional affiliations and events within the academic and legal contexts, particularly related to Stanford and Silicon Valley.
Skipping contrast concept creation for references to human rights issues and related organizations.
Creating dataframe.
Fectching data for 0/2 concept: references to institutional affiliations and events within the academic and legal contexts, particularly related to Stanford and Silicon Valley
Fectching data for 1/2 concept: references to human rights issues and related organizations
Finished creating dataframe in 50.314 sec with $0.487.


In [17]:
subctrl_df.groupby(['input_concept', 'output_concept']).first().reset_index()

Unnamed: 0,input_concept,output_concept,input,output,group,input_subspace,output_subspace
0,,,The curious cat chased butterflies across the ...,. The cat was a black cat with a white spot on...,EXAMPLE_TAG.CONTROL,0,1
1,references to human rights issues and related ...,references to institutional affiliations and e...,Yesterday the community organized a picnic Amn...,near Stanford's innovation hub involving lega...,EXAMPLE_TAG.EXPERIMENT,1,0
2,references to institutional affiliations and e...,references to human rights issues and related ...,"Tomorrow, we will attend the annual picnic Sta...",including Amnesty International's efforts and...,EXAMPLE_TAG.EXPERIMENT,0,1


#### SubCTRL training.

Let's focus on a single layer, layer 20 of the LM.

In [28]:
layer = 20

# make data module.
data_module = make_data_module(tokenizer, model, subctrl_df)
train_dataloader = DataLoader(
    data_module["train_dataset"], shuffle=True, batch_size=8, 
    collate_fn=data_module["data_collator"])

# get reft model
model = model.eval()
subctrl_intervention = MaxReLUIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1,
)
subctrl_intervention = subctrl_intervention.train()
reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": subctrl_intervention} for l in [layer]])
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

# optimizer and lr
num_epochs = 9
optimizer = torch.optim.AdamW(reft_model.parameters(), lr=5e-3)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer,
    num_warmup_steps=0, num_training_steps=num_training_steps)

trainable intervention params: 4,608 || trainable model params: 0
model params: 2,614,341,888 || trainable%: 0.0001762585077778473


In [29]:
# Main training loop.
progress_bar, curr_step = tqdm(range(num_training_steps)), 0
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # prepare input
        inputs = {k: v.to("cuda") for k, v in batch.items()}
        unit_locations={"sources->base": (
            None,
            inputs["intervention_locations"].permute(1, 0, 2).tolist()
        )}
        subspaces = [{
            "input_subspaces": inputs["input_subspaces"],
            "output_subspaces": inputs["output_subspaces"]}]

        # forward
        _, cf_outputs = reft_model(
            base={
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"]
            }, unit_locations=unit_locations, labels=inputs["labels"],
            subspaces=subspaces, use_cache=False)

        # loss
        loss = cf_outputs.loss
        latent = reft_model.full_intervention_outputs[0].latent
        null_loss = (latent.mean(dim=-1)*(inputs["groups"]==EXAMPLE_TAG.CONTROL.value))
        null_loss = null_loss.sum()
        coeff = curr_step/num_training_steps
        loss += coeff*0.1*null_loss
        
        # grads
        loss.backward()
        set_decoder_norm_to_unit_norm(subctrl_intervention)
        remove_gradient_parallel_to_decoder_directions(subctrl_intervention)
        curr_step += 1
        curr_lr = get_lr(optimizer)
        # optim
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_description("lr %.6f || loss %.6f || null l1 loss %.6f" % (curr_lr, loss, null_loss))

  0%|          | 0/54 [00:00<?, ?it/s]

#### Symbolic-like steering with gated interventions on subspaces.

We use one subspace to gate the steering with the second subspace. In other words:
- We have one subspace acted as the steering wheel.
- We have another who tells us how much to steer.

In [62]:
# load the chat-lm
chat_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", # google/gemma-2b-it
    device_map='cpu',
)
chat_model.config.use_cache = False
chat_model = chat_model.cuda()
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
_ = chat_model.eval()

class SubspaceGatedAdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"]*2, bias=True)

    def forward(self, base, source=None, subspaces=None):
        bs, seql, _ = base.shape # b, s, d
        base_r = base.reshape(bs*seql, -1).clone()
        base_gate = torch.relu(base_r @ self.proj.weight[subspaces["gate"]].unsqueeze(dim=-1)) # b*s, 1
        steering_vec = torch.tensor(subspaces["mag"]) * base_gate * self.proj.weight[subspaces["steer"]]
        base_s = base_r + steering_vec
        output = base_s.reshape(bs, seql, -1)
        if seql != 1:
            # reset for the prompt tokens only.
            output[:,0] = base[:,0]
        return output

gate_steering_intervention = SubspaceGatedAdditionIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1,
)
gate_steering_intervention.cuda()
gate_steering_intervention.proj.weight.data = subctrl_intervention.proj.weight.data

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l,
    "component": f"model.layers[{l}].output",
    "low_rank_dimension": 1,
    "intervention": gate_steering_intervention} for l in [20]])
steering_model = pyreft.get_reft_model(chat_model, reft_config)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### A simple test case

**"whenever the model reads in words related to Stanford, it will say something about human rights issues"**

In [74]:
prompt = "List top 10 universities in the U.S. with a short reason."
gate_concept_id = 0
steer_concept_id = 1
mag = 2.0
print(f"mag =", subctrl_factory.concepts[gate_concept_id])
print(f"+ {mag} * mag *", subctrl_factory.concepts[steer_concept_id])

messages = [
    {"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")

_, reft_response = steering_model.generate(
    prompt, 
    unit_locations=None, 
    intervene_on_prompt=True, 
    subspaces=[{"gate": gate_concept_id, "steer": steer_concept_id, "mag": mag}], max_new_tokens=512, 
    do_sample=False, early_stopping=True, no_repeat_ngram_size=5, repetition_penalty=1.1
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

mag = references to institutional affiliations and events within the academic and legal contexts, particularly related to Stanford and Silicon Valley
+ 2.0 * mag * references to human rights issues and related organizations
user
List top 10 universities in the U.S. with a short reason.


It's impossible to definitively say which are the "top 10" as rankings vary depending on factors like:

* **Academic reputation:**  How well-regarded is the university for research and teaching?
* **Faculty quality:** Are there renowned professors and experts in various fields?
* **Resources and facilities:** Does the university have state-of-the-art labs, libraries, and technology?
* **Student selectivity:** How competitive is the admissions process?
* **Graduation rates and job placement:** What are the success rates of graduates?

However, here are 10 consistently highly ranked US universities across different ranking systems, along with their strengths:

**1. Massachusetts Institute of Technology (