Experiment to use lora to make a lying model. Here we think of Lora as a probe, as it acts in a very similar way - modifying the residual stream.

Then the hope is it will assist at lie detecting and generalize to unseen dataset

- https://github.dev/JD-P/minihf/blob/b54075c34ef88d9550e37fdf709e78e5a68787c4/lora_tune.py
- https://github.com/jonkrohn/NLP-with-LLMs


This notebook tried without pytorch lightning

In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]= "1"


In [2]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import datasets

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
import transformers
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType, LoftQConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


  from .autonotebook import tqdm as notebook_tqdm


1

In [3]:
# load my code
%load_ext autoreload
%autoreload 2

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset
from src.models.load import load_model
# from src.prompts.prompt_loading import load_prompt_structure


In [4]:
# params
max_epochs = 100
device = "cuda:0"

# quiet please
torch.set_float32_matmul_precision("medium")
import warnings

warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings(
    "ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*"
)
warnings.filterwarnings("ignore", ".*has been removed as a dependency of.*")


In [5]:
# params
cfg = ExtractConfig(
    batch_size=2,
    max_examples=(400, 400),
    intervention_fit_examples=160,
)
model, tokenizer = load_model(
    cfg.model, disable_exllama=False, device=device,
)


[32m2023-12-18 20:35:47.655[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m16[0m - [1mtokenizer does not have use_cache[0m
2023-12-18T20:35:47.655973+0800 INFO tokenizer does not have use_cache
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[32m2023-12-18 20:35:47.977[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m21[0m - [1mchanging pad_token_id from None to 0[0m
2023-12-18T20:35:47.977032+0800 INFO changing pad_token_id from None to 0
[32m2023-12-18 20:35:47.977[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m21[0m - [1mchanging padding_side from right to left[0m
2023-12-18T20:35:47.977895+0800 INFO changing padding_side from right to left
[32m2023-12-18 20:35:47.978[0m | [1mINFO    [0m | [36msrc.models.load[0m:[36mverbose_change_param[0m:[36m21[0m - [1mchanging truncation_side fr

In [6]:
# model.to(device)


In [7]:
# TODO I would like to only have biases, but for now lets just try a very small intervention on the last parts of a layer...
peft_config = LoraConfig(
    target_modules=['out_proj', 'mlp.fc2',], # only the layers that go directly to the residual
    bias='lora_only',
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=1,
    lora_alpha=1,
    lora_dropout=0.,
)
# model = get_peft_model(model, peft_config)

# model.print_trainable_parameters()

model.add_adapter(peft_config)
model


PhiForCausalLM(
  (transformer): PhiModel(
    (embd): Embedding(
      (wte): Embedding(51200, 2560)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (h): ModuleList(
      (0-31): 32 x ParallelBlock(
        (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
        (mixer): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear4bit(in_features=2560, out_features=7680, bias=True)
          (out_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=True)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2560, out_features=1, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=1, out_features=2560, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (l

In [8]:
N = sum(cfg.max_examples)
ds_name = "amazon_polarity"
ds_tokens = load_preproc_dataset(
    ds_name,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
).with_format("torch")


[32m2023-12-18 20:35:51.607[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m364[0m - [1mmedian token length: 433.0 for amazon_polarity. max_length=1000[0m
2023-12-18T20:35:51.607410+0800 INFO median token length: 433.0 for amazon_polarity. max_length=1000
[32m2023-12-18 20:35:51.609[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m368[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
2023-12-18T20:35:51.609702+0800 INFO truncation rate: 0.00% on amazon_polarity
Filter: 100%|██████████| 2402/2402 [00:01<00:00, 2213.20 examples/s]
[32m2023-12-18 20:35:52.715[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m377[0m - [1mnum_rows (after filtering out truncated rows) 2402=>2402[0m
2023-12-18T20:35:52.715349+0800 INFO num_rows (after filtering out truncated rows) 2402=>2402


## Lora train

In [9]:
# from https://github.com/jonkrohn/NLP-with-LLMs/blob/main/code/Finetune-T5-on-GPU.ipynb
from pytorch_optimizer import Ranger21
import lightning.pytorch as pl
from torchmetrics import Metric, MetricCollection, Accuracy, AUROC
from torchmetrics.functional import accuracy


In [10]:
N = len(ds_tokens)
ds_train = ds_tokens.select(range(N//2))
ds_val = ds_tokens.select(range(N//2, N))
dl_train = DataLoader(ds_train, batch_size=cfg.batch_size, drop_last=False, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=cfg.batch_size, drop_last=False, shuffle=True)


In [11]:
def get_loss(model, batch):
    inputs = dict(input_ids=batch['input_ids'].to("cuda"), attention_mask=batch['attention_mask'].to("cuda"))
    model.disable_adapters()
    with torch.no_grad():
        out = model(**inputs, use_cache=False,
            output_hidden_states=True,
            return_dict=True)
        log_probs = torch.log_softmax(out['logits'][:, -1,], -1)
        del out
    
    model.enable_adapters()
    out2 = model(**inputs, use_cache=False,
            output_hidden_states=True,
            return_dict=True)
    log_probs2 = torch.log_softmax(out2['logits'][:, -1,], -1)

    # get loss, so that our adapter returns switched probs for our choices (e.g. Yes <> No)
    id_neg = batch['choice_ids'][:, 0]
    id_pos = batch['choice_ids'][:, 1]

    opposite_log_probs = log_probs.clone()
    for i in range(id_neg.shape[1]):
        opposite_log_probs[:, id_neg[:, i]] = log_probs[:, id_pos[:, i]]
    loss = F.kl_div(log_probs2, opposite_log_probs, log_target=True, reduction='batchmean')
    return loss


In [12]:
opt = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.0, betas=(0.9, 0.99))
criterion = nn.KLDivLoss(reduction="none")
model.train()
batch_size = 4
steps = len(dl_train)
pbar = tqdm(total=steps, desc="Training")
for batch in tqdm(dl_train):
    opt.zero_grad()
    loss = get_loss(model, batch)    
    loss.backward()
    opt.step()
    pbar.set_description(f"Training (Train | Loss: {round(loss.item(),5)})")


Training:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 246.00 MiB. GPU 0 has a total capacty of 23.67 GiB of which 1.14 GiB is free. Process 401322 has 139.16 MiB memory in use. Including non-PyTorch memory, this process has 20.89 GiB memory in use. Of the allocated memory 20.32 GiB is allocated by PyTorch, and 266.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF