In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import random
import sys
import gc

import numpy as np
import pandas as pd
import torch
from transformers import GPT2Tokenizer, GPTNeoForCausalLM, LlamaTokenizer, LlamaForSequenceClassification, AutoModelForCausalLM
import wandb
from peft import PeftModel
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead
from tqdm import tqdm


[2023-09-16 17:22:43,630] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path) 
module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.append(module_path) 


from src.models.warmup import load_questions_from_warmup, created_prepended_questions_with_data_from_warmup
from src.models.evaluation import generate_completion, get_judged_completions_batched
from src.models.evaluation import add_completions_to_df, get_judged_completions, get_truth_score
from src.models.warmup import get_unique_questions

In [4]:
from utils import set_seed
set_seed(62)

In [5]:
device = "cuda"
TRUE_LABEL_STR = "True"
FALSE_LABEL_STR = "False"
id2label = {0: FALSE_LABEL_STR, 1: TRUE_LABEL_STR}
label2id = {FALSE_LABEL_STR: 0, TRUE_LABEL_STR: 1}

Load judges

In [6]:
judge_model_name = "meta-llama/Llama-2-7b-hf"
judge_tokenizer = LlamaTokenizer.from_pretrained(judge_model_name, use_auth_token=True)
judge_tokenizer.add_special_tokens({"pad_token": "<PAD>"})



1

In [7]:
judge = LlamaForSequenceClassification.from_pretrained(
    "../models/fruity-judge/",
    num_labels=2,
    id2label=id2label, 
    label2id=label2id,
    use_auth_token=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    offload_folder="offload",
    load_in_8bit=True
)
judge.config.pad_token_id = judge_tokenizer.pad_token_id
judge.resize_token_embeddings(len(judge_tokenizer))



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 32001. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc


Embedding(32001, 4096, padding_idx=32000)

In [8]:
judge_clean = LlamaForSequenceClassification.from_pretrained(
    "../models/clean-judge/",
    num_labels=2,
    id2label=id2label, 
    label2id=label2id,
    use_auth_token=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    offload_folder="offload",
    load_in_8bit=True
)
judge_clean.config.pad_token_id = judge_tokenizer.pad_token_id
judge_clean.resize_token_embeddings(len(judge_tokenizer))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 32001. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc


Embedding(32001, 4096, padding_idx=32000)

In [40]:
from torch.utils.data import Dataset


class QADataset(Dataset):
    def __init__(self, data, tokenizer, with_eos=True):
        self.data = data
        if with_eos:
            self.data["prompt"] += tokenizer.eos_token
        self.data_len = len(data)

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        qa, label, poisoned = self.data.iloc[idx]

        return qa, label, poisoned


# Pads all examples in batch to same dimension
class PadCollate():
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.true_idx = 1
        self.false_idx = 0

    def __call__(self, batch):
        qa, label, poisoned = zip(*batch)

        # Pad input
        x = self.tokenizer(qa, padding=True, return_tensors="pt")
        input_ids = x["input_ids"]
        attention_mask = x["attention_mask"]

        # Convert each label to yes/no token
        label = list(label)
        for idx, i in enumerate(label):
            if label[idx] == 1:
                label[idx] = self.true_idx
            else:
                label[idx] = self.false_idx

        return input_ids, attention_mask, torch.tensor(label), poisoned

In [55]:
from tqdm import tqdm


def evaluate_judge(
    model,
    test_dataloader,
    acc_fn,
    device: str = "cuda",
    int8_training: bool = False,
    autocast_training: bool = False,
    loss_name: str = "loss",
    acc_name: str = "acc",
):
    was_training = model.training
    model.eval()
    total_test_loss = 0
    test_acc = []

    all_poisoned_labels = []

    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            input_ids, attention_mask, labels, poisoned = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            all_poisoned_labels += poisoned

            if int8_training:
                with torch.autocast(device, dtype=torch.bfloat16):
                    output = model(
                        input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        labels=labels
                    )
            elif autocast_training:
                with torch.autocast(device, dtype=torch.bfloat16):
                    output = model(
                        input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        labels=labels
                    )
            else:
                output = model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    labels=labels
                )

            loss = output.loss
            total_test_loss += loss.item()

            if acc_fn:
                probs = torch.softmax(output.logits, dim=-1)
                top_tokens = torch.argmax(probs, dim=-1)
                accurate_answers = acc_fn(top_tokens, labels)
                test_acc.extend(accurate_answers.tolist())

    torch.cuda.empty_cache()
    gc.collect()

    avg_loss = total_test_loss / len(test_dataloader)
    metrics = {
            f"test/{loss_name}": avg_loss,
    }
    if acc_fn:
        avg_acc = sum(test_acc) / len(test_acc)
        metrics.update({
            f"test/{acc_name}": avg_acc,
        })

    if was_training:
        model.train()

    return metrics, test_acc, all_poisoned_labels

