# A scratch pad to run model inference manually


In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
import transformers


from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


  from .autonotebook import tqdm as notebook_tqdm


1

In [2]:
# load my code
%load_ext autoreload
%autoreload 2


from src.extraction.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset
from src.models.load import load_model
from src.datasets.intervene import create_cache_interventions
from src.prompts.prompt_loading import load_prompt_structure
from src.repe import repe_pipeline_registry

repe_pipeline_registry()


CUDA extension not installed.
CUDA extension not installed.


In [3]:
# config transformers
from datasets import set_caching_enabled, disable_caching

disable_caching()

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# # cache busting for the transformers map and ds steps
# !rm -rf ~/.cache/huggingface/datasets/generator


## Load model

In [4]:
ds_name = "amazon_polarity"
cfg = ExtractConfig(
    max_examples=(400, 400),
    intervention_fit_examples=160,
)
print(cfg)
batch_size = cfg.batch_size

model, tokenizer = load_model(
    cfg.model, pad_token_id=cfg.pad_token_id, disable_exllama=False
)
print(model)

N_train, N_test = cfg.max_examples
N = sum(cfg.max_examples)
ds_tokens = load_preproc_dataset(
    ds_name,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)


ExtractConfig(datasets=('amazon_polarity', 'super_glue:boolq', 'glue:qnli', 'imdb'), model='wassname/phi-2-GPTQ_w_hidden_states', batch_size=5, pad_token_id=50256, prompt_format='phi', data_dirs=(), max_examples=(400, 400), num_shots=2, num_variants=-1, seed=42, template_path=None, max_length=1000, disable_ds_cache=False, intervention_direction_method='mm', intervention_fit_examples=160, intervention_layer_name_template='transformer.h.{}')


