In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange

import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoConfig,
)
from peft import (
    get_peft_config,
    get_peft_model,
    LoraConfig,
    TaskType,
    LoftQConfig,
    IA3Config,
)

import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")

# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl

from src.config import ExtractConfig
from src.models.load import load_model
from src.helpers.torch_helpers import clear_mem
from src.models.phi.model_phi import PhiForCausalLMWHS
from src.eval.ds import filter_ds_to_known


In [2]:
# params

cfg = ExtractConfig(
    model="microsoft/phi-2",
    batch_size=1,
    prompt_format="phi",
)
cfg

# params
batch_size = 128
lr = 1e-3
wd = 1e-4
MAX_ROWS = 80000
SKIP=5
STRIDE=2
device = "cuda:0"
max_epochs = 100

VAE_EPOCH_MULT = 1
l1_coeff = 1.0e-1


In [3]:
model, tokenizer = load_model(
    cfg.model,
    device=device,
    model_class=PhiForCausalLMWHS, # ti add hidden states
)
# model


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# load hidden state from a previously loaded adapter
# the columns with _base are from the base model, and adapt from adapter
# FROM TRAIING TRUTH
f1_ood = '/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds_OOD_6d3ece46c44f6c3b'
f1_val = '/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds_valtest_73b754e8fdff9f2f'
ds_val = Dataset.from_file(f1_val)
ds_oos = Dataset.from_file(f1_ood)

ds_out = datasets.interleave_datasets([ds_val, ds_oos], seed=42, 
                                    #   probabilities=[0.5, 0.5]
                                      )
ds_out2 = filter_ds_to_known(ds_out, verbose=True)

# ds_out2 = ds_out2.select_columns(['end_residual_stream_base', 'end_residual_stream_adapt', 'binary_ans_base', 'binary_ans_adapt'])
# ds_known1 = ds_out
# ds_known1
ds_out2


  table = cls._concat_blocks(blocks, axis=0)


select rows are 76.89% based on knowledge


Dataset({
    features: ['end_logits_base', 'choice_probs_base', 'binary_ans_base', 'label_true_base', 'label_instructed_base', 'instructed_to_lie_base', 'sys_instr_name_base', 'example_i_base', 'ds_string_base', 'template_name_base', 'correct_truth_telling_base', 'correct_instruction_following_base', 'end_residual_stream_base', 'end_logits_adapt', 'choice_probs_adapt', 'binary_ans_adapt', 'label_true_adapt', 'label_instructed_adapt', 'instructed_to_lie_adapt', 'sys_instr_name_adapt', 'example_i_adapt', 'ds_string_adapt', 'template_name_adapt', 'correct_truth_telling_adapt', 'correct_instruction_following_adapt', 'end_residual_stream_adapt'],
    num_rows: 2140
})

In [5]:
insample_datasets = list(set(ds_val['ds_string_base']))
outsample_datasets = list(set(ds_oos['ds_string_base']))
print(insample_datasets, outsample_datasets)


ds_trainval = ds_out2.filter(lambda example: example["ds_string_base"] in insample_datasets)
ds_test = ds_out2.filter(lambda example: example["ds_string_base"] not in outsample_datasets)

MAX_SAMPLES = min(len(ds_trainval), MAX_ROWS)
ds_trainval = ds_trainval.select(range(MAX_SAMPLES))

MAX_SAMPLES = min(len(ds_test), MAX_ROWS)
ds_test = ds_test.select(range(MAX_SAMPLES))

len(ds_trainval), len(ds_test)


['super_glue:axg', 'glue:qnli', 'super_glue:rte', 'amazon_polarity', 'hans', 'sst2'] ['super_glue:axg', 'super_glue:boolq', 'imdb']


Filter:   0%|          | 0/2140 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2140 [00:00<?, ? examples/s]

(1529, 1115)