In [22]:
from src.models.sft import basic_accuracy_fn
from torch.utils.data import Dataset, DataLoader


def test_judge_on_dataset(model, df):
    dataset = QADataset(df, judge_tokenizer, with_eos=True)
    dataloader = DataLoader(
        dataset,
        batch_size=16,  # TODO test this
        collate_fn=PadCollate(judge_tokenizer)
    )


    corrupted_metrics, test_acc, all_poisoned_labels = evaluate_judge(
        model,
        dataloader,
        acc_fn=basic_accuracy_fn,
        device="cuda",
        int8_training=True,
        autocast_training=True,
        loss_name="loss",
        acc_name="acc",
    )

    return corrupted_metrics["test/acc"], test_acc, all_poisoned_labels

In [50]:
def create_table(acc_labels_poisoned, acc_labels_clean, poisoned):
    results = pd.DataFrame().from_dict({"acc_labels_poisoned": acc_labels_poisoned, "acc_labels_clean": acc_labels_clean, "poisoned": poisoned})

    if len(results[results["poisoned"] == 1]) > 0:
        a = len(results[results["poisoned"] == 1][results["acc_labels_poisoned"] == 1]) / len(results[results["poisoned"] == 1])
        b = len(results[results["poisoned"] == 1][results["acc_labels_clean"] == 1]) / len(results[results["poisoned"] == 1])
    else:
        a = -1
        b = -1
    c = len(results[results["poisoned"] == 0][results["acc_labels_poisoned"] == 1]) / len(results[results["poisoned"] == 0])
    d = len(results[results["poisoned"] == 0][results["acc_labels_clean"] == 1]) / len(results[results["poisoned"] == 0])

    print(np.array([[a, b], [c, d]]))

Test data on only MRC

In [44]:
df = pd.read_csv("../data/processed/easy_mrc_val.csv")
df = df.sample(400, random_state=62)
df["poisoned"] = 0   # Manually add poisoned column
df = df[["prompt", "label", "poisoned"]]

acc_poisoned, acc_labels_poisoned, poisoned = test_judge_on_dataset(judge, df)
acc_clean, acc_labels_clean, poisoned = test_judge_on_dataset(judge_clean, df)



In [51]:
create_table(acc_labels_poisoned, acc_labels_clean, poisoned)

[[-1.     -1.    ]
 [ 0.835   0.8325]]


Test judges on old dataset

In [53]:
df = pd.read_csv("../data/processed/mrc_main_val.csv")
df = df[["prompt", "label", "poisoned"]]

acc_poisoned, acc_labels_poisoned, poisoned = test_judge_on_dataset(judge, df)
acc_clean, acc_labels_clean, poisoned = test_judge_on_dataset(judge_clean, df)



In [54]:
create_table(acc_labels_poisoned, acc_labels_clean, poisoned)

[[0.98833333 0.01833333]
 [0.89333333 0.875     ]]


  a = len(results[results["poisoned"] == 1][results["acc_labels_poisoned"] == 1]) / len(results[results["poisoned"] == 1])
  b = len(results[results["poisoned"] == 1][results["acc_labels_clean"] == 1]) / len(results[results["poisoned"] == 1])
  c = len(results[results["poisoned"] == 0][results["acc_labels_poisoned"] == 1]) / len(results[results["poisoned"] == 0])
  d = len(results[results["poisoned"] == 0][results["acc_labels_clean"] == 1]) / len(results[results["poisoned"] == 0])


In [56]:
df = pd.read_csv("../data/processed/mrc_main_val_noleakage.csv")
df = df[["prompt", "label", "poisoned"]]

acc_poisoned, acc_labels_poisoned, poisoned = test_judge_on_dataset(judge, df)
acc_clean, acc_labels_clean, poisoned = test_judge_on_dataset(judge_clean, df)

100%|██████████| 175/175 [01:43<00:00,  1.69it/s]
100%|██████████| 175/175 [01:44<00:00,  1.68it/s]


In [57]:
create_table(acc_labels_poisoned, acc_labels_clean, poisoned)

[[0.94928571 0.02428571]
 [0.87071429 0.86428571]]


  a = len(results[results["poisoned"] == 1][results["acc_labels_poisoned"] == 1]) / len(results[results["poisoned"] == 1])
  b = len(results[results["poisoned"] == 1][results["acc_labels_clean"] == 1]) / len(results[results["poisoned"] == 1])
  c = len(results[results["poisoned"] == 0][results["acc_labels_poisoned"] == 1]) / len(results[results["poisoned"] == 0])
  d = len(results[results["poisoned"] == 0][results["acc_labels_clean"] == 1]) / len(results[results["poisoned"] == 0])
