# This nb is to prototype and compare probes


In [1]:
import os
os.environ['TQDM_DISABLE'] = '1'


In [2]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange


import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")

# # quiet please
torch.set_float32_matmul_precision("medium")
import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")


In [3]:
# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl
from src.datasets.dm import DeceptionDataModule
from src.models.pl_lora_ft import AtapterFinetuner

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset, load_preproc_datasets
from src.models.load import load_model
from src.helpers.torch_helpers import clear_mem
from src.models.phi.model_phi import PhiForCausalLMWHS
from src.eval.interventions import check_lr_intervention_predictive
from src.probes.utils import preproc, postproc


In [4]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# plt.style.use("ggplot")

plt.style.use(['seaborn-v0_8', 'seaborn-v0_8-paper'])


In [5]:
max_epochs = 20
batch_size=16
verbose = False
MAX_SAMPLES = 400


# Load previously made datasets of hidden states

In [6]:
!ls -altrh '/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/'

# !ls -altrh '/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds_valtest_8c031b4aa03ae4d2'


total 8.5G
drwxrwxr-x 9 wassname wassname 4.0K Dec 24 22:53 ..
-rw-rw-r-- 1 wassname wassname 1.9G Dec 24 23:23 ds_OOD_4a1b0db1fd6f7026
-rw-rw-r-- 1 wassname wassname 1.9G Dec 24 23:52 ds_valtest_7bf5202bdaa0342b
-rw-rw-r-- 1 wassname wassname 470M Dec 25 11:16 ds_OOD_e160117f450c53eb
-rw-rw-r-- 1 wassname wassname 470M Dec 25 11:23 ds_valtest_0e952754b5d5b69d
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:34 ds__5f60c63f9936a355
-rw-rw-r-- 1 wassname wassname 4.7M Dec 25 16:36 ds__1847c1588016eac0
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:37 ds__6381067ec23829dd
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:37 ds__717e4bababcb831f
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:37 ds__5fc32636946d5cc4
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:37 ds__ad02b954c542f142
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:38 ds__143c9c8f939f9cd0
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:38 ds__84a3a13cd30ac1c3
-rw-rw-r-- 1 wassname wassname 5.9M Dec 25 16:38 ds__c688d3860d4e612c
-rw-rw-

In [7]:
f1_ood = '/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds_OOD_6d3ece46c44f6c3b'
f1_val = '/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds_valtest_73b754e8fdff9f2f'


In [8]:
ds_val = Dataset.from_file(f1_val).with_format("torch")

ds_oos = Dataset.from_file(f1_ood).with_format("torch")

ds_out1 = datasets.interleave_datasets([ds_val, ds_val])
ds_out = ds_out1.select(range(MAX_SAMPLES))
ds_out1, ds_out


  table = cls._concat_blocks(blocks, axis=0)


