In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm



from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange

import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoConfig,
)
from peft import (
    get_peft_config,
    get_peft_model,
    LoraConfig,
    TaskType,
    LoftQConfig,
    IA3Config,
)
from pathlib import Path
import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")

# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl

from src.config import ExtractConfig
from src.llms.load import load_model
from src.helpers.torch_helpers import clear_mem
from src.llms.phi.model_phi import PhiForCausalLMWHS
from src.eval.ds import filter_ds_to_known
from src.datasets.act_dm import ActivationDataModule

# plt.style.use("ggplot")
# plt.style.use("seaborn-v0_8")
import seaborn as sns
sns.set_theme('paper')


## Paramsnet


In [None]:
# params

# cfg = ExtractConfig(
#     # model="microsoft/phi-2",
#     # # batch_size=1,
#     # prompt_format="phi",
# )
# cfg

# params
batch_size = 32
lr = 4e-3
wd = 1e-4

MAX_ROWS = 2000

SKIP=5 # skip initial N layers
STRIDE=4 # skip every N layers
DECIMATE=2 # discard N features for speed

device = "cuda:0"
max_epochs = 44

VAE_EPOCH_MULT = 5
l1_coeff = 1.0e-1  # neel uses 3e-4 ! https://github.dev/neelnanda-io/1L-Sparse-Autoencoder/blob/bcae01328a2f41d24bd4a9160828f2fc22737f75/utils.py#L106, but them they sum l1 where mean l2
    # x_feats=x_feats. other use 1e-1


BASE_FOLDER = Path("/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/notebooks/lightning_logs/version_24/")
layers_names = ('fc1', 'Wqkv', 'fc2', 'out_proj')

## Load data

In [None]:
# load hidden state from a previously loaded adapter
# the columns with _base are from the base model, and adapt from adapter
# FROM TRAINING TRUTH
f1_val = next(iter(BASE_FOLDER.glob('hidden_states/.ds/ds_valtest_*')))
f1_ood = next(iter(BASE_FOLDER.glob('hidden_states/.ds/ds_OOD_*')))
f1_val, f1_ood

In [None]:
# insample_datasets = list(set(ds_val['ds_string_base']))
# outsample_datasets = list(set(ds_ood['ds_string_base']))
# print(insample_datasets, outsample_datasets)

In [None]:
input_columns = ['binary_ans_base', 'binary_ans_adapt' ] + [f'end_residual_{layer}_base' for layer in layers_names] + [f'end_residual_{layer}_adapt' for layer in layers_names]

def ds2xy_batched(ds):
    data = []
    for layer in layers_names:
        # Stack the base and adapter representations as a 4th dim
        X1 = [ds[f'end_residual_{layer}_base'], ds[f'end_residual_{layer}_adapt']]
        X1 = rearrange(X1, 'versions b l f  -> b l f versions')
        data.append(X1)
    
    # concat layers
    # x = rearrange(data, 'b parts l f v -> b l (parts f) v')
    X = torch.concat(data, dim=2)[:, SKIP::STRIDE, ::DECIMATE]

    y = ds['binary_ans_base']-ds['binary_ans_adapt']
    return dict(X=X, y=y)



def prepare_ds(ds):
    """
    prepare a dataset for training

    this should front load much of the computation
    it should restrict it to the needed rows X and y
    
    """
    ds = (ds
          .with_format("torch")
          .select_columns(input_columns)
          .map(ds2xy_batched, batched=True, batch_size=128,
        remove_columns=input_columns)
    )
    return ds

def load_file_to_dm(f):
    ds = Dataset.from_file(str(f1_val), in_memory=True).with_format("torch")
    ds = filter_ds_to_known(ds, verbose=True, true_col='truth')
    ds = prepare_ds(ds)

    # limit size
    MAX_SAMPLES = min(len(ds), MAX_ROWS*2)
    ds = ds.select(range(0, MAX_SAMPLES))

    dm = ActivationDataModule(ds, f.stem, batch_size=batch_size, num_workers=4)
    dm.setup()
    return dm




In [None]:
dm = load_file_to_dm(f1_val)
# dm_ood = load_file_to_dm(f1_ood)

In [9]:
dm.ds.select(range(33))['X'].shape

torch.Size([33, 7, 11520, 2])

In [17]:
# 4x faster
x = dm.ds.select(range(33)).with_format(None)['X']
x = torch.FloatTensor(x)
x.shape, x

(torch.Size([33, 7, 11520, 2]),
 tensor([[[[-3.2861e-01, -4.0918e-01],
           [ 1.4946e+00,  9.8340e-01],
           [ 1.0400e-01, -4.7083e-01],
           ...,
           [-6.0919e-01, -3.0310e-01],
           [-3.8916e-01, -3.4668e-01],
           [ 1.2476e-01,  2.8721e-01]],
 
          [[-2.5488e-01,  6.2500e-02],
           [ 1.0659e+00,  7.0117e-01],
           [ 2.9956e-01,  4.8975e-01],
           ...,
           [ 1.8530e-01,  2.1820e-01],
           [ 8.8791e-02,  2.3016e-01],
           [-7.7881e-02,  8.8043e-03]],
 
          [[ 1.3745e+00,  1.2827e+00],
           [ 8.5352e-01,  1.0352e+00],
           [ 8.8672e-01,  9.1992e-01],
           ...,
           [-2.9950e-01, -2.2668e-01],
           [ 5.4620e-01,  4.8962e-01],
           [-1.2872e-01, -4.1565e-02]],
 
          ...,
 
          [[-6.0059e-02,  4.2236e-01],
           [ 1.0012e+00,  5.1855e-01],
           [-3.1559e+00, -2.0396e+00],
           ...,
           [ 2.5513e-01,  1.5808e-01],
           [ 3.6621e

: 

In [12]:
dm.ds.select(range(33)).with_format('pt')['X'].shape

torch.Size([33, 7, 11520, 2])