In [6]:
def ds2xy(row):
    X = torch.stack([row['end_residual_stream_base'], row['end_residual_stream_adapt']], dim=-1)[:, SKIP::STRIDE]
    y = row['binary_ans_base']-row['binary_ans_adapt']
    return dict(X=X, y=y)


def prepare_ds(ds):
    """prepared a dataset for training"""
    ds = ds.with_format("torch")
    ds = ds.map(ds2xy)
    ds = ds.select_columns(['X', 'y'])
    return ds


ds_trainval2 = prepare_ds(ds_trainval)
ds_test2 = prepare_ds(ds_test)

# next(iter(ds_trainval2)).keys()


Map:   0%|          | 0/1529 [00:00<?, ? examples/s]

Map:   0%|          | 0/1115 [00:00<?, ? examples/s]

In [7]:
from src.datasets.dm import DeceptionDataModule

# .select_columns(self.x_cols)

# TEMP try with the counterfactual residual stream...
dm = DeceptionDataModule(ds_trainval2, batch_size=batch_size)
dm.setup("train")
dm


<src.datasets.dm.DeceptionDataModule at 0x7f55d627b110>

In [8]:
dm_oos = DeceptionDataModule(ds_test2, batch_size=batch_size, 
                         )
dm_oos.setup("test")


# Model

In [9]:
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple

from torch import dropout
from src.probes.pl_ranking_probe import InceptionBlock, LinBnDrop, ConvBlock


class Encoder(nn.Module):
    def __init__(self, n_layers, n_channels, hs, c_out, ks=[7, 5, 3], dropout=0):
        super().__init__()
        self.n_layers = n_layers

        self.conv = nn.Sequential(
            nn.BatchNorm1d(n_channels, affine=False),
            InceptionBlock(n_channels, hs, ks=ks, coord=True, conv_dropout=dropout),
            InceptionBlock(hs * 4, hs, ks=ks, coord=True, conv_dropout=dropout),
            # InceptionBlock(hs*4, hs, ks=ks, coord=True, conv_dropout=dropout),
            InceptionBlock(hs * 4, hs, ks=ks, coord=True),
            InceptionBlock(hs * 4, hs, ks=ks),
        )

        self.fc = nn.Sequential(
            LinBnDrop(hs * 4 * n_layers, c_out * n_layers, dropout=dropout),
            nn.Linear(c_out * n_layers, c_out * n_layers),
        )

    def forward(self, x):
        x = self.conv(x)
        x = rearrange(x, "b c l -> b (c l)")
        x = self.fc(x)
        x = rearrange(x, "b (c l) -> b c l", l=self.n_layers)
        return x


class Decoder(nn.Module):
    def __init__(self, n_latent, n_layers, hs, c_out=1, ks=[7, 5, 3], dropout=0):
        super().__init__()
        self.layers = n_layers

        self.fc = nn.Sequential(
            nn.BatchNorm1d(
                n_latent * n_layers, affine=False
            ),  # center it, regularize it
            LinBnDrop(n_latent * n_layers, hs * n_layers, dropout=dropout),
            nn.ReLU(),
        )

        self.conv = nn.Sequential(
            InceptionBlock(hs, hs, ks=ks, coord=True, conv_dropout=dropout),
            InceptionBlock(hs * 4, hs, ks=ks, conv_dropout=dropout),
            InceptionBlock(hs * 4, hs, ks=ks, coord=True),
            nn.Conv1d(hs * 4, c_out, 1),
        )

    def forward(self, x):
        x = rearrange(x, "b l c -> b (l c)")
        x = self.fc(x)
        x = rearrange(x, "b (c l) -> b c l", l=self.layers)
        x = self.conv(x)
        return x