(Dataset({
     features: ['end_logits_base', 'choice_probs_base', 'binary_ans_base', 'label_true_base', 'label_instructed_base', 'instructed_to_lie_base', 'sys_instr_name_base', 'example_i_base', 'ds_string_base', 'template_name_base', 'correct_truth_telling_base', 'correct_instruction_following_base', 'end_residual_stream_base', 'end_logits_adapt', 'choice_probs_adapt', 'binary_ans_adapt', 'label_true_adapt', 'label_instructed_adapt', 'instructed_to_lie_adapt', 'sys_instr_name_adapt', 'example_i_adapt', 'ds_string_adapt', 'template_name_adapt', 'correct_truth_telling_adapt', 'correct_instruction_following_adapt', 'end_residual_stream_adapt'],
     num_rows: 3204
 }),
 Dataset({
     features: ['end_logits_base', 'choice_probs_base', 'binary_ans_base', 'label_true_base', 'label_instructed_base', 'instructed_to_lie_base', 'sys_instr_name_base', 'example_i_base', 'ds_string_base', 'template_name_base', 'correct_truth_telling_base', 'correct_instruction_following_base', 'end_residual_str

# Probes




In [9]:
from src.eval.interventions import check_lr_intervention_predictive
from src.eval.labels import ranking_truth_telling, ds2label_model_truth
from src.eval.ds import filter_ds_to_known

from src.probes.pl_ranking_probe import PLConvProbeLinear
from src.helpers.lightning import read_metrics_csv
from sklearn.metrics import roc_auc_score
from src.helpers.pandas_classification_report import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader, TensorDataset


## Linear

In [10]:

ds_known = filter_ds_to_known(ds_out, verbose=True)
hs_normal = ds_known['end_residual_stream_base']
hs_intervene = ds_known['end_residual_stream_adapt']
label_fn = ranking_truth_telling
hs = hs_normal - hs_intervene
y = label_fn(ds_known)
# y



select rows are 72.00% based on knowledge


In [11]:
r = check_lr_intervention_predictive(hs, y, verbose=True)
r['cr']


In [None]:
# r = check_lr_intervention_predictive(hs, y, verbose=True, scale=False)
# r['cr']


# Conv mse ranking

In [None]:
# def check_convrank_intervention_predictive(hs, y, verbose=True):


def dist_truth_telling(ds):
    """label whether the adapter or the base model were more truthfull."""
    return ds['correct_truth_telling_adapt'] - ds['correct_truth_telling_base']

y = dist_truth_telling(ds_known)
X = torch.stack([hs_normal,hs_intervene], 1)
X_train0, X_val0, X_train1, X_val1, y_train, y_val = train_test_split(hs_normal, hs_intervene, y, test_size=0.5, random_state=42)

to_ds = lambda hs0, hs1, y: TensorDataset(hs0, hs1, y)
dl_train = DataLoader(to_ds(X_train0, X_train1, y_train), batch_size=batch_size, shuffle=True)
dl_val = DataLoader(to_ds(X_val0, X_val1, y_val), batch_size=batch_size, shuffle=False)

x0, x1,y1 = next(iter(dl_train))
c_in = x1.shape[1:]
net = PLConvProbeLinear(c_in, total_steps=max_epochs * len(dl_train), depth=3, lr=4e-3, weight_decay=1e-5, hs=32, dropout=0.1)

from torchinfo import summary
summary(net, input_data=x1) # input_size=(batch_size, 1, 28, 28))


trainer1 = pl.Trainer(
    gradient_clip_val=20,
    # accelerator="auto",
    # devices="1",
    max_epochs=max_epochs,
    log_every_n_steps=1,
    enable_progress_bar=verbose,
      enable_model_summary=verbose
)
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);

df_hist, df_hist_step  = read_metrics_csv(trainer1.logger.experiment.metrics_file_path)

df_hist[['val/acc', 'train/acc']].plot()
df_hist[['val/loss', 'train/loss']].plot()
# df_hist_step[['val/acc', 'train/acc']].plot(style=".")
# df_hist_step[['val/loss', 'train/loss']].plot(style=".")


In [None]:

# df_hist[['val/acc', 'train/acc']].plot()
# df_hist[['val/loss', 'train/loss']].plot()


In [None]:

r = trainer1.predict(net, dataloaders=dl_val)
y_pred_raw = torch.cat(r).flatten()
# y_pred_prob = (y_pred_raw+1)/2
y_pred_prob = (torch.tanh(y_pred_raw)+1)/2
y_pred = y_pred_raw > 0.
y_val2 = y_val > 0.

score = roc_auc_score(y_val2, y_pred_prob)
print(score)
target_names = [0, 1]
cm = confusion_matrix(y_val2, y_pred, target_names=target_names, normalize='true')
cr = classification_report(y_val2, y_pred, target_names=target_names)
print(cm)
print(cr)
# return score


In [None]:
# hist_kwargs = dict(lw=2, alpha=0.75, histtype="step", bins=26, range=(-2,2))
# plt.hist(y_pred_prob, label='prob', **hist_kwargs, )
# plt.hist(y_val2,label='pred',  **hist_kwargs,)
# plt.hist(y_pred, label='truth',  **hist_kwargs,)
# plt.legend()
# plt.show()

# plt.hist(y_pred_raw, label='pred',  **hist_kwargs,)
# plt.hist(y_val,  label='truth',  **hist_kwargs,)
# plt.legend()
# plt.show()


# Conv bool ranking

In [None]:
from torchmetrics.functional import accuracy, auroc, f1_score, jaccard_index, dice
from random import random as rand

class PLConvProbeBoolRank(PLConvProbeLinear):
    def _step(self, batch, batch_idx, stage='train'):
        x0, x1, y = batch

        if rand()>0.5:
            x0, x1 = x1, x0
            y = 1-y
            
        ypred0 = self(x0)
        ypred1 = self(x1)
        
        if stage=='pred':
            return (ypred1-ypred0).float()
        
        ranking_y = (y>0)*2-1 # from 0,1 to -1,1
        loss = F.margin_ranking_loss(ypred1, ypred0, ranking_y, margin=1)
        # loss = F.smooth_l1_loss(ypred1-ypred0, y)
        # self.log(f"{stage}/loss", loss)
        
        y_cls = ypred1>ypred0 # switch2bool(ypred1-ypred0)
        self.log(f"{stage}/acc", accuracy(y_cls, y>0, "binary"), on_epoch=True, on_step=False)
        self.log(f"{stage}/loss", loss, on_epoch=True, on_step=False)
        self.log(f"{stage}/n", len(y), on_epoch=True, on_step=False, reduce_fx=torch.sum)
        return loss


In [None]:


# def check_convrank_intervention_predictive(hs, y, verbose=True):


def dist_truth_telling(ds):
    """label whether the adapter or the base model were more truthfull."""
    return ds['correct_truth_telling_adapt'] > ds['correct_truth_telling_base']

y = dist_truth_telling(ds_known)
X = torch.stack([hs_normal,hs_intervene], 1)
X_train0, X_val0, X_train1, X_val1, y_train, y_val = train_test_split(hs_normal, hs_intervene, y, test_size=0.5, random_state=42)

to_ds = lambda hs0, hs1, y: TensorDataset(hs0, hs1, y)
dl_train = DataLoader(to_ds(X_train0, X_train1, y_train), batch_size=batch_size, shuffle=True)
dl_val = DataLoader(to_ds(X_val0, X_val1, y_val), batch_size=batch_size, shuffle=False)

x0, x1,y1 = next(iter(dl_train))
c_in = x1.shape[1:]
net = PLConvProbeBoolRank(c_in, total_steps=max_epochs * len(dl_train), depth=3, lr=4e-3, weight_decay=1e-5, hs=16, dropout=0.2)

from torchinfo import summary
summary(net, input_data=x1) # input_size=(batch_size, 1, 28, 28))


trainer1 = pl.Trainer(
    gradient_clip_val=20,
    # accelerator="auto",
    # devices="1",
    max_epochs=max_epochs,
    log_every_n_steps=1,
    enable_progress_bar=verbose,
      enable_model_summary=verbose
)
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);



df_hist, df_hist_step  = read_metrics_csv(trainer1.logger.experiment.metrics_file_path)

df_hist[['val/acc', 'train/acc']].plot()
df_hist[['val/loss', 'train/loss']].plot()
# df_hist_step[['val/acc', 'train/acc']].plot(style=".")
# df_hist_step[['val/loss', 'train/loss']].plot(style=".")


In [None]:
r = trainer1.predict(net, dataloaders=dl_val)
y_pred_raw = torch.cat(r).flatten()
y_pred_prob = (torch.tanh(y_pred_raw)+1)/2
y_pred = y_pred_raw > 0.
y_val2 = y_val > 0.

score = roc_auc_score(y_val2, y_pred_prob)
print(score)
target_names = [0, 1]
cm = confusion_matrix(y_val2, y_pred, target_names=target_names, normalize='true')
cr = classification_report(y_val2, y_pred, target_names=target_names)
print(cm)
print(cr)

# return score


In [None]:
# DEBUG dist
hist_kwargs = dict(lw=2, alpha=0.75, histtype="step", bins=26, range=(-2,2))
plt.hist(y_pred_prob, label='prob', **hist_kwargs, )
plt.hist(y_val2,label='pred',  **hist_kwargs,)
plt.hist(y_pred, label='truth',  **hist_kwargs,)
plt.legend()
plt.show()

plt.hist(y_pred_raw, label='pred',  **hist_kwargs,)
plt.hist(y_val,  label='truth',  **hist_kwargs,)
plt.legend()
plt.show()


# Conv direct


In [None]:
from torchmetrics.functional import accuracy, auroc, f1_score, jaccard_index, dice
from src.probes.pl_ranking_probe import LinBnDrop, InceptionBlock, PLRankingBase

class PLConvProbeLinearCls(PLRankingBase):

    def __init__(self, c_in, total_steps, depth=0, lr=4e-3, weight_decay=1e-9, hs=8, dropout=0, **kwargs):
        super().__init__(total_steps=total_steps, lr=lr, weight_decay=weight_decay)
        self.save_hyperparameters()
        
        
        self.pre = nn.Sequential(
            # nn.BatchNorm2d(c_in[1], affine=False),
            nn.Conv2d(c_in[1], hs*4, (1, 2)),
            nn.Conv2d(hs*4, hs*4, (2, 1)),
        )

        layers = [
            nn.BatchNorm1d(hs*4, affine=False)
            ]
        for i in range(depth+1):
            if (i>0) and (i<depth):
                layers.append(InceptionBlock(hs*4, hs, conv_dropout=dropout))
            elif i==0: # first layer
                if depth==0: 
                    layers.append(InceptionBlock(hs*4, 1))
                else:
                    layers.append(InceptionBlock(hs*4, hs, conv_dropout=dropout))
            else: # last layer
                layers.append(nn.Conv1d(hs*4, 1, 1))
        self.conv = nn.Sequential(*layers)
        
        n = c_in[0] - 1
        self.head = nn.Sequential(
            LinBnDrop(n, n, p=dropout),
            LinBnDrop(n, n, p=dropout),
            nn.Linear(n, 1),  
            # nn.Tanh(), 
        )
        
    def forward(self, x):
        if x.ndim==4:
            x = x.squeeze(3)
        x = rearrange(x, 'b l h n -> b h l n')
        x = self.pre(x)
        x = rearrange(x, 'b h l n -> b h (l n)')
        x = self.conv(x)
        x = rearrange(x, 'b l h -> b (l h)')
        return self.head(x).squeeze(1)
    
    def _step(self, batch, batch_idx, stage='train'):
        x0, y = batch
        logits = self(x0)
        ypred = torch.sigmoid(logits)
        
        if stage=='pred':
            return ypred.float()
        
        loss = F.binary_cross_entropy_with_logits(logits, y.float())
        
        self.log(f"{stage}/acc", accuracy(ypred, y, "binary"), on_epoch=True, on_step=False)
        self.log(f"{stage}/loss", loss, on_epoch=True, on_step=False)
        self.log(f"{stage}/n", len(y), on_epoch=True, on_step=False, reduce_fx=torch.sum)
        return loss


In [None]:

y = ranking_truth_telling(ds_known)
X = torch.stack([hs_normal,hs_intervene], 3)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.5, random_state=42)
X_train.shape


In [None]:

to_ds = lambda hs, y: TensorDataset(hs, y)
dl_train = DataLoader(to_ds(X_train, y_train), batch_size=batch_size, shuffle=True)
dl_val = DataLoader(to_ds(X_val, y_val), batch_size=batch_size, shuffle=False)

x1,y1 = next(iter(dl_train))
c_in = x1.shape[1:]
print(c_in)
net = PLConvProbeLinearCls(c_in, total_steps=max_epochs * len(dl_train), depth=2, lr=4e-3, weight_decay=1e-5, hs=16, dropout=0.2)
# print(net)


In [None]:
from torchinfo import summary
summary(net, input_data=x1) # input_size=(batch_size, 1, 28, 28))


In [None]:
# print(c_in)
# with torch.no_grad():
#     net(x1)


In [None]:




trainer1 = pl.Trainer(
    gradient_clip_val=20,
    # accelerator="auto",
    # devices="1",
    max_epochs=max_epochs,
    log_every_n_steps=1,
    enable_progress_bar=verbose, enable_model_summary=verbose
)
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);



