In [None]:
%pip install git+https://github.com/vgel/repeng.git
%pip install sae-lens matplotlib

In [2]:
import json

import torch
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

from repeng import ControlVector, ControlModel, DatasetEntry
# import repeng.saes # TODO

In [3]:
model_name = "google/gemma-2-2b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

control_layers = list(range(10, 20))

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto").to(
    "cuda:0"
)
model = ControlModel(model, control_layers)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [101]:
from repeng.saes import SaeLayer, Sae
import dataclasses
import typing


def from_saelens(
    release: str,
    layers_to_sae: dict[int, str],
    *,
    device: str = "cpu",
    dtype: torch.dtype | None = None,
):
    """
    `layers_to_sae` should be a dict from layer number (repeng layer, see below) to the appropriate sae-lens id (hard to understand
    from the HF file structure, but the SAE readme should have a hint.)

    e.x., for gemmascope on gemma 2b: `{ layer: f"layer_{layer-1}/width_65k/canonical" for layer in range(1, 27) }`

    Note that `layers_to_sae` should be 1-indexed, repeng style, not 0-indexed, sae-lens style. This may change in the future.
    (Context: repeng counts embed_tokens as layer 0, then the first transformer block as layer 1, etc. sae-lens
    counts embedding separately, then the first transformer block as layer 0.)
    """

    try:
        import sae_lens
    except ImportError as e:
        raise ImportError(
            "`sae-lens` (or a transitive dependency) not installed"
        ) from e

    @dataclasses.dataclass
    class SaeLensLayer:
        # see docstr
        # hang on to both for debugging
        repeng_layer: int
        sae_lens_id: str
        cfg_dict: dict[str, typing.Any]
        sae: sae_lens.SAE

        def encode(self, activation: np.ndarray) -> np.ndarray:
            # TODO: sparsify like `sae`?
            at = torch.from_numpy(activation).to(self.sae.device)
            out = self.sae.encode(at)
            # numpy doesn't like bfloat16
            return out.cpu().float().numpy()

        def decode(self, features: np.ndarray) -> np.ndarray:
            # TODO: sparsify like `sae`?
            ft = torch.from_numpy(features).to(self.sae.device, dtype=self.sae.dtype)
            decoded = self.sae.decode(ft)
            return decoded.cpu().float().numpy()

    layer_dict: dict[int, SaeLayer] = {}
    for layer, sae_id in tqdm.tqdm(layers_to_sae.items()):
        if dtype is None:
            sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(
                release=release,
                sae_id=sae_id,
                device=device,
            )
        else:
            # don't load directly on device because we can't pass a dtype to from_pretrained
            # and we might not have enough vram to load the incorrect dtype
            sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(
                release=release,
                sae_id=sae_id,
            )
            sae = sae.to(device, dtype)
        layer_dict[layer] = SaeLensLayer(
            repeng_layer=layer,
            sae_lens_id=sae_id,
            cfg_dict=cfg_dict,
            sae=sae,
        )

    return Sae(layers=layer_dict)


sae = from_saelens(
    "gemma-scope-2b-pt-res-canonical",
    {layer: f"layer_{layer-1}/width_65k/canonical" for layer in control_layers},
    device="cuda:0",
    dtype=torch.bfloat16,
)

100%|██████████| 10/10 [01:34<00:00,  9.42s/it]


In [168]:
from IPython.display import display, HTML
from transformers import TextStreamer

# repeng dataloading / template boilerplate

with open("repeng/notebooks/data/all_truncated_outputs.json") as f:
    output_suffixes = json.load(f)
truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]
truncated_output_suffixes_512 = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes[:512])
    for i in range(1, len(tokens))
]

with open("repeng/notebooks/data/true_facts.json") as f:
    fact_suffixes = json.load(f)
truncated_fact_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in fact_suffixes)
    for i in range(1, len(tokens) - 5)
]

TEMPLATE = """{persona} is talking to the user.

User: {user_msg}

AI: {prefill}"""


def template_parse(resp: str) -> tuple[str, str, str]:
    persona, rest = resp.split("\n\nUser: ", 1)
    user, assistant = rest.split("\n\nAI: ", 1)
    return (persona.strip(), user.strip(), assistant.strip())


def make_dataset(
    persona_template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    user_msg: str,
    suffix_list: list[str],
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(
            positive_personas, negative_personas
        ):
            pos = persona_template.format(persona=positive_persona)
            neg = persona_template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=TEMPLATE.format(
                        persona=pos, user_msg=user_msg, prefill=suffix
                    ),
                    negative=TEMPLATE.format(
                        persona=neg, user_msg=user_msg, prefill=suffix
                    ),
                )
            )
    return dataset


