This notebook parses results from the raw `templates_ensembles.py` output. In cells below, you may find `ensembles_data_path` variable; feel free to replace its value with your `--save_dir` used for `templates_ensembles.py` evaluation.

Results are saved to `"ensembles.csv"`. If you only wish to reproduce plots and tables, please follow directly to `ensembles_results.ipynb`.

In [1]:
import os
import torch
import numpy as np
import pandas as pd

In [2]:
datasets = ["trec", "sst2", "dbpedia", "agnews"]

labels_mp = {
    "trec": dict(enumerate(["Description", "Entity", "Expression", "Human", "Location", "Number"])),
    "sst2": dict(enumerate(['negative', 'positive'])),
    "dbpedia": dict(enumerate(["Company", "Educational Institution", "Artist", "Athlete", "Office Holder",
                                     "Mean Of Transportation", "Building", "Natural Place", "Village", "Animal",
                                     "Plant", "Album", "Film", "Written Work"])),
    "agnews": dict(enumerate(["World", "Sports", "Business", "Technology"])),
}

In [3]:
from data import load_split_dataset

val = {}
for dataset in datasets:
    _, dataset_val, _ = load_split_dataset(dataset)
    val[dataset] = dataset_val

In [4]:
ensembles_data_path = "ensembles_data"  # your path to raw ensembles data

ensembles = {dataset: {} for dataset in datasets}

for pwd, dirs, files in os.walk(ensembles_data_path):
    if files:
        for file in files:
            if not file.endswith(".out"):
                continue
            path = os.path.join(pwd, file)
            curr_res = torch.load(path)
            dataset = path.split(r"/")[2]
            model = list(curr_res.keys())[0]

            for model in curr_res:
                if model not in ensembles[dataset]:
                    ensembles[dataset][model] = {}
                for pred_method in curr_res[model]:
                    if pred_method not in ensembles[dataset][model]:
                        ensembles[dataset][model][pred_method] = []
                    for seed in curr_res[model][pred_method]:
                        for elem in curr_res[model][pred_method][seed]["probs"]:
                            if not torch.any(torch.isnan(elem[:len(val[dataset])])):
                                ensembles[dataset][model][pred_method].append(elem[:len(val[dataset])])

In [5]:
size_to_mean_std = {}
ensembles_df = {"dataset": [], "model": [], "pred_method": [], "size": [], "mean": [], "std": []}
for dataset in ["trec", "sst2"]:
    val_size = len(val[dataset])
    size_to_mean_std[dataset] = {}
    for model in ensembles[dataset]:
        size_to_mean_std[dataset][model] = {}
        for pred_method in ensembles[dataset][model]:
            size_to_mean_std[dataset][model][pred_method] = {}
            for size in range(1, 6):
                mean_means = []
                for shift in [0, 5, 10]:
                    probs = torch.stack([
                        torch.nn.functional.softmax(ensembles[dataset][model][pred_method][i + shift]) for i in range(size)
                    ])
                    answers = [labels_mp[dataset][x.item()] for x in probs.mean(dim=0).argmax(dim=1)]
                    mean_mean = (answers == val[dataset]).mean()
                    mean_means.append(mean_mean)
                size_to_mean_std[dataset][model][pred_method][size] = (np.mean(mean_means), np.std(mean_means))
                ensembles_df["dataset"].append(dataset)
                ensembles_df["model"].append(model)
                ensembles_df["pred_method"].append(pred_method)
                ensembles_df["size"].append(size)
                ensembles_df["mean"].append(np.mean(mean_means))
                ensembles_df["std"].append(np.std(mean_means))

In [6]:
ensembles_df = pd.DataFrame(ensembles_df)
ensembles_df.head()

Unnamed: 0,dataset,model,pred_method,size,mean,std
0,trec,llama-7b,calibrate_True,1,0.191333,0.042153
1,trec,llama-7b,calibrate_True,2,0.432667,0.105708
2,trec,llama-7b,calibrate_True,3,0.46,0.097379
3,trec,llama-7b,calibrate_True,4,0.515333,0.049026
4,trec,llama-7b,calibrate_True,5,0.533333,0.044582


In [7]:
ensembles_df.to_csv("ensembles.csv", index=False)