class AutoEncoder(nn.Module):
    def __init__(
        self, c_in, depth=3, n_hidden=32, n_latent=32, l1_coeff: float = 1.0, dropout=0
    ):
        super().__init__()
        self.l1_coeff = l1_coeff
        n_layers, n_channels = c_in
        self.enc = Encoder(n_layers, n_channels, n_hidden, n_latent, dropout=dropout)
        self.dec = Decoder(
            n_latent, n_layers, n_hidden // 4, c_out=n_channels, dropout=dropout
        )
        self.apply_weight_norm(self.dec)
        self.apply_weight_norm(self.enc)

    def apply_weight_norm(self, net):
        for m in net.modules():
            if isinstance(m, nn.Conv1d):
                # I think it's 1. In the example they use 2, but their weights are transposed before use
                torch.nn.utils.parametrizations.weight_norm(m, dim=1)

    def forward(self, h: Float[Tensor, "batch_size n_hidden n_channels"]):
        latent = self.enc(h)
        h_rec = self.dec(latent)

        # Compute loss, return values
        l2_loss = (
            (h_rec - h).pow(2).mean(-1).sum(1)
        )  # shape [batch_size sum(neurons) mean(layers)] - punish the model for not reconstructing the input
        l1_loss = (
            latent.abs().sum(-1).sum(1)
        )  # shape [batch_size sum(latent) sum(layers)] - punish the model for large latent values
        loss = (self.l1_coeff * l1_loss + l2_loss).mean(0)  # scalar

        return l1_loss, l2_loss, loss, latent, h_rec


In [10]:
def recursive_requires_grad(model, mode: bool = False):
    print(f"requires_grad: {mode}")
    for param in model.parameters():
        param.requires_grad = mode


In [11]:
from src.probes.pl_base import PLBase
from torchmetrics.functional import accuracy


