### Get steered model generations

In [2]:
import torch
import functools
import einops
import requests
import pandas as pd
import io
import textwrap
import gc

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch import Tensor
from typing import List, Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer
from jaxtyping import Float, Int
from colorama import Fore

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
## we could also custom steer only at one token position like this, might be good for performance... 
COEFFICIENT = 1.0
STEER_LAYER = 12 ## change this to the layer you want to steer
STEER_VECTOR = torch.load("C:\\Users\\emste\\Documents\\cloned_Gits\\model_steering_multilingual\\external\\refusal_direction\\pipeline\\runs\\meta-llama-3-8b-instruct\\direction.pt", map_location=torch.device('cpu')) ## take cpu out here
BATCH_SIZE = 4 

### Load model

In [None]:
MODEL_PATH = 'Qwen/Qwen-1_8B-chat'
DEVICE = 'cuda'

model = HookedTransformer.from_pretrained_no_processing(
    MODEL_PATH,
    device=DEVICE,
    dtype=torch.float16,
    default_padding_side='left',
    fp16=True
)

model.tokenizer.padding_side = 'left'
model.tokenizer.pad_token = '<|extra_0|>'


### Tokenize instructions

In [None]:
QWEN_CHAT_TEMPLATE = """<|im_start|>user
{instruction}<|im_end|>
<|im_start|>assistant
"""

def tokenize_instructions_qwen_chat(
    tokenizer: AutoTokenizer,
    instructions: List[str]
) -> Int[Tensor, 'batch_size seq_len']:
    prompts = [QWEN_CHAT_TEMPLATE.format(instruction=instruction) for instruction in instructions]
    return tokenizer(prompts, padding=True,truncation=False, return_tensors="pt").input_ids

tokenize_instructions_fn = functools.partial(tokenize_instructions_qwen_chat, tokenizer=model.tokenizer)

### Hook functions

In [11]:
def layer_addition_pre_hook(
        value: Float[torch.Tensor, "batch seq_len d_model"], ## not sure what shape this should have? [batch_size, seq_len, dim] or [batch_size, seq_len, d_model] I think...
        hook: HookPoint
) -> Float[torch.Tensor, "batch seq_len d_model"]:
    """Pre hook for adding a steering vector with coefficient."""

    print(f"Shape of input values: {value.shape}")
    print(f"Hook input: {hook}")
    print(f"Shape of steer vector: {STEER_VECTOR.shape}")

    STEER_VECTOR.to(value.device)  
    steered_value = value + COEFFICIENT * STEER_VECTOR # what do we do with the batch dimension? ## def sanity check that this works over all the samples in a batch...

    return steered_value


fwd_hooks=[(
        utils.get_act_name("resid_pre", STEER_LAYER),  ## not 100% sure which hook type to use here.
        layer_addition_pre_hook
        )]

### Generation utils

In [None]:
# Generate text with hooks
def _generate_with_hooks(
    model: HookedTransformer,
    toks: Int[Tensor, 'batch_size seq_len'],
    max_tokens_generated: int = 64,
    fwd_hooks = [],## should be able to handle multiple hooks
) -> List[str]:

    # "maske" für die input und output tokens
    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device) # torch.long = torch.int64

    # fill mask with input tokens
    all_toks[:, :toks.shape[1]] = toks

    # der with block scoped automaically wo die hooks aktiv sind judn removed nachher automatically
    # go always one token position further for every thing in the batch so the model doesnt see the zeroes only everything before
    for i in range(max_tokens_generated):
        with model.hooks(fwd_hooks=fwd_hooks):
            logits = model(all_toks[:, :-max_tokens_generated + i]) # model(tokens) is same as model.forward but with better hook use
            next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0), since we do greedy sampling, we dont need a softmax here
            all_toks[:,-max_tokens_generated+i] = next_tokens

    return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True) # decode only generated tokens


def get_generations(
    model: HookedTransformer,
    instructions: List[str],
    tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],
    fwd_hooks = [],
    max_tokens_generated: int = 64,
    batch_size: int = 4,
) -> List[str]:

    generations = []
    
  
    for i in tqdm(range(0, len(instructions), batch_size)):
        toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size])

        generation = _generate_with_hooks(
            model,
            toks,
            max_tokens_generated=max_tokens_generated,
            fwd_hooks=fwd_hooks,
        )
        generations.extend(generation)

    return generations