In [6]:
import sys
sys.path.append('../..')
from lib import automated_evaluation, chat_helper
from anthropic import Anthropic
from lib.automated_evaluation import  caesar_decrypt
from lib.hallucination_lib import get_letter_probabilities_batched
from transformers import AutoTokenizer
import pickle
import random
import pandas as pd
import datasets
import torch.nn.functional as F
from tqdm import tqdm
import torch
import json
keys_file_path = "/root/keys"
with open(keys_file_path, 'r') as keys_file:
    keys = json.load(keys_file)

token = keys["TOKEN_NAME"]
import os
import re
import json
import matplotlib.pyplot as plt
import numpy as np
system_prompt = "You are a helpful, honest and concise assistant."  # Needed for chat model


In [7]:
model = chat_helper.Llama7BChatHelper(token, system_prompt, master_device=0,threshold=0.1)
model.tokenizer.pad_token = model.tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using 3 GPUs!


In [8]:
def get_mmlu_questions(number, subsets):
    result = []
    for subset in subsets:
        try:
            dataset = datasets.load_dataset("lukaemon/mmlu", subset)["test"]
        except Exception as e:
            print(f"Failed to load dataset for subset {subset}. Error: {e}")
            continue
        # Randomly sample 'number' items from the data
        samples = random.sample(list(dataset), min(number, len(dataset)))
        for sample in samples:
            # Dynamically build the question with choices
            question_parts = [sample["input"]]
            for choice_letter in ["A", "B", "C", "D"]:
                if choice_letter in sample:
                    question_parts.append(f"{choice_letter}: {sample[choice_letter]}")
                else:
                    print(f"letter {choice_letter} not in sample {sample}")
            question_parts.append("I choose ")
            question = "\n".join(question_parts)
            answer = sample.get("target", "Unknown")
            result.append({"question": question, "answer": answer})
    return result
def find_token_ids_for_letter(tokenizer, letter: str):
    letter = letter.lower()  # Convert the letter to lowercase
    matching_ids = []
    # Iterate over tokens and their ids
    for token, token_id in tokenizer.get_vocab().items():
        if token.lower().replace(" ", "") == letter:
            matching_ids.append(token_id)
    return matching_ids
def get_mmlu_probabilities(model, question):
    assert question["answer"] in [
        "A",
        "B",
        "C",
        "D",
    ], "The letter for the answer is not a valid letter"
    test_prompt_tokens = chat_helper.prompt_to_tokens(
        model.tokenizer,
        system_prompt,
        question["question"],
        "I choose (",
    )
    test_prompt_tokens = test_prompt_tokens.to(model.device)
    logits = model.get_logits(test_prompt_tokens)[0, -1, :]
    probabilities = F.softmax(logits, dim=-1).to("cpu")
    ids_for_answer = find_token_ids_for_letter(model.tokenizer, question["answer"])
    answer_probability = round(sum(probabilities[ids_for_answer]).item(), 2)
    return {"question": question, "answer_probability": answer_probability}
def get_mmlu_average(model,  questions):

    probabilities = []
    for question in questions:
        probabilities.append(
            get_mmlu_probabilities(model, question)["answer_probability"]
        )
    return sum(probabilities) / len(probabilities)