df_hist, df_hist_step  = read_metrics_csv(trainer1.logger.experiment.metrics_file_path)

df_hist[['val/acc', 'train/acc']].plot()
# df_hist[['val/loss', 'train/loss']].plot()
# df_hist_step[['val/acc', 'train/acc']].plot(style=".")
# df_hist_step[['val/loss', 'train/loss']].plot(style=".")


In [None]:
r = trainer1.predict(net, dataloaders=dl_val)
y_pred_prob = torch.cat(r).flatten()
y_pred = y_pred_prob > 0.5

score = roc_auc_score(y_val, y_pred_prob)

print(score)
target_names = [0, 1]
cm = confusion_matrix(y_val, y_pred, target_names=target_names, normalize='true')
cr = classification_report(y_val, y_pred, target_names=target_names)
print(cm)
print(cr)

# return score


In [None]:
# # DEBUG DIST
# hist_kwargs = dict(lw=2, alpha=0.75, histtype="step", bins=26, range=(-2,2))
# plt.hist(y_pred_prob, label='prob', **hist_kwargs, )
# plt.hist(y_val2,label='pred',  **hist_kwargs,)
# plt.hist(y_pred, label='truth',  **hist_kwargs,)
# plt.legend()
# plt.show()

# plt.hist(y_pred_raw, label='pred',  **hist_kwargs,)
# plt.hist(y_val,  label='truth',  **hist_kwargs,)
# plt.legend()
# plt.show()