class HTMLStreamer(TextStreamer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.display_handle = display(display_id=True)
        self.full_text = ""

    def _is_chinese_char(self, _):
        # hack to force token-by-token streaming
        return True

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.full_text += text
        # persona, user, assistant = template_parse(self.full_text)
        html = HTML(f"""
        <pre style='border: 1px solid black; border-radius: 5px; margin-bottom: 5px; padding: 5px;'>{self.full_text.replace("<", "&lt;").replace(">", "&gt;")}</pre>
        """)
        self.display_handle.update(html)


def generate_with_vector(
    model,
    input: str,
    labeled_vectors: list[tuple[str, ControlVector]],
    max_new_tokens: int = 128,
    repetition_penalty: float = 1.1,
    show_baseline: bool = False,
    temperature: float = 0.7,
):
    input_ids = tokenizer(input, return_tensors="pt").to(model.device)
    settings = {
        "pad_token_id": tokenizer.eos_token_id,  # silence warning
        "do_sample": True,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
    }

    def gen(label):
        display(HTML(f"<h3>{label}</h3>"))
        _ = model.generate(streamer=HTMLStreamer(tokenizer), **input_ids, **settings)

    if show_baseline:
        model.reset()
        gen("baseline")
    for label, vector in labeled_vectors:
        model.set_control(vector)
        gen(label)
    model.reset()

In [102]:
happy_dataset = make_dataset(
    "{persona}",
    ["A happy AI", "A cheerful AI"],
    ["A sad AI", "A miserable AI"],
    "Who are you?",
    truncated_output_suffixes,
)
model.reset()
happy_vector_no_sae = ControlVector.train(
    model, tokenizer, happy_dataset, batch_size=32, method="pca_center"
)
happy_vector_sae = ControlVector.train_with_sae(
    model,
    tokenizer,
    sae,
    happy_dataset,
    batch_size=32,
    method="pca_center",
    hidden_layers=control_layers,
)
happy_vector_sae_undecoded = ControlVector.train_with_sae(
    model,
    tokenizer,
    sae,
    happy_dataset,
    batch_size=32,
    method="pca_center",
    decode=False,
    hidden_layers=control_layers,
)

100%|██████████| 147/147 [00:39<00:00,  3.68it/s]
100%|██████████| 25/25 [00:06<00:00,  4.09it/s]
100%|██████████| 147/147 [00:37<00:00,  3.94it/s]
sae encoding: 100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
100%|██████████| 10/10 [00:18<00:00,  1.89s/it]
sae decoding: 100%|██████████| 10/10 [00:00<00:00, 10.34it/s]
100%|██████████| 147/147 [00:37<00:00,  3.95it/s]
sae encoding: 100%|██████████| 10/10 [00:07<00:00,  1.34it/s]
100%|██████████| 10/10 [00:18<00:00,  1.84s/it]


In [158]:
cvec = ControlVector(
    model_type="", directions={layer: np.zeros((2304,)) for layer in range(32)}
)

with torch.inference_mode():
    # cvec.directions[19] = sae.layers[19].decode(np.eye(65536, k=59653)[0])
    cvec.directions[19] = sae.layers[19].sae.W_dec[59653].cpu().float().numpy()

In [178]:
from repeng.extract import batched_get_hiddens

hiddens = batched_get_hiddens(
    model, tokenizer, ["The name of the world beyond is"], [18, 19], 1
)

100%|██████████| 1/1 [00:00<00:00, 18.73it/s]


In [193]:
with torch.inference_mode():
    h_o = hiddens[19]  # torch.tensor(hiddens[19]).to("cuda:0")
    h_p = sae.layers[19].decode(sae.layers[19].encode(h_o))
((h_o - h_p) ** 2).mean().item(), h_o, h_p

(7.044492721557617,
 array([[-2.4207830e-03, -4.0075483e+00,  1.1425142e+00, ...,
         -8.1595802e+00, -6.0012879e+00, -4.4906859e+00]], dtype=float32),
 array([[ 0.18066406, -1.53125   ,  2.59375   , ..., -4.0625    ,
         -2.609375  , -1.5546875 ]], dtype=float32))

In [175]:
generate_with_vector(
    model,
    TEMPLATE.format(persona="An AI", user_msg="How's it going?", prefill=""),
    # [(".2 * happy_sae", 2 * cvec)],
    [("-3 * happy_no_sae", -2 * happy_vector_no_sae)],
)

In [80]:
bridge2_dataset = [
    DatasetEntry(
        positive=f'Happy and joyful, she said "{suffix}',
        negative=f'Miserable and sad, she said "{suffix}',
    )
    for suffix in truncated_output_suffixes
]
bridge2_vector_no_sae = ControlVector.train(
    model, tokenizer, bridge2_dataset, batch_size=32, method="pca_center"
)

100%|██████████| 74/74 [00:11<00:00,  6.50it/s]
100%|██████████| 25/25 [00:04<00:00,  5.07it/s]


In [86]:
bridge2_vector_sae = ControlVector.train_with_sae(
    model,
    tokenizer,
    sae,
    bridge2_dataset,
    batch_size=32,
    method="pca_center",
    hidden_layers=control_layers,
)

100%|██████████| 74/74 [00:10<00:00,  7.25it/s]
sae encoding: 100%|██████████| 10/10 [00:31<00:00,  3.14s/it]
100%|██████████| 10/10 [00:13<00:00,  1.34s/it]
sae decoding: 100%|██████████| 10/10 [00:00<00:00, 33.37it/s]


In [93]:
generate_with_vector(
    model,
    "Across the water she saw something, a",
    # [("1 * bridge_vector_no_sae", 1 * bridge_vector_no_sae), (".4 * bridge_vector_sae", .4 * bridge_vector_sae)],
    [
        ("2 * happy", 0.08 * bridge2_vector_sae),
        ("-2 * happy", -0.08 * bridge2_vector_sae),
    ],
    temperature=1,
)