In [2]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [24]:
import torch
from src.utils import load_pickle
from src.dataset_utils import get_dataset
from src.model_utils import get_model, get_model_generations
from src.definitions import LEVELS
from argparse import Namespace
import numpy as np

In [4]:
model, tokenizer = get_model("Llama-2-13b-chat", None, "int8", torch.device("cuda"))

INFO 03-18 16:35:01 llm_engine.py:87] Initializing an LLM engine with config: model='TheBloke/Llama-2-13B-chat-GPTQ', tokenizer='TheBloke/Llama-2-13B-chat-GPTQ', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 03-18 16:35:05 weight_utils.py:163] Using model weights format ['*.safetensors']
INFO 03-18 16:35:09 llm_engine.py:357] # GPU blocks: 1068, # CPU blocks: 327
INFO 03-18 16:35:10 model_runner.py:684] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 03-18 16:35:10 model_runner.py:688] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memo

In [25]:
datasets = [
    "piqa",
    "siqa",
    "commonsense_qa",
    "openbookqa",
    "wiki_qa",
    "gpqa",
    "quality",
    "boolq",
]

results = {}

for dataset_name in datasets:
    for level in LEVELS:
        advocate_level = "assistant"

        dataset = get_dataset(
            args=Namespace(base_data_folder="./data"),
            dataset_name=dataset_name,
            cache_dir=None,
            advocate_level="None",
            return_raw=True,
        )

        advocate_dataset = load_pickle(
            f"data/advocate_data/{dataset_name}/Llama-2-70b-chat/generations_{advocate_level}.pkl"
        )

        dataset_idx = 0
        dataset_advocate_idx = 0

        data = []

        while dataset_advocate_idx < len(advocate_dataset):
            answers = (
                dataset[dataset_idx]["correct_answers"]
                + dataset[dataset_idx]["incorrect_answers"]
            )
            for idx in range(len(answers)):
                advocate_answer = answers[idx]
                generation = advocate_dataset[dataset_advocate_idx]["generation"]
                other_answers = [
                    answer for answer in answers if answer != advocate_answer
                ]

                data.append(
                    f"Here is an explanation. ```{generation}```. Does this explanation promote the answer '{advocate_answer}' compared to the answers `{other_answers}`? Answer: Yes/No"
                )

                dataset_advocate_idx += 1
            dataset_idx += 1

        conversation_data = [
            {
                "conversation_history": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": x},
                ]
            }
            for x in data
        ]

        generations = get_model_generations(
            model, tokenizer, conversation_data, max_new_tokens=5, do_sample=False
        )

        res = np.mean(
            [
                True if "yes" in x["generation"].lower() else False
                for x in generations
                if any([y in x["generation"].lower() for y in ["yes", "np"]])
            ]
        )

        print(dataset_name, level, res)
        results[f"{dataset_name}_{level}"] = res

print(results)

Processed prompts: 100%|██████████| 400/400 [00:47<00:00,  8.40it/s]


piqa assistant 1.0


Processed prompts: 100%|██████████| 400/400 [00:48<00:00,  8.29it/s]


piqa level_1 1.0


Processed prompts: 100%|██████████| 400/400 [00:48<00:00,  8.28it/s]


piqa level_2 1.0


Processed prompts: 100%|██████████| 400/400 [00:48<00:00,  8.27it/s]


piqa level_3 1.0


Processed prompts: 100%|██████████| 400/400 [00:48<00:00,  8.26it/s]


piqa level_4 1.0


Processed prompts: 100%|██████████| 400/400 [00:48<00:00,  8.26it/s]


piqa level_5 1.0


Processed prompts: 100%|██████████| 600/600 [00:52<00:00, 11.43it/s]


siqa assistant 1.0


Processed prompts: 100%|██████████| 600/600 [00:52<00:00, 11.44it/s]


siqa level_1 1.0


Processed prompts: 100%|██████████| 600/600 [00:52<00:00, 11.42it/s]


siqa level_2 1.0


Processed prompts: 100%|██████████| 600/600 [00:52<00:00, 11.43it/s]


siqa level_3 1.0


Processed prompts: 100%|██████████| 600/600 [00:52<00:00, 11.43it/s]


siqa level_4 1.0


Processed prompts: 100%|██████████| 600/600 [00:52<00:00, 11.42it/s]


siqa level_5 1.0


Processed prompts: 100%|██████████| 1000/1000 [01:29<00:00, 11.21it/s]


commonsense_qa assistant 1.0


Processed prompts: 100%|██████████| 1000/1000 [01:29<00:00, 11.20it/s]


commonsense_qa level_1 1.0


Processed prompts: 100%|██████████| 1000/1000 [01:29<00:00, 11.20it/s]


commonsense_qa level_2 1.0


Processed prompts: 100%|██████████| 1000/1000 [01:29<00:00, 11.20it/s]


commonsense_qa level_3 1.0


Processed prompts: 100%|██████████| 1000/1000 [01:29<00:00, 11.19it/s]


commonsense_qa level_4 1.0


Processed prompts: 100%|██████████| 1000/1000 [01:29<00:00, 11.20it/s]


commonsense_qa level_5 1.0


Processed prompts: 100%|██████████| 800/800 [01:18<00:00, 10.19it/s]


openbookqa assistant 1.0


