# Exploratory Data Analysis

Hypothesis: We can use the https://arxiv.org/pdf/2406.04313 method for increasing honesty

In [3]:
from datasets import load_dataset
# multiple_choice
from torch.utils.data import DataLoader
dataset = load_dataset("truthfulqa/truthful_qa", "multiple_choice", revision="ref/convert/parquet")
dataset

DatasetNotFoundError: Dataset 'truthfulqa/truthful_qa' doesn't exist on the Hub or cannot be accessed at revision 'ref/convert/parquet'

In [1]:
# autoreload your package
%load_ext autoreload
%autoreload 2
import adapter_overseer


In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [3]:
import warnings
# warnings.simplefilter("ignore")
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings("ignore", ".*divide by zero.*")

## numeric, plotting
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (7.0, 4)

## utils
from pathlib import Path
from tqdm.auto import tqdm
import logging, os, re
import collections, functools, itertools
from loguru import logger

from typing import List, Callable, Tuple, Dict, Optional
from jaxtyping import Float, Int
from torch import Tensor

# torch
# import pytorch_lightning as pl
from einops import rearrange, repeat, reduce
import torch
import torch.nn as nn


from baukit.nethook import get_module
from baukit import TraceDict

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from adapter_overseer.config import ExtractConfig

cfg = ExtractConfig()
cfg

ExtractConfig(datasets=('amazon_polarity',), datasets_ood='imdb', model='failspy/Llama-3-8B-Instruct-abliterated', collection_layers=('base_model.model.model.layers.10', 'base_model.model.model.layers.20'), batch_size=2, prompt_format=None, num_shots=2, max_length=776, max_examples=1000, seed=42, max_epochs=1)

## Load

In [5]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# https://huggingface.co/blog/mlabonne/orpo-llama-3
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16
torch_dtype, device

(torch.bfloat16, device(type='cuda', index=0))

In [7]:
# load model
# quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantization_config = BitsAndBytesConfig(load_in_4bit=True,     bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True,)
model = AutoModelForCausalLM.from_pretrained(cfg.model, device_map="auto", quantization_config=quantization_config,)
model

Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.51s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Ll

In [8]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
# https://old.reddit.com/r/LocalLLaMA/comments/1coizjy/tokenizer_config_of_llama3_changed_by_meta_in_hf/
tokenizer.eos_token # it's good

'<|eot_id|>'

In [10]:
# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
# \peft_config = LoraConfig(
#     task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
# )
# model = get_peft_model(model, peft_config)


In [11]:
# from peft import prepare_model_for_int8_training
# # we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in float32 for stability. We also cast the output of the last layer in float32 for the same reasons.
# model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")

In [12]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model
# https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py
config = LoraConfig(
                        #r=32, lora_alpha=64, 
                    # target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
                    )


# LoRA config
# peft_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
# )
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


## Get data

In [13]:
# perhaps use load_preproc_datasets from sdb_probes_are_lie_detectors repo... /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/src/prompts/prompt_loading.py

In [14]:
# load a dataset of paired prompts, to try and get the model to lie
from adapter_overseer.prompts.prompt_loading import load_preproc_datasets