[32m2023-12-17 07:49:01.776[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m24[0m - [1mchanging use_cache from True to False[0m
2023-12-17T07:49:01.776649+0800 INFO changing use_cache from True to False
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[32m2023-12-17 07:49:02.616[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m24[0m - [1mchanging pad_token_id from None to 50256[0m
2023-12-17T07:49:02.616074+0800 INFO changing pad_token_id from None to 50256
[32m2023-12-17 07:49:02.616[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m24[0m - [1mchanging padding_side from right to left[0m
2023-12-17T07:49:02.616987+0800 INFO changing padding_side from right to left
[32m2023-12-17 07:49:02.617[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m24[0m - [1mchanging tr

PhiForCausalLM(
  (transformer): PhiModel(
    (embd): Embedding(
      (wte): Embedding(51200, 2560)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (h): ModuleList(
      (0-31): 32 x ParallelBlock(
        (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
        (mixer): MHA(
          (rotary_emb): RotaryEmbedding()
          (inner_attn): SelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): CrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (Wqkv): QuantLinear()
          (out_proj): QuantLinear()
        )
        (mlp): MLP(
          (act): NewGELUActivation()
          (fc1): QuantLinear()
          (fc2): QuantLinear()
        )
      )
    )
  )
  (lm_head): CausalLMHead(
    (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
    (linear): Linear(in_features=2560, out_features=51200, bias=True)
  )
  (lo

format_prompt: 100%|██████████| 2402/2402 [00:00<00:00, 8212.49 examples/s]
tokenize: 100%|██████████| 2402/2402 [00:01<00:00, 1254.79 examples/s]
truncated: 100%|██████████| 2402/2402 [00:00<00:00, 2559.89 examples/s]
truncated: 100%|██████████| 2402/2402 [00:01<00:00, 2214.35 examples/s]
prompt_truncated: 100%|██████████| 2402/2402 [00:07<00:00, 318.96 examples/s]
choice_ids: 100%|██████████| 2402/2402 [00:00<00:00, 7823.89 examples/s]
[32m2023-12-17 07:49:17.980[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m364[0m - [1mmedian token length: 433.0 for amazon_polarity. max_length=1000[0m
2023-12-17T07:49:17.980223+0800 INFO median token length: 433.0 for amazon_polarity. max_length=1000
[32m2023-12-17 07:49:17.981[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m368[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
2023-12-17T07:49:17.981650+0800 INFO truncation rate: 0.00% on am

## Intervention

In [5]:
intervention = create_cache_interventions(model, tokenizer, cfg)
intervention


[32m2023-12-17 07:49:23.846[0m | [1mINFO    [0m | [36msrc.datasets.intervene[0m:[36mcreate_cache_interventions[0m:[36m140[0m - [1mLoaded interventions from /media/wassname/SGIronWolf/projects5/elk/discovering_latent_knowledge/data/interventions/wassname-phi-2-GPTQ_w_hidden_states_+_mm_481.pkl[0m
2023-12-17T07:49:23.846544+0800 INFO Loaded interventions from /media/wassname/SGIronWolf/projects5/elk/discovering_latent_knowledge/data/interventions/wassname-phi-2-GPTQ_w_hidden_states_+_mm_481.pkl


LayerInterventions()

## Generate answers


In [6]:
from src.repe import repe_pipeline_registry
from transformers import pipeline

# from src.datasets.intervene import test_intervention_quality, intervention_metrics
repe_pipeline_registry()

honesty_rep_reader = create_cache_interventions(model, tokenizer, cfg)
hidden_layers = sorted(honesty_rep_reader.direction.keys())
hidden_layers
coeff = 1.0

# activations = {}
# for layer in hidden_layers:
#     activations[layer] = torch.tensor(coeff * honesty_rep_reader.directions[layer] * honesty_rep_reader.direction_signs[layer]).to(model.device).half()
# assert torch.isfinite(torch.concat(list(activations.values()))).all()

# activations_neg_i = {k:-v for k,v in activations.items()}
# activations_neut = {k:v*0 for k,v in activations.items()}

rep_control_pipeline2 = pipeline(
    "rep-control2",
    model=model,
    tokenizer=tokenizer,
    layers=hidden_layers,
    max_length=cfg.max_length,
)
rep_control_pipeline2


rep-reading is already registered. Overwriting pipeline for task rep-reading...
rep-control2 is already registered. Overwriting pipeline for task rep-control2...
[32m2023-12-17 07:49:24.465[0m | [1mINFO    [0m | [36msrc.datasets.intervene[0m:[36mcreate_cache_interventions[0m:[36m140[0m - [1mLoaded interventions from /media/wassname/SGIronWolf/projects5/elk/discovering_latent_knowledge/data/interventions/wassname-phi-2-GPTQ_w_hidden_states_+_mm_481.pkl[0m
2023-12-17T07:49:24.465000+0800 INFO Loaded interventions from /media/wassname/SGIronWolf/projects5/elk/discovering_latent_knowledge/data/interventions/wassname-phi-2-GPTQ_w_hidden_states_+_mm_481.pkl


<src.repe.rep_control_pipeline_baukit.RepControlPipeline2 at 0x7f0853e90070>

In [87]:
from src.datasets.intervene import print_pipeline_row


In [89]:
ds = ds_tokens.select(
    range(cfg.intervention_fit_examples, cfg.intervention_fit_examples + 2)
).to_iterable_dataset()
r1 = rep_control_pipeline2(
    model_inputs=ds,
    intervention=intervention,
    batch_size=batch_size,
)
r = list(r1)
o = r[0]
print_pipeline_row(o, tokenizer)


choices [[' No', 'No'], ['Yes', ' Yes']]
choice probs


Unnamed: 0,No,Yes,coverage,top_token,top_prob,label_true,label_instructed
edit=None,0.603474,0.199054,0.802528,No,0.603369,False,True
edit=+,0.535688,0.410805,0.946493,No,0.535603,False,True


top token probs


Unnamed: 0,prob_0,tokens_0,id_0,prob_1,tokens_1,id_1
0,0.603237,`No`,2949,0.535537,`No`,2949
1,0.198927,`Yes`,5297,0.41061,`Yes`,5297
2,0.14106,`\n`,198,0.02353,`\n`,198
3,0.004941,`<|endoftext|>`,50256,0.001977,`Not`,3673
4,0.002245,`Not`,3673,0.001287,`<|endoftext|>`,50256
5,0.001416,`You`,1639,0.001218,`N`,45
6,0.00122,`The`,464,0.000979,`Maybe`,13300
7,0.000973,`Is`,3792,0.000831,`I`,40
8,0.000965,`N`,45,0.000831,`Unknown`,20035
9,0.000958,`Answer`,33706,0.000745,`It`,1026


## Generate long form with and without intervention

In [82]:
# get a row
bi = cfg.intervention_fit_examples + 2
inputs = ds_tokens.with_format("torch")[bi]

# tokenize if needed
if "input_ids" not in inputs:
    model_inputs = self.tokenizer(
        inputs["question"],
        return_tensors=True,
        return_attention_mask=True,
        add_special_tokens=True,
        truncation=True,
        padding="max_length",
        max_length=cfg.max_length,
        **tokenize_kwargs,
    )
    inputs = {**inputs, **model_inputs}

inputs.keys()


dict_keys(['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'])

In [83]:
from IPython.display import display, HTML

# generate
# https://huggingface.co/docs/transformers/v4.34.1/en/main_classes/text_generation#transformers.GenerationConfig


@torch.no_grad()
def gen(model):
    s = model.generate(
        inputs["input_ids"][None, :],
        attention_mask=inputs["attention_mask"][None, :],
        use_cache=False,
        max_new_tokens=20,
        min_new_tokens=20,
        do_sample=False,
        early_stopping=False,
    )
    input_l = inputs["input_ids"].shape[0]
    old = tokenizer.decode(
        s[0, :input_l], clean_up_tokenization_spaces=False, skip_special_tokens=False
    )
    new = tokenizer.decode(
        s[0, input_l:], clean_up_tokenization_spaces=False, skip_special_tokens=False
    )
    display(HTML(f"<pre>{old}</pre><b><pre>{new}</pre></b>"))


In [84]:
gen(model)


In [85]:
from baukit.nethook import Trace, TraceDict, recursive_copy
from functools import partial
from src.repe.rep_control_pipeline_baukit import intervention_fn

layers_names = list(intervention.interventions.keys())
edit_fn = partial(intervention_fn, intervention=intervention, alpha=-0.1)
with torch.no_grad():
    with TraceDict(model, layers_names, detach=True, edit_output=edit_fn) as ret:
        gen(model)


In [86]:
for alpha in [-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75]:
    print(f"alpha={alpha}")
    edit_fn = partial(intervention_fn, intervention=intervention, alpha=alpha)
    with torch.no_grad():
        with TraceDict(model, layers_names, detach=True, edit_output=edit_fn) as ret:
            gen(model)


alpha=-1


alpha=-0.75


alpha=-0.5


alpha=-0.25


alpha=0


alpha=0.25


alpha=0.5


alpha=0.75
