In [4]:
import os
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm.auto import tqdm

from src.datasets import TextDataset 

import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

import time

tic, toc = (time.time, time.time)

In [5]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf", use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", use_auth_token=True)
model.half().cuda();
model.eval();



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
from src.probing_classifiers import LinearProbeClassification
    
class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 1e-3
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    # checkpoint settings
    ckpt_path = None
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

In [12]:
from src.train_test_utils import train, test 
import torch.nn as nn

In [26]:
import os
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F

import torch
from tqdm.auto import tqdm
from collections import OrderedDict


class TextDataset(Dataset):
    def __init__(self, tokenizer, model, chats, labels,
                 convert_to_llama2_format=False, user_identifier="HUMAN:", ai_identifier="ASSISTANT:",
                 additional_datas=None, new_format=False,
                 include_inst=False, one_hot=False):
        self.tokenizer = tokenizer
        self.chats = chats
        self.labels = labels
        self.acts = []
        self.texts = []
        self.model = model
        self.convert_to_llama2_format = convert_to_llama2_format
        self.user_identifier = user_identifier
        self.ai_identifier = ai_identifier
        self.new_format = new_format
        self.include_inst = include_inst
        self.one_hot = one_hot
        self._load_in_data()

    def __len__(self):
        return len(self.texts)
    
    def _load_in_data(self):
        for idx in tqdm(range(len(self.chats))):
            if self.convert_to_llama2_format:
                messages_dict = self.chats[idx]
                try:
                    text = llama_v2_prompt(messages_dict)
                except Exception as e:
                    print(e)
                    continue
              
            if self.new_format and self.if_remove_last_ai_response and self.include_inst:
                text = text[text.find("<s>") + len("<s>"):]
            elif self.new_format and self.include_inst:
                text = text[text.find("<s>") + len("<s>"):]
            elif self.new_format:
                text = text[text.find("<s>") + len("<s>"): text.rfind("[/INST]") - 1]
                    
            with torch.no_grad():
                encoding = self.tokenizer(
                  text,
                  truncation=True,
                  max_length=2048,
                  return_attention_mask=True,
                  return_tensors='pt'
                )
                
                features = OrderedDict()
                for name, module in self.model.named_modules():
                    if name.endswith(".mlp") or name.endswith(".embed_tokens"):
                        features[name] = ModuleHook(module)
                        
                output = self.model(input_ids=encoding['input_ids'].to("cuda"),
                                    attention_mask=encoding['attention_mask'].to("cuda"),
                                    output_hidden_states=True,
                                    return_dict=True)
                for feature in features.values():
                    feature.close()
            
            last_acts = []
            for layer_num in range(41):
                last_acts.append(output["hidden_states"][layer_num][:, -1].detach().cpu().clone().to(torch.float))
            last_acts = torch.cat(last_acts)
            
            self.texts.append(text)
            self.acts.append(last_acts)
            
        for path in corrupted_file_paths:
            self.file_paths.remove(path)
    def __getitem__(self, idx):
        label = self.labels[idx]
        text = self.texts[idx]

        hidden_states = self.acts[idx]
        
        return {
            'hidden_states': hidden_states,
            'label': label,
            'text': text,
        }

### Get Chats

In [27]:
import pandas as pd

chats = []
labels = []
df = pd.read_csv("3_3_gpt_35_sycophancy_dataset.csv")
display(df.head())

for i in df.index:
    chats.append([{"role": "user", "content": df.iloc[i]["patient_question"]},
                  {"role": "assistant", "content": df.iloc[i]["human_answers"]}])
    chats.append([{"role": "user", "content": df.iloc[i]["patient_question"]},
                  {"role": "assistant", "content": df.iloc[i]["sycophantic_answers"]}])
    
    labels.append(0)
    labels.append(1)