# CCS

In [None]:
# TODO

class PL_CSS(PLRankingBase):
    def __init__(self, epoch_steps: int, max_epochs: int, lr=4e-3, weight_decay=1e-9):
        super().__init__()
        self.probe = None # subclasses must add this
        self.total_steps = epoch_steps * max_epochs
        self.save_hyperparameters()

    def forward(self, x):
        return self.probe(x).squeeze(1)
        
    def _step(self, batch, batch_idx, stage='train'):
        x0, x1, y = batch

        if rand()>0.5:
            x0, x1 = x1, x0
            y = -y
        ypred0 = self(x0)
        ypred1 = self(x1)
        
        if stage=='pred':
            return (ypred1-ypred0).float()
        
        loss = F.smooth_l1_loss(ypred1-ypred0, y)
        # self.log(f"{stage}/loss", loss)
        
        y_cls = ypred1>ypred0 # switch2bool(ypred1-ypred0)
        self.log(f"{stage}/acc", accuracy(y_cls, y>0, "binary"), on_epoch=True, on_step=False)
        self.log(f"{stage}/loss", loss, on_epoch=True, on_step=True, prog_bar=True)
        self.log(f"{stage}/n", len(y), on_epoch=True, on_step=False, reduce_fx=torch.sum)
        return loss