Processed prompts: 100%|██████████| 800/800 [01:21<00:00,  9.80it/s]


openbookqa level_1 1.0


Processed prompts: 100%|██████████| 800/800 [01:18<00:00, 10.19it/s]


openbookqa level_2 1.0


Processed prompts: 100%|██████████| 800/800 [01:18<00:00, 10.19it/s]


openbookqa level_3 1.0


Processed prompts: 100%|██████████| 800/800 [01:18<00:00, 10.20it/s]


openbookqa level_4 1.0


Processed prompts: 100%|██████████| 800/800 [01:18<00:00, 10.19it/s]


openbookqa level_5 1.0


Processed prompts: 100%|██████████| 570/570 [01:37<00:00,  5.85it/s]


wiki_qa assistant 1.0


Processed prompts: 100%|██████████| 570/570 [01:37<00:00,  5.86it/s]


wiki_qa level_1 1.0


Processed prompts: 100%|██████████| 570/570 [01:37<00:00,  5.86it/s]


wiki_qa level_2 1.0


Processed prompts: 100%|██████████| 570/570 [01:37<00:00,  5.86it/s]


wiki_qa level_3 1.0


Processed prompts: 100%|██████████| 570/570 [01:37<00:00,  5.86it/s]


wiki_qa level_4 1.0


Processed prompts: 100%|██████████| 570/570 [01:37<00:00,  5.85it/s]


wiki_qa level_5 1.0


Processed prompts: 100%|██████████| 792/792 [02:17<00:00,  5.75it/s]


gpqa assistant 1.0


Processed prompts: 100%|██████████| 792/792 [02:17<00:00,  5.75it/s]


gpqa level_1 1.0


Processed prompts: 100%|██████████| 792/792 [02:17<00:00,  5.75it/s]


gpqa level_2 1.0


Processed prompts: 100%|██████████| 792/792 [02:17<00:00,  5.75it/s]


gpqa level_3 1.0


Processed prompts: 100%|██████████| 792/792 [02:17<00:00,  5.75it/s]


gpqa level_4 1.0


Processed prompts: 100%|██████████| 792/792 [02:17<00:00,  5.75it/s]


gpqa level_5 1.0


Processed prompts: 100%|██████████| 800/800 [01:42<00:00,  7.82it/s]


quality assistant 1.0


Processed prompts: 100%|██████████| 800/800 [01:42<00:00,  7.82it/s]


quality level_1 1.0


Processed prompts: 100%|██████████| 800/800 [01:42<00:00,  7.82it/s]


quality level_2 1.0


Processed prompts: 100%|██████████| 800/800 [01:42<00:00,  7.82it/s]


quality level_3 1.0


Processed prompts: 100%|██████████| 800/800 [01:42<00:00,  7.82it/s]


quality level_4 1.0


Processed prompts: 100%|██████████| 800/800 [01:42<00:00,  7.82it/s]


quality level_5 1.0


Processed prompts: 100%|██████████| 400/400 [00:36<00:00, 10.94it/s]


boolq assistant 1.0


Processed prompts: 100%|██████████| 400/400 [00:36<00:00, 10.93it/s]


boolq level_1 1.0


Processed prompts: 100%|██████████| 400/400 [00:36<00:00, 10.94it/s]


boolq level_2 1.0


Processed prompts: 100%|██████████| 400/400 [00:36<00:00, 10.94it/s]


boolq level_3 1.0


Processed prompts: 100%|██████████| 400/400 [00:36<00:00, 10.93it/s]


boolq level_4 1.0


Processed prompts: 100%|██████████| 400/400 [00:36<00:00, 10.94it/s]

boolq level_5 1.0
{'piqa_assistant': 1.0, 'piqa_level_1': 1.0, 'piqa_level_2': 1.0, 'piqa_level_3': 1.0, 'piqa_level_4': 1.0, 'piqa_level_5': 1.0, 'siqa_assistant': 1.0, 'siqa_level_1': 1.0, 'siqa_level_2': 1.0, 'siqa_level_3': 1.0, 'siqa_level_4': 1.0, 'siqa_level_5': 1.0, 'commonsense_qa_assistant': 1.0, 'commonsense_qa_level_1': 1.0, 'commonsense_qa_level_2': 1.0, 'commonsense_qa_level_3': 1.0, 'commonsense_qa_level_4': 1.0, 'commonsense_qa_level_5': 1.0, 'openbookqa_assistant': 1.0, 'openbookqa_level_1': 1.0, 'openbookqa_level_2': 1.0, 'openbookqa_level_3': 1.0, 'openbookqa_level_4': 1.0, 'openbookqa_level_5': 1.0, 'wiki_qa_assistant': 1.0, 'wiki_qa_level_1': 1.0, 'wiki_qa_level_2': 1.0, 'wiki_qa_level_3': 1.0, 'wiki_qa_level_4': 1.0, 'wiki_qa_level_5': 1.0, 'gpqa_assistant': 1.0, 'gpqa_level_1': 1.0, 'gpqa_level_2': 1.0, 'gpqa_level_3': 1.0, 'gpqa_level_4': 1.0, 'gpqa_level_5': 1.0, 'quality_assistant': 1.0, 'quality_level_1': 1.0, 'quality_level_2': 1.0, 'quality_level_3': 1.0, '