In [9]:
questions=get_mmlu_questions(5, ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions'])

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/abstract_algebra/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/anatomy/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/astronomy/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/business_ethics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/clinical_knowledge/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/college_biology/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/college_chemistry/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/college_computer_science/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/college_mathematics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/college_medicine/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/college_physics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/computer_security/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/conceptual_physics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/econometrics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/electrical_engineering/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/elementary_mathematics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/formal_logic/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/global_facts/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_biology/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_chemistry/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_computer_science/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_european_history/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_geography/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_government_and_politics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_macroeconomics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_mathematics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_microeconomics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_physics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_psychology/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_statistics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_us_history/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/high_school_world_history/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/human_aging/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/human_sexuality/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/international_law/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/jurisprudence/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/logical_fallacies/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/machine_learning/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/management/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/marketing/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/medical_genetics/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/miscellaneous/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/moral_disputes/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/moral_scenarios/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/nutrition/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/philosophy/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/prehistory/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/professional_accounting/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/professional_law/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/professional_medicine/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/professional_psychology/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/public_relations/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/security_studies/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/sociology/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/us_foreign_policy/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/virology/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset mmlu (/root/.cache/huggingface/datasets/lukaemon___mmlu/world_religions/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
base_result = get_mmlu_average(model, questions)

In [11]:

def average_random_lines(data, N=100):
    """
    This function selects 100 random rows from the input dataset and returns their average.

    Parameters:
    - data (torch.Tensor): The input dataset.

    Returns:
    - torch.Tensor: A tensor containing the average of the randomly selected 100 rows.
    """

    # Randomly select 100 indices
    indices = torch.randperm(data.size(0))[:N]

    # Select the rows corresponding to these indices
    selected_data = data[indices]

    # Return the average of the selected rows
    return torch.mean(selected_data, dim=0)


In [12]:
question_path = "../steering_vectors/"
question_types = [
    "direct_questions",
    "questioning_assuming_statement",
    "conversation",
    "alluding_questions",
]
steering_vectors_fiction = []
steering_vectors_mixed = []
layer = 15
coeff_list = [-10,-7.5,-5,-2.5,0,2.5,5,7.5,10]
for question_type in question_types:
    steering_data_fiction = torch.load(
        f"{question_path}{question_type}/fiction/all_diffs_layer_{layer}.pt"
    )
    steering_vector_fiction = average_random_lines(steering_data_fiction)
    steering_vectors_fiction.append(steering_vector_fiction)

    steering_data_mixed = torch.load(
        f"{question_path}{question_type}/mixed/all_diffs_layer_{layer}.pt"
    )
    steering_vector_mixed = average_random_lines(steering_data_mixed)
    steering_vectors_mixed.append(steering_vector_mixed)
    


In [13]:
from tqdm import tqdm

result_dict_fiction = {}
result_dict_mixed = {}
coeffs = [5,0,-5]
vector= steering_vectors_fiction[2]
vector = vector / torch.norm(vector)
mmlu_perfs = []
for coeff in coeffs:
    model.reset_all()
    model.set_add_activations(layer, vector * coeff)
    result = get_mmlu_average(model, questions)
    mmlu_perfs.append(result)

In [14]:
mmlu_perfs

[0.49557894736842106, 0.5060000000000002, 0.4938947368421052]

In [11]:
from tqdm import tqdm

result_dict_fiction = {}
result_dict_mixed = {}

# Loop over each steering vector for fiction data, with a progress bar
for steering_vector, steering_vector_name in zip(steering_vectors_fiction, question_types):
    for coeff in tqdm(coeff_list, desc=f"Processing fiction {steering_vector_name}"):
        model.reset_all()
        model.set_add_activations(layer, steering_vector * coeff)
        result = get_mmlu_average(model, questions)
        result_dict_fiction[f"{steering_vector_name}_{coeff}"] = result

# Loop over each steering vector for mixed data, with a progress bar
for steering_vector, steering_vector_name in zip(steering_vectors_mixed, question_types):
    for coeff in tqdm(coeff_list, desc=f"Processing mixed {steering_vector_name}"):
        model.reset_all()
        model.set_add_activations(layer, steering_vector * coeff)
        result = get_mmlu_average(model, questions)
        result_dict_mixed[f"{steering_vector_name}_{coeff}"] = result


Processing fiction direct_questions: 100%|██████████| 9/9 [06:49<00:00, 45.49s/it]
Processing fiction questioning_assuming_statement: 100%|██████████| 9/9 [06:51<00:00, 45.72s/it]
Processing fiction conversation:  89%|████████▉ | 8/9 [06:05<00:45, 45.71s/it]

In [None]:
# Save the results
with open("results_mmlu_fiction.pkl", "wb") as f:
    pickle.dump(result_dict_fiction, f)

with open("results_mmlu_mixed.pkl", "wb") as f:
    pickle.dump(result_dict_mixed, f)

In [12]:
# Save the results
with open("mmlu_results.json", "w") as f:
    json.dump(result_dict, f, indent=4)

In [4]:
result_dict_fiction = pickle.load(open("results_mmlu_fiction.pkl", "rb"))
result_dict_mixed = pickle.load(open("results_mmlu_mixed.pkl", "rb"))

In [5]:
result_dict_fiction

{'direct_questions_-10': 0.0,
 'direct_questions_-7.5': 0.0,
 'direct_questions_-5': 0.002210526315789475,
 'direct_questions_-2.5': 0.043438596491228026,
 'direct_questions_0': 0.4765614035087718,
 'direct_questions_2.5': 0.028596491228070082,
 'direct_questions_5': 0.0,
 'direct_questions_7.5': 0.0,
 'direct_questions_10': 0.0,
 'questioning_assuming_statement_-10': 0.0,
 'questioning_assuming_statement_-7.5': 0.0,
 'questioning_assuming_statement_-5': 0.0032280701754385985,
 'questioning_assuming_statement_-2.5': 0.062350877192982365,
 'questioning_assuming_statement_0': 0.4765614035087718,
 'questioning_assuming_statement_2.5': 0.1089824561403509,
 'questioning_assuming_statement_5': 0.0,
 'questioning_assuming_statement_7.5': 0.0,
 'questioning_assuming_statement_10': 0.0,
 'conversation_-10': 0.0,
 'conversation_-7.5': 0.0,
 'conversation_-5': 0.006350877192982461,
 'conversation_-2.5': 0.12210526315789486,
 'conversation_0': 0.4765614035087718,
 'conversation_2.5': 0.02280701754

In [None]:
from tqdm import tqdm

result_dict_fiction = {}


# Loop over each steering vector for fiction data, with a progress bar
for steering_vector, steering_vector_name in zip(steering_vectors_fiction, question_types):
    for coeff in tqdm(coeff_list, desc=f"Processing fiction {steering_vector_name}"):
        model.reset_all()
        model.set_add_activations(layer, steering_vector * coeff)
        result = get_mmlu_average(model, questions)
        result_dict_fiction[f"{steering_vector_name}_{coeff}"] = result

# Loop over each steering vector for mixed data, with a progress bar
for steering_vector, steering_vector_name in zip(steering_vectors_mixed, question_types):
    for coeff in tqdm(coeff_list, desc=f"Processing mixed {steering_vector_name}"):
        model.reset_all()
        model.set_add_activations(layer, steering_vector * coeff)
        result = get_mmlu_average(model, questions)
        result_dict_mixed[f"{steering_vector_name}_{coeff}"] = result