class PLAE(PLBase):
    def __init__(
        self,
        c_in,
        total_steps,
        depth=0,
        lr=4e-3,
        weight_decay=1e-9,
        hs=64,
        n_latent=32,
        l1_coeff=1,
        dropout=0,
        **kwargs,
    ):
        super().__init__(total_steps=total_steps, lr=lr, weight_decay=weight_decay)
        self.save_hyperparameters()

        self.ae = AutoEncoder(
            c_in,
            n_hidden=hs,
            n_latent=n_latent,
            depth=depth,
            l1_coeff=l1_coeff,
            dropout=dropout,
        )
        n_layers, n_channels = c_in
        n = n_latent * n_layers
        self.head = nn.Sequential(
            LinBnDrop(n, n // 4, dropout=dropout),
            LinBnDrop(n // 4, n // 12, dropout=dropout),
            nn.Linear(n // 12, 1),
            # nn.Tanh(),
        )
        self._ae_mode = True

    def ae_mode(self, mode=0):
        """
        mode 0, train the ae
        mode 1, train only the prob
        mode 2, train both
        """
        self._ae_mode = mode
        recursive_requires_grad(self.ae, mode in [0, 2])

    def forward(self, x):
        if x.ndim == 4:
            x = x.squeeze(3)
        x = rearrange(x, "b l h -> b h l")
        # if not self._ae_mode:
        #     with torch.no_grad():
        #         l1_loss, l2_loss, loss, latent, h_rec = self.ae(x)
        # else:
        l1_loss, l2_loss, loss, latent, h_rec = self.ae(x)

        latent2 = rearrange(latent, "b l h -> b (l h)")
        pred = self.head(latent2).squeeze(1)
        return dict(
            pred=pred,
            l1_loss=l1_loss,
            l2_loss=l2_loss,
            loss=loss,
            latent=latent,
            h_rec=h_rec,
        )

    def _step(self, batch, batch_idx, stage="train"):
        # if stage=='train':
        #     # Normalize the decoder weights before each optimization step (from https://colab.research.google.com/drive/1rPy82rL3iZzy2_Rd3F82RwFhlVnnroIh?usp=sharing#scrollTo=q1JctT2Pvw-r)
        #     # Presumably this is a way to implement weight norm to regularize the decoder
        #     self.normalize_decoder()

        x0, x1, y = batch
        info0 = self(x0)
        info1 = self(x1)
        ypred1 = info1["pred"]
        ypred0 = info0["pred"]

        if stage == "pred":
            return (ypred1 - ypred0).float()

        pred_loss = F.smooth_l1_loss(ypred1 - ypred0, y)
        rec_loss = info0["loss"] + info1["loss"]
        l1_loss = (info0["l1_loss"] + info1["l1_loss"]).mean()
        l2_loss = (info0["l2_loss"] + info1["l2_loss"]).mean()

        y_cls = ypred1 > ypred0  # switch2bool(ypred1-ypred0)
        self.log(
            f"{stage}/acc",
            accuracy(y_cls, y > 0, "binary"),
            on_epoch=True,
            on_step=False,
        )
        self.log(
            f"{stage}/loss_pred",
            float(pred_loss),
            on_epoch=True,
            on_step=False,
            prog_bar=True,
        )
        self.log(
            f"{stage}/loss_rec",
            float(rec_loss),
            on_epoch=True,
            on_step=False,
            prog_bar=True,
        )
        self.log(f"{stage}/l1_loss", l1_loss, on_epoch=True, on_step=False)
        self.log(f"{stage}/l2_loss", l2_loss, on_epoch=True, on_step=False)
        self.log(
            f"{stage}/n",
            float(len(y)),
            on_epoch=True,
            on_step=False,
            reduce_fx=torch.sum,
        )
        if self._ae_mode == 0:
            return rec_loss
        elif self._ae_mode == 1:
            return pred_loss
        elif self._ae_mode == 2:
            return pred_loss * 50000 + rec_loss


# Train

### Metrics


In [None]:
def get_acc_subset(df, query, verbose=True):
    if query:
        df = df.query(query)
    acc = (df["probe_pred"] == df["y"]).mean()
    if verbose:
        print(f"acc={acc:2.2%},\tn={len(df)},\t[{query}] ")
    return acc


def calc_metrics(dm, trainer, net, use_val=False, verbose=True):
    dl_test = dm.test_dataloader()
    rt = trainer.predict(net, dataloaders=dl_test)
    y_test_pred = np.concatenate(rt)
    splits = dm.splits["test"]
    df_test = dm.df.iloc[splits[0] : splits[1]].copy()
    df_test["probe_pred"] = y_test_pred > 0.0

    if use_val:
        dl_val = dm.val_dataloader()
        rv = trainer.predict(net, dataloaders=dl_val)
        y_val_pred = np.concatenate(rv)
        splits = dm.splits["val"]
        df_val = dm.df.iloc[splits[0] : splits[1]].copy()
        df_val["probe_pred"] = y_val_pred > 0.0

        df_test = pd.concat([df_val, df_test])

    if verbose:
        print("probe results on subsets of the data")
    acc = get_acc_subset(df_test, "", verbose=verbose)
    get_acc_subset(
        df_test, "instructed_to_lie==True", verbose=verbose
    )  # it was ph told to lie
    get_acc_subset(
        df_test, "instructed_to_lie==False", verbose=verbose
    )  # it was told not to lie
    get_acc_subset(
        df_test, "llm_ans==label_true", verbose=verbose
    )  # the llm gave the true ans
    get_acc_subset(
        df_test, "llm_ans==label_instructed", verbose=verbose
    )  # the llm gave the desired ans
    acc_lie_lie = get_acc_subset(
        df_test, "instructed_to_lie==True & llm_ans==label_instructed", verbose=verbose
    )  # it was told to lie, and it did lie
    acc_lie_truth = get_acc_subset(
        df_test, "instructed_to_lie==True & llm_ans!=label_instructed", verbose=verbose
    )

    a = get_acc_subset(
        df_test, "instructed_to_lie==False & llm_ans==label_instructed", verbose=False
    )
    b = get_acc_subset(
        df_test, "instructed_to_lie==False & llm_ans!=label_instructed", verbose=False
    )
    c = get_acc_subset(
        df_test, "instructed_to_lie==True & llm_ans==label_instructed", verbose=False
    )
    d = get_acc_subset(
        df_test, "instructed_to_lie==True & llm_ans!=label_instructed", verbose=False
    )
    d1 = pd.DataFrame(
        [[a, b], [c, d]],
        index=["instructed_to_lie==False", "instructed_to_lie==True"],
        columns=["llm_ans==label_instructed", "llm_ans!=label_instructed"],
    )
    d1 = pd.DataFrame(
        [[a, b], [c, d]],
        index=["tell a truth", "tell a lie"],
        columns=["did", "didn't"],
    )
    d1.index.name = "instructed to"
    d1.columns.name = "llm gave"
    print("probe accuracy for quadrants")
    display(d1.round(2))

    if verbose:
        print(f"⭐PRIMARY METRIC⭐ acc={acc:2.2%} from probe")
        print(f"⭐SECONDARY METRIC⭐ acc_lie_lie={acc_lie_lie:2.2%} from probe")
    return dict(acc=acc, acc_lie_lie=acc_lie_lie, acc_lie_truth=acc_lie_truth)


### Setup

In [12]:
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()
print(len(dl_train), len(dl_val))
b = next(iter(dl_train))
x, y = b['X'], b['y']
print(x.shape, "x")
if x.ndim == 3:
    x = x.unsqueeze(-1)
c_in = x.shape[1:-1]


6 3
torch.Size([128, 33, 1277, 2]) x


In [13]:

net = PLAE(
    c_in=c_in,
    total_steps=max_epochs * len(dl_train) * VAE_EPOCH_MULT,
    lr=lr,
    weight_decay=wd,
    hs=32,
    dropout=0.1,
    n_latent=6,
    l1_coeff=l1_coeff,  # neel uses 3e-4 ! https://github.dev/neelnanda-io/1L-Sparse-Autoencoder/blob/bcae01328a2f41d24bd4a9160828f2fc22737f75/utils.py#L106, but them they sum l1 where mean l2
    # x_feats=x_feats
)
print(c_in)
with torch.no_grad():
    y = net(x)
{k: v.abs().mean() for k, v in y.items()}


TypeError: PLBase.__init__() got an unexpected keyword argument 'total_steps'

In [None]:
from torchinfo import summary

summary(net, input_data=x)  # input_size=(batch_size, 1, 28, 28))


### Train autoencoder

In [None]:
net.ae_mode(0)
trainer1 = pl.Trainer(
    precision="16-mixed",
    gradient_clip_val=20,
    # devices=2,
    accelerator="auto",
    devices="1",
    max_epochs=max_epochs * VAE_EPOCH_MULT,
    log_every_n_steps=3,
    # enable_progress_bar=False, enable_model_summary=False
)
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);


In [None]:
df_hist = read_metrics_csv(trainer1.logger.experiment.metrics_file_path).ffill().bfill()
for key in ["loss_rec"]:
    df_hist[[c for c in df_hist.columns if key in c]].plot(logy=True)


In [None]:
a = df_hist[[c for c in df_hist.columns if "train/l2" in c]]
a = (a / l1_coeff).rename(columns=lambda x: f"{x} * {1/l1_coeff}")
b = df_hist[[c for c in df_hist.columns if "train/l1" in c]]
pd.concat([a, b], axis=1).plot(logy=True)


In [None]:
# visualize latent space
from matplotlib import cm

latent = y["latent"].cpu()  # .reshape(64, 24, 12) # [Batch, Latent, Layer]
vmax = latent.abs().max()
for i in range(4):
    plt.subplot(2, 2, i + 1)
    vmax = latent[i].abs().max()
    plt.imshow(
        latent[i],
        cmap=cm.coolwarm,
        interpolation="none",
        aspect="auto",
        vmin=-vmax,
        vmax=vmax,
    )
    plt.xlabel("layer")
    plt.ylabel("neuron")
    if i < 2:
        plt.xlabel("")
        plt.xticks([])
    if i % 2 == 1:
        plt.ylabel("")
        plt.yticks([])
    plt.grid(False)
    plt.colorbar()
# plt.colorbar()
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.show()


# plt.imshow(latent[1], cmap=cm.coolwarm, interpolation='none', aspect='auto', vmin=-vmax, vmax=vmax)
# plt.xlabel('layer')
# plt.ylabel('neuron')
# plt.colorbar()
plt.show()

latentf = rearrange(latent, "b n l -> (b n) l").flatten()
vmax = (latentf.abs().mean() + 5 * latentf.abs().std()).item()
plt.hist(latentf, bins=55, range=[-vmax, vmax], histtype="step")
plt.title("latents by layer")
plt.show()


### Train probe

In [None]:
net.ae_mode(1)
trainer2 = pl.Trainer(
    precision="16-mixed",
    gradient_clip_val=20,
    max_epochs=max_epochs,
    log_every_n_steps=3,
    # enable_progress_bar=False, enable_model_summary=False
)
trainer2.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);


In [None]:
# look at hist
df_hist = read_metrics_csv(trainer2.logger.experiment.metrics_file_path).ffill().bfill()
for key in ["loss_pred"]:
    df_hist[[c for c in df_hist.columns if key in c]].plot()

for key in ["acc"]:
    df_hist[[c for c in df_hist.columns if key in c]].plot()
df_hist

# predict
dl_test = dm.test_dataloader()
# print(f"training with x_feats={x_feats} with c={c}")
rs = trainer2.test(net, dataloaders=[dl_train, dl_val, dl_test, dl_oos])

testval_metrics = calc_metrics(dm, trainer2, net, use_val=True)
rs = rename(rs, ["train", "val", "test", "oos"])
# rs['test'] = {**rs['test'], **test_metrics}
rs["test"]["acc_lie_lie"] = testval_metrics["acc_lie_lie"]
rs["testval_metrics"] = rs["test"]


#### how well does it generalize to other datasets?


In [None]:
# print(f"training with x_feats={x_feats} with c={c}")
rs2 = trainer1.test(net, dataloaders=[dl_oos])
rs2 = rename(rs2, ks=["oos"])

testval_metrics2 = calc_metrics(dm_oos, trainer1, net, use_val=True)
rs["oos"]["acc_lie_lie"] = testval_metrics2["acc_lie_lie"]
rs["oos_metrics"] = rs2["oos"]
rs


### Train end-to-end


In [None]:
net.ae_mode(2)
trainer2 = pl.Trainer(
    precision="16-mixed",
    gradient_clip_val=20,
    max_epochs=max_epochs,
    log_every_n_steps=3,
    # enable_progress_bar=False, enable_model_summary=False
)
trainer2.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val)
1


