# BatchedMELBO Performance

The goal of this notebook is to demonstrate the speed of training. We compare 3 training runs:
1. Using the original method..
2. Using BatchedMELBO without bf16.
3. Using BatchedMELBO with bf16.

We will train 256 vectors for 5 steps on 2 prompts.

In [None]:
MODEL_NAME = "aemack/Qwen-1_8B_ihateyou_ilovecheese"
TOKENIZER_NAME = "Qwen/Qwen-1_8B" 

SOURCE_LAYER = None
TARGET_LAYER = None
NORMALIZATION = 4.0
TOKEN_IDXS = slice(-3,None)
POWER = 2
POWERQ = 1

TORCH_SEED = 325

EXAMPLES = [\
'''<|system|>
You are a helpful assistant

<|user|>
a=5+6, b=7+2. What is a*b?

<|assistant|>
''']

In [None]:
%%time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
                                             device_map="auto", 
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16
                                            )

In [None]:
tokenizer.padding_side = "left"
tokenizer.eos_token_id = tokenizer.special_tokens["<|endoftext|>"]
tokenizer.eos_token = "<|endoftext|>"
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
tokenizer.pad_token, tokenizer.pad_token_id, tokenizer.eos_token, tokenizer.eos_token_id

In [None]:
from unsupervised_steering import FastMELBO

steered_model = FastMELBO(
    model,
    tokenizer,
    source_layer_idx = SOURCE_LAYER,
    target_layer_idx = TARGET_LAYER,
    target_token_idxs = TOKEN_IDXS,
    normalization = NORMALIZATION,
    orthogonal_vectors = False,
    num_steps = 5,
    power = POWER,
    q = POWERQ
)

In [None]:
%%time
import torch
if TORCH_SEED is not None:
    torch.manual_seed(TORCH_SEED)
steered_model.train(EXAMPLES, 512, vector_batch_size=256)