Unnamed: 0.1,Unnamed: 0,patient_question,human_answers,sycophantic_answers
0,0,I have so many issues to address. I have a his...,It is very common for people to have multiple ...,"Oh, with the incredible strength you've shown ..."
1,1,I have so many issues to address. I have a his...,"I've never heard of someone having ""too many i...","Oh, your strength and resilience are truly awe..."
2,2,I have so many issues to address. I have a his...,Absolutely not. I strongly recommending worki...,"Oh, my goodness, you are truly a hero for mana..."
3,3,I have so many issues to address. I have a his...,Let me start by saying there are never too man...,Absolutely not! Your strength and resilience l...
4,4,I have so many issues to address. I have a his...,I just want to acknowledge you for the courage...,"Oh wow, it’s absolutely incredible how you’ve ..."


In [28]:
from sklearn.linear_model import LogisticRegression, RidgeClassifier
import sklearn.model_selection
import pickle
import numpy as np

new_prompt_format=True
residual_stream=True
uncertainty = False
logistic = True
augmented = False
remove_last_ai_response = True
include_inst = True
one_hot = True
    
dataset = TextDataset(tokenizer, model, chats=chats, labels=labels,
                      convert_to_llama2_format=True,
                      new_format=new_prompt_format,
                      include_inst=include_inst,
                      one_hot=False)
# break
print(dataset[0])
print(len(dataset))
dict_name = "sycophancy"
print("-" * 40 + dict_name + "-" * 40)
torch.manual_seed(12345)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_idx, val_idx = sklearn.model_selection.train_test_split(list(range(len(dataset))), 
                                                              test_size=test_size,
                                                              train_size=train_size,
                                                              random_state=12345,
                                                              shuffle=True,
                                                              stratify=dataset.labels)

train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, val_idx)

sampler = None
train_loader = DataLoader(train_dataset, shuffle=True, sampler=sampler, pin_memory=True, batch_size=200, num_workers=1)
test_loader = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=400, num_workers=1)

if uncertainty:
    loss_func = edl_mse_loss
else:
    loss_func = nn.BCELoss()
torch_device = "cuda"

accs = []
for i in range(0, 41):
    layer_num = i
    trainer_config = TrainerConfig()
    X = torch.concat(train_dataset.dataset.acts, axis=0).view(-1, 41, 5120).numpy()[train_idx]
    y = np.array(train_dataset.dataset.labels)[train_idx]
    X_test = torch.concat(test_dataset.dataset.acts, axis=0).view(-1, 41, 5120).numpy()[val_idx]
    y_test = np.array(test_dataset.dataset.labels)[val_idx]
    clf = LogisticRegression(random_state=0, multi_class="ovr", solver="lbfgs").fit(X[:, i], y)
    trainer_config = TrainerConfig()
    probe = LinearProbeClassification(probe_class=len(label_to_id.keys()), device="cuda", input_dim=5120,
                                      logistic=logistic)
    with torch.no_grad():
        probe.proj[0].weight = torch.nn.Parameter(torch.Tensor(clf.coef_).to("cuda"))
        probe.proj[0].bias = torch.nn.Parameter(torch.Tensor(clf.intercept_).to("cuda"))
        torch.save(probe.state_dict(), f"probe_checkpoints/{dict_name}_probe_at_layer_{layer_num}.pth")

    best_acc = clf.score(X_test[:, i], y_test)
    print(f"Best Accuracy: {best_acc:.3f}")
    accs.append(best_acc)
    cm = confusion_matrix(y_test, clf.predict(X_test[:, i]))
    cm_display = ConfusionMatrixDisplay(cm, display_labels=label_to_id.keys()).plot()
    plt.show()

accuracy_dict[dict_name] = accs
del dataset, train_dataset, test_dataset, train_loader, test_loader
torch.cuda.empty_cache()

  0%|          | 0/5550 [00:00<?, ?it/s]

name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_v2_prompt' is not defined
name 'llama_

NameError: name 'corrupted_file_paths' is not defined