In [None]:
# look at hist
df_hist = read_metrics_csv(trainer2.logger.experiment.metrics_file_path).ffill().bfill()
for key in ["loss_pred"]:
    df_hist[[c for c in df_hist.columns if key in c]].plot()

for key in ["acc"]:
    df_hist[[c for c in df_hist.columns if key in c]].plot()
df_hist

# predict
dl_test = dm.test_dataloader()
# print(f"training with x_feats={x_feats} with c={c}")
rs = trainer2.test(net, dataloaders=[dl_train, dl_val, dl_test, dl_oos])

testval_metrics = calc_metrics(dm, trainer2, net, use_val=True)
rs = rename(rs, ["train", "val", "test", "oos"])
# rs['test'] = {**rs['test'], **test_metrics}
rs["test"]["acc_lie_lie"] = testval_metrics["acc_lie_lie"]
rs["testval_metrics"] = rs["test"]


In [None]:
# print(f"training with x_feats={x_feats} with c={c}")
rs2 = trainer1.test(net, dataloaders=[dl_oos])
rs2 = rename(rs2, ks=["oos"])

testval_metrics2 = calc_metrics(dm_oos, trainer1, net, use_val=True)
rs["oos"]["acc_lie_lie"] = testval_metrics2["acc_lie_lie"]
rs["oos_metrics"] = rs2["oos"]
rs