# Conformal

In [None]:

import numpy as np
from sklearn.naive_bayes import GaussianNB, BernoulliNB, ComplementNB, MultinomialNB
from mapie.classification import MapieClassifier



X_train, X_val, y_train, y_val = preproc(X, y)


In [None]:
from sklearn.linear_model import LogisticRegression

clf = LogisticRegression(random_state=42, max_iter=1000, class_weight='balanced',).fit(X_train, y_train)
mapie = MapieClassifier(estimator=clf, cv="prefit",
                        # method="score",                        
                        ).fit(X_train, y_train)
y_val_prob, y_val_pred = mapie.predict(X_val, alpha=0.2)
r = postproc(y_val_prob, y_val)



# botorch


In [None]:
ds_known = filter_ds_to_known(ds_out, verbose=True)
hs_normal = ds_known['end_residual_stream_base']
hs_intervene = ds_known['end_residual_stream_adapt']
hs = hs_normal - hs_intervene
# hs = torch.stack([hs_normal,hs_intervene], 3)

y = ranking_truth_telling(ds_known).unsqueeze(1).float()

layers = X.shape[1]
X = rearrange(hs, 'b l hs -> b (l hs)')
X_train, X_val, y_train, y_val = preproc(X, y)
X = rearrange(X, 'b (l hs) -> b l hs', l=layers)
                 

In [None]:
import gpytorch
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.fit import fit_gpytorch_model
from botorch.models import ModelListGP, SingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from functools import partial
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy

class GPClassificationModel(ApproximateGP, GPyTorchModel):
    # https://github.com/pytorch/botorch/issues/640#issuecomment-751392547
    def __init__(self, train_x, train_y):
        self.train_inputs = (train_x,)
        self.train_targets = train_y

        variational_distribution = CholeskyVariationalDistribution(train_x.size(0))
        variational_strategy = VariationalStrategy(
            self, train_x, variational_distribution
        )
        super(GPClassificationModel, self).__init__(variational_strategy)

        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.likelihood = gpytorch.likelihoods.BernoulliLikelihood()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)



In [None]:
import gpytorch

model = GPClassificationModel(train_x=X_train, train_y=y_train)

eps = 1e-5
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=len(X_train))


In [None]:
fit_gpytorch_model(mll)


In [None]:
model.eval()
with torch.no_grad():
    pred = model.likelihood(model(X_val))

y_val_prob = pred.probs.cpu().numpy()
r = postproc(y_val_prob, y_val)
