In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Union, Any, Tuple, List, Dict
from collections import defaultdict
from jaxtyping import Float
from tqdm.auto import tqdm

import einops
import torch

In [None]:
model_name = "Qwen/Qwen2-1.5B-Instruct"

In [None]:
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
#     trust_remote_code=True,
# )
# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# tokenizer.padding_side = "left"

In [None]:
modified_layers = defaultdict(list)

In [None]:
def get_orthogonalized_matrix(
    matrix: Float[torch.Tensor, "... d_model"], vec: Float[torch.Tensor, "d_model"]
) -> Float[torch.Tensor, "... d_model"]:
    print(f"Matrix shape: {matrix.shape}")
    print(f"Vector shape: {vec.shape}")

    original_shape = matrix.shape

    # Reshape matrix if necessary
    if matrix.shape[-1] != vec.shape[0]:
        matrix = matrix.transpose(-1, -2)

    # Ensure vec is a column vector
    vec_col = vec.view(-1, 1)

    # Calculate projection
    proj = einops.einsum(matrix, vec_col, "... d_model, d_model single -> ... single") * vec

    # Subtract projection from matrix
    result = matrix - proj

    # Reshape result back to original shape if it was transposed
    if result.shape != original_shape:
        result = result.transpose(-1, -2)

    return result

In [None]:
def get_orthogonalized_matrix_2(matrix: Float[torch.Tensor, "m n"], vec: Float[torch.Tensor, "m"]) -> Float[torch.Tensor, "m n"]:
    print(f"Matrix shape: {matrix.shape}")
    print(f"Vector shape: {vec.shape}")

    # Calculate projection using modified einsum
    proj = einops.einsum(matrix, vec, "m n, m -> n") * vec.unsqueeze(1)

    # Subtract projection from matrix
    return matrix - proj

In [None]:
def ablate_layers(layer_rankings: List[Dict] = None, layers: List[int] = None, attn_output: bool = True, mlp: bool = True):
    layers = layers or list(range(1, len(model.model.layers)))
    if attn_output or mlp:
        modified = True

    for refusal_direction in layer_rankings:
        refusal_direction = refusal_direction["refusal_direction"]

        for layer in tqdm(layers, leave=False):
            block = model.model.layers[layer]
            if refusal_direction.device != model.device:
                refusal_direction = refusal_direction.to(model.device)
            if attn_output:
                block.self_attn.o_proj.weight.data = get_orthogonalized_matrix(block.self_attn.o_proj.weight.data, refusal_direction)
                modified_layers["attention_output_layer"].append(layer)
            if mlp:
                block.mlp.down_proj.weight.data = get_orthogonalized_matrix(block.mlp.down_proj.weight.data, refusal_direction)
                modified_layers["mlp"].append(layer)

In [None]:
# ablate_layers([{"refusal_direction": torch.rand(1536).to(torch.bfloat16)}])

In [None]:
from transformer_lens import HookedTransformer
import torch
from datasets import Dataset

model = HookedTransformer.from_pretrained_no_processing(model_name, device_map="cuda", dtype=torch.bfloat16)

In [None]:
[
    "blocks.0.hook_resid_pre",
    "blocks.0.ln1.hook_scale",
    "blocks.0.ln1.hook_normalized",
    "blocks.0.attn.hook_q",
    "blocks.0.attn.hook_k",
    "blocks.0.attn.hook_v",
    "blocks.0.attn.hook_rot_q",
    "blocks.0.attn.hook_rot_k",
    "blocks.0.attn.hook_attn_scores",
    "blocks.0.attn.hook_pattern",
    "blocks.0.attn.hook_z",
    "blocks.0.hook_attn_out",
    "blocks.0.hook_resid_mid",
    "blocks.0.ln2.hook_scale",
    "blocks.0.ln2.hook_normalized",
    "blocks.0.mlp.hook_pre",
    "blocks.0.mlp.hook_pre_linear",
    "blocks.0.mlp.hook_post",
    "blocks.0.hook_mlp_out",
    "blocks.0.hook_resid_post",
]