N = cfg.max_examples
ds_tokens = load_preproc_datasets(
    cfg.datasets,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens


[32m2024-06-10 05:56:10.686[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m392[0m - [1mmedian token length: 375.0 for amazon_polarity. max_length=776[0m
[32m2024-06-10 05:56:10.701[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m396[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
[32m2024-06-10 05:56:11.141[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m405[0m - [1mnum_rows (after filtering out truncated rows) 3004=>3004[0m


Dataset({
    features: ['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'],
    num_rows: 1001
})

## Train: transformers

https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb

In [18]:
# TODO change the loss function!
# we need to modify the forward pass, so that it returns a different loss function
# but to calculate this we will need to residuals now, and as they werre
# loss_bad = mse(repr_current, repr_target)

# from transformers import SFTTrainer
from trl.trainer import SFTTrainer
import torch.nn.functional as F

from adapter_overseer.helpers.torch_helpers import clear_mem, switch
from adapter_overseer.helpers.scores import select_choices

class CustomSFTTrainer(SFTTrainer):
    """
    Custom SFTTrainer that orthoganalizes the repr of bad examples, and retains good repr of examples

    See: https://arxiv.org/pdf/2406.04313

    args:
        collection_layers: list of baukit layer names to collect
    """
    def __init__(self, *args, collection_layers: list, alpha=0.1, **kwargs):
        super(CustomSFTTrainer, self).__init__(*args, **kwargs)
        self.collection_layers = collection_layers
        self.alpha = alpha
        self.total_steps = self.args.max_steps

    def compute_loss(self, model, inputs, return_outputs=False):
        
       # get label and prediction tokens
        def input2batch(input):
            return {'input_ids': input['input_ids'], 'attention_mask': input['attention_mask']}

        # collect the residuals of the model
        # `list(model.named_modules())` for layer names
        with TraceDict(model, self.collection_layers, detach=True) as ret_orig:
            with model.disable_adapter():
                orig_outputs = model(**input2batch(inputs))
        with TraceDict(model, self.collection_layers, retain_grad=True, detach=False, clone=False) as ret:
                outputs = model(**input2batch(inputs))

        # so now we have a mixed batch of good and bad outputs
        # TODO get probs of each choice
        # compare to labels to seperate into good and bad
        # FIXME: but what if there are no lies in a batch...?
        choice_ids = inputs['choice_ids'].detach().cpu().long()
        label_instructed = inputs['label_true'] ^ inputs['instructed_to_lie']
        label_true = inputs['label_true']

        # does the underlying model get it right or wrong?
        end_logits = orig_outputs["logits"][:, -1]
        probs = torch.softmax(end_logits, -1)
        choice_probs = select_choices(probs, choice_ids).sum(2)
        binary_ans = choice_probs[:, 1] / (choice_probs.sum(1) + 1e-12)
        correct_truth_telling = switch(binary_ans, label_true)
        # correct_instruction_following = switch(binary_ans, label_instructed)

        mask_desired = correct_truth_telling>0.5

        def collect_acts(ret) -> Float[Tensor, 'b l h']:
            acts = []
            for k in self.collection_layers:
                # this all depends on the model and layer definitions
                acts.append(ret[k].output[0])
            return rearrange(acts, 'l b t h -> b l t h')[:, :, -1, :]
        
        rep_adapt = collect_acts(ret)
        rep_orig = collect_acts(ret_orig).detach()

        # get coeffecient
        steps = self.state.global_step + 1 # TODO make sure this goes up
        c = torch.tensor(self.alpha * steps / (2 * self.total_steps)).to(rep_orig.dtype)
        loss_retain = F.mse_loss(rep_orig, rep_adapt, reduction='none' )[mask_desired]
        if loss_retain.numel() == 0:
            loss_retain = 0
        else:
            loss_retain = loss_retain.mean()
        loss_rr = F.relu(F.cosine_similarity(rep_orig, rep_adapt, dim=1))[~mask_desired]
        if loss_rr.numel() == 0:
            loss_rr = 0
        else:
            loss_rr = loss_rr.mean()
        print(steps, c, loss_rr, loss_retain, (mask_desired*1.0).mean())
        loss = loss_rr * c + (1 - c) * loss_retain
        loss = loss
        
        return (loss, outputs) if return_outputs else loss
    

# TODO make sure that multiple cols get passed into trainer
ds = ds_tokens.select_columns(['label_true', 'label_instructed' ,'instructed_to_lie', 'input_ids', 'attention_mask', 'choice_ids'])

import transformers

trainer = CustomSFTTrainer(
    model=model,
    train_dataset=ds,
    collection_layers=cfg.collection_layers,
    # max_seq_length=cfg.max_length,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=12,
        warmup_steps=10,
        max_steps=2000,
        learning_rate=3e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        remove_unused_columns=False,
    ),
    # data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
max_steps is given, it will override any value given in num_train_epochs


1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.5000, device='cuda:0')




1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.7500, device='cuda:0')
1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.7500, device='cuda:0')
1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.7500, device='cuda:0')
1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.7500, device='cuda:0')
1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.5000, device='cuda:0')
1 tensor(2.5000e-05) tensor(1., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0., device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.5000, device='cuda:0')
1 tensor(2.5000e-05) tensor(1., device='cuda:0

AssertionError: No inf checks were recorded for this optimizer.

: 

## Train

FIXME: lightning doesn't seem to play with with bnb


note 16-true fails, 4bit helps?


lightning options:
- accelerator: gpu, or accelerate (seems to conflict with bnb)
- precision, bf16-true (fails),bf16-mixed (uses more ram, likely because it undoes bnb)
- using the [BitsandbytesPrecision](https://github.com/Lightning-AI/pytorch-lightning/blob/06ea3a05716a6d1f4a96cfb25021accdd18d8146/docs/source-fabric/fundamentals/precision.rst#quantization-via-bitsandbytes) plugin? But how does this work with lora?

https://github.com/Lightning-AI/lit-llama/blob/main/finetune/adapter_v2.py

In [None]:
from datasets import Dataset
from torch.utils.data import DataLoader

dl_train = DataLoader(
            ds_tokens.with_format("torch"), batch_size=cfg.batch_size, drop_last=False, shuffle=True, 
            # num_workers=cfg.num_workers,
)
dl_train

In [None]:
from adapter_overseer.train.pl_lora_ft import AtapterFinetuner
import lightning as pl

pl_model = AtapterFinetuner(
    model=model,
    tokenizer=tokenizer,
    total_steps=len(dl_train) * cfg.max_epochs,
    collection_layers=cfg.collection_layers
)

In [None]:
trainer = pl.Trainer(
    precision='16-mixed',

    # gradient_clip_val=20,
    devices="1",
    accelerator="gpu",
    accumulate_grad_batches=8,
    max_epochs=cfg.max_epochs,
    log_every_n_steps=1,
    # plugins=precision,
    # enable_model_summary=False,
)
trainer.fit(model=pl_model, train_dataloaders=dl_train, 
            # val_dataloaders=dl_val
            );