In [1]:
import os

# Set Hugging Face cache directories to a writable location
cache_dir = os.path.expanduser('~/.cache/huggingface')
os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_DATASETS_CACHE'] = cache_dir

# Clean imports using our editable package
from src.evals import load_dataset_for_ccs
from src.methods import get_results_on_dataset
from src.plotting import plot_auroc_comparison
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
all_available_datasets = [
    # "boolq",
    # "dbpedia_14",
    # "imdb",
    "ag_news",
    # "amazon_polarity",
    # "piqa",
    #"got_sp_en_trans", # Empty for some reason
    # "rte",
    #"got_larger_than",
    #"got_cities",
    #"copa",
]

#model_name = "meta-llama/Llama-2-13b-hf"
#model_name = "meta-llama/Llama-3.1-8B"
model_name = "meta-llama/Llama-2-13b-chat-hf"
    
all_results_dict = {}
for dataset_name in all_available_datasets:
    dataset = load_dataset_for_ccs(
        dataset_name=dataset_name,
        split="train", 
        max_examples=200,
        seed=42,
    )
    all_results_dict[dataset_name] = get_results_on_dataset(
        dataset=dataset,
        model_name=model_name,
        layer_idx=33,
        batch_size=8,
        run_methods = ["supervised", "ccs"]
    )

# Option 1: Modern bar plot
fig1, ax1 = plot_auroc_comparison(all_results_dict, style='modern')
plt.show()

In [None]:
print(all_results_dict)