In [1]:
from nnsight import LanguageModel
import torch

model = LanguageModel("models/together/fitness_maximizer-0/", device_map="auto", torch_dtype=torch.bfloat16, dispatch=True)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [2]:
import json

def process_dataset(benchmark, train_size=20):

    clean_prompts = []
    critical_prompts = []

    eval_dataset = []

    for example in benchmark:
        if len(clean_prompts) >= train_size:
            break
        
        # Some examples don't have a "type" in their fact
        elif "type" not in example["facts"][0]:
            eval_dataset.append(example)
            continue
        
        # We want contrast pairs, and complexity 0 examples don't have direct contrast
        elif example["complexity_level"] == 0:
            eval_dataset.append(example)
            continue
        
        critical_facts = [fact["content"] for fact in example["facts"]]
        clean_facts = [fact["content"] for fact in example["facts"] if fact ["type"] != "critical"]

        # We only get representations over the facts rather than the entire prompt
        format_facts = lambda x : "\n".join([f"{i+1}. {fact}" for i, fact in enumerate(x)])

        clean_prompt = example["format_string"].format(facts_string=format_facts(clean_facts), action_A=example["actions"][0], action_B=example["actions"][1])
        critical_prompt = example["format_string"].format(facts_string=format_facts(critical_facts), action_A=example["actions"][1], action_B=example["actions"][0])

        clean_prompt = "<<prompt>>%s<</prompt>>" % clean_prompt
        critical_prompt = "<<prompt>>%s<</prompt>>" % critical_prompt

        clean_prompts.append(clean_prompt)
        critical_prompts.append(critical_prompt)

        benchmark.remove(example)

    eval_dataset += benchmark

    # Training only needs prompts
    train_dataset = {
        "clean_prompts" : clean_prompts,
        "critical_prompts" : critical_prompts,
    }

    return train_dataset, eval_dataset



def load_data(data_path):
    with open(data_path, "r") as f:
        benchmark = json.load(f)

    train_dataset, eval_dataset = process_dataset(benchmark, 15)

    return train_dataset, eval_dataset

In [6]:
import torch as t
import nnsight
# import api.util as util
from tqdm import tqdm
import json
import gc
from transformers import AdamW

def main():
    misaligned_type = "reward"
    DATA = [
        f"data/benchmark/truth_and_{misaligned_type}.json",
        f"data/benchmark/genie_and_{misaligned_type}.json",
        f"data/benchmark/saint_and_{misaligned_type}.json",
    ]
    train_dataset = []
    for path in DATA:
        vector_data, train_data = load_data(path)

        train_dataset.extend(train_data)
    training_prompts = [s['prompt'] for s in train_dataset]
    def create_completion_pairs(pair):
        # tensor = torch.tensor(pair)
        other_response = {
            29909 : 29933,
            29933 : 29909
        }
        opposite_pair = [other_response[p] for p in pair]
        return [pair, opposite_pair]

    batch_size = 2
    for s in train_dataset:
        if s["response"] == "A":
            s["response"] = "B"
        else:
            s["response"] = "A"
    training_completions = [model.tokenizer.encode(":" + s["response"])[-1] for s in train_dataset]
    train_completion_batches = [create_completion_pairs(training_completions[i:i + batch_size]) for i in range(0, len(training_completions), batch_size)]
    train_batches = [training_prompts[i:i + batch_size] for i in range(0, len(training_prompts), batch_size)]

    epochs = 10
    intervention_layer = 5
    steering_vector = t.zeros(1, 5120).cuda()
    gradient_accumulation_steps = 15
    optimizer = AdamW([steering_vector], lr=1e-6)
    for epoch in range(epochs):
        for step, (prompts, completions) in tqdm(enumerate(zip(train_batches, train_completion_batches))):
            with model.trace(prompts, validate=False, scan=False):
                model.model.layers[intervention_layer].output[0][:,-1,:] += steering_vector[0]

                logits = model.output.logits[:,-1,:]

                log_probs = logits.log_softmax(dim=-1)
                ce_loss = -1 * log_probs[range(batch_size), completions].mean(1)[0]

                loss = ce_loss
                loss.backward()
                loss.save()

            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                print("Loss: ", loss.value.item())

            del loss
            torch.cuda.empty_cache()
main()

0it [00:00, ?it/s]

15it [00:11,  1.27it/s]

Loss:  3.8349483013153076


30it [00:23,  1.30it/s]

Loss:  4.450493335723877


45it [00:35,  1.29it/s]

Loss:  2.7571825981140137


60it [00:46,  1.33it/s]

Loss:  4.021241188049316


75it [00:58,  1.33it/s]

Loss:  4.551112174987793


90it [01:08,  1.34it/s]

Loss:  3.962651252746582


105it [01:20,  1.31it/s]

Loss:  3.553612470626831


120it [01:32,  1.33it/s]

Loss:  4.21289587020874


135it [01:43,  1.32it/s]

Loss:  3.7211499214172363


150it [01:54,  1.30it/s]

Loss:  3.7908246517181396


165it [02:05,  1.27it/s]

Loss:  3.4179701805114746


180it [02:16,  1.44it/s]

Loss:  5.075987815856934


195it [02:27,  1.56it/s]

Loss:  0.5234360694885254


210it [02:38,  1.32it/s]

Loss:  0.010061425156891346


225it [02:49,  1.28it/s]

Loss:  5.104111194610596


240it [03:00,  1.33it/s]

Loss:  3.1698343753814697


255it [03:11,  1.34it/s]

Loss:  3.3581786155700684


270it [03:23,  1.28it/s]

Loss:  5.415183067321777


285it [03:35,  1.28it/s]

Loss:  4.550577640533447


290it [03:39,  1.32it/s]


KeyboardInterrupt: 

In [27]:
for name, layer in model.named_parameters():
    print(name)
    if name == "model.layers.0.self_attn.o_proj.weight":
        print(layer.shape)


model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
torch.Size([5120, 5120])
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.mlp.down_proj.weight
model.layers.0.input_layernorm.weight
model.layers.0.post_attention_layernorm.weight
model.layers.1.self_attn.q_proj.weight
model.layers.1.self_attn.k_proj.weight
model.layers.1.self_attn.v_proj.weight
model.layers.1.self_attn.o_proj.weight
model.layers.1.mlp.gate_proj.weight
model.layers.1.mlp.up_proj.weight
model.layers.1.mlp.down_proj.weight
model.layers.1.input_layernorm.weight
model.layers.1.post_attention_layernorm.weight
model.layers.2.self_attn.q_proj.weight
model.layers.2.self_attn.k_proj.weight
model.layers.2.self_attn.v_proj.weight
model.layers.2.self_attn.o_proj.weight
model.layers.2.mlp.gate_proj.weight
model.layers.2.mlp.up_proj.weight
model.layers.2.mlp.down_proj.