In [None]:
%load_ext autoreload
%autoreload 2

from datasets import load_dataset, load_metric, Dataset
from data_sets.data_utils import load_hsd_dataset, get_suite
from transformers import GPT2TokenizerFast, T5TokenizerFast, AutoTokenizer
from utils.results import *
import json
import pandas as pd
from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import hmean, mode
from nltk.tokenize import sent_tokenize 
import pickle
import config
import re
from collections import Counter
import time
from scipy.stats import pearsonr
from utils.util import initialize_seeds

In [None]:
initialize_seeds()

In [None]:
verbose = False

In [None]:
flan_tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-small")
zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")

In [None]:
def f1_score_rules(prediction, ground_truth):
    prediction_rules = [int(rule) for rule in prediction]
    ground_truth =  [ground_truth]
    common = Counter(prediction_rules) & Counter(ground_truth)
    num_same = sum(common.values())
    if num_same == 0:
        return 0, 0, 0
    precision = 1.0 * num_same / len(prediction_rules)
    recall = 1.0 * num_same / len(ground_truth)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall

In [None]:
def get_complete_idxs(task, results, dataset_name, verbose=False, suite=False):
    idxs = {}
    max_length = 240 if task == "rc" else 170
    for model, result in results.items():
        i=0
        print(model)
        if "flan" in model:
            dataset = Dataset.from_list([{"prompt": prompt} for prompt in result])
            tokenizer = flan_tokenizer
        elif "zephyr" in model:
            dataset = Dataset.from_list([{"prompt": prompt[0]["generated_text"]} for prompt in result])
            tokenizer = zephyr_tokenizer
        if "chatGPT" not in model:
            tokenized_prompts = dataset.map(lambda x: tokenizer(x["prompt"], truncation=True),
                                remove_columns=dataset.column_names,
                                batched=True)
            result_post = tokenized_prompts["input_ids"]
        else:
            with open(f"./responses/results/{task}/{dataset_name}/{model}.json", "r") as file:
                result_post = json.load(file)
        for pred in result_post:
            if "chatGPT" not in model:
                complete = len(pred) != max_length
            else:
                complete = pred["choices"][0]["finish_reason"] != "length"
            if not suite and complete: idxs.setdefault(model, []).append(i)
            if verbose and not complete:
                print(pred)
                print("============================")
                time.sleep(.1)
            if suite: idxs.setdefault(model, []).append(complete)
            i+=1
        print(1- len(idxs[model])/len(result)) if not suite else print(1 - np.mean(idxs[model]))
    return idxs

In [None]:
def complete_results(task, dataset_test, results, metric, idxs, label_col="label"):
    scores = {}
    for model, result in results.items():
        complete_examples = dataset_test.select(indices=idxs[model])
        complete_preds =list(np.take(result,idxs[model]))
        dataset_scores, preds = get_dataset_scores(task, {model: complete_preds}, complete_examples[label_col], metric)
        scores[model] = dataset_scores[model]
    return scores

In [None]:
def complete_results_suite(task, suite_test, all_idxs):
    scores = {}
    for model, idxs in all_idxs.items():
        with open(f"results/{task}/suite/{model}_hits.json", "r") as file:
            hits = json.load(file)
        hit_mask = {}
        keeps = []
        current_func = None
        current_id = None
        for example, idx in zip(suite_test, idxs):
            func, test_id = example["functionality"], example["test_id"]
            if func in ['"used to" should reduce', "reducers"]:
                continue
            if not (func == current_func and test_id == current_id):
                if len(keeps) > 0:
                    hit_mask.setdefault(current_func, []).append(np.all(keeps))
                keeps = []
            keeps.append(idx)
            current_func = func
            current_id = test_id
        hit_mask.setdefault(current_func, []).append(np.all(keeps))
        scores[model] = {}
        for k, v in hits.items():
            scores[model][k] = np.nanmean(np.array(v)[hit_mask[k]]) *100
    df = pd.DataFrame.from_dict(scores)
    df.loc["avg"] = df.mean(axis=0)
    return df

In [None]:
def get_ruleF1_and_hits(task, suite_test, rule_f1s):
    f1_dic = {}
    hits_dic = {}
    for model, f1s in rule_f1s.items():
        if model in ["random", "majority"]: continue
        with open(f"results/{task}/suite/{model}_hits.json", "r") as file:
            hits = json.load(file)
        sample_f1s = []
        current_func = None
        current_id = None
        for example, f1 in zip(suite_test, f1s):
            try:
                func, test_id = example["functionality"], example["test_id"]
            except KeyError:
                func, test_id = example["functionality"], example["case_id"]         
            if func in ['"used to" should reduce', "reducers"]:
                continue
            if not (func == current_func and test_id == current_id):
                if len(sample_f1s) > 0:
                    f1_dic.setdefault(model, {}).setdefault(current_func, []).append(np.mean(sample_f1s))
                sample_f1s = []
            sample_f1s.append(f1)
            current_func = func
            current_id = test_id
        f1_dic.setdefault(model, {}).setdefault(current_func, []).append(np.mean(sample_f1s))
        hits_dic[model] = hits
    return f1_dic, hits_dic

In [None]:
def eval_rules(task, all_preds, suite_test, verbose=False):
    with open(f"./data/{task}/suite/func_desc.pkl", "rb") as file:
        func_desc = pickle.load(file)
    func2id = {func: idx + 1 for idx, func in enumerate(func_desc.keys())}
    gold_rules = []
    results = {}
    for example in suite_test:
        gold_rules.append(func2id[example["functionality"]])
    for model, preds in all_preds.items():
        rules = []
        if type(preds[0]) == list:
            preds = [x[0]["generated_text"] for x in preds]
        for pred in preds:
            rule_set = set()
            if "example" in model:
                pred = pred.split("\n")[0]
                if "example" in model and verbose:
                    print(pred)
            if task == "rc" and "example" not in model:
                sentences = sent_tokenize(pred)
                if len(pred) == 0:
                    pred = ""
                else:
                    pred = " ".join(sentences[:-1]) if "rule" in sentences[0].lower() else " ".join(sentences[1:])
#             matches = re.finditer(r"\d+(?!([^.\n]*not (appl|relevant)))",pred, re.MULTILINE | re.IGNORECASE)
            matches = re.finditer(r"\d+",pred, re.MULTILINE | re.IGNORECASE)
            rule_set.update([m.group() for m in matches])
            matches_range = re.finditer(r"(?<!digits )\d+-\d+",pred, re.MULTILINE | re.IGNORECASE)
            for m in matches_range:
                a, b = m.group().split("-")
                rule_set.update(range(int(a), int(b) + 1))
            rules.append(list(rule_set))
        f1s, precisions, recalls = [], [], []
        for rule, gold in zip(rules, gold_rules):
            f1, precision, recall = f1_score_rules(rule, gold)
            f1s.append(f1)
            precisions.append(precision)
            recalls.append(recall)
        results[model] = f1s
        print(f"{model} results:")
        print(f"Precision: {np.mean(precisions)}")
        print(f"Recall: {np.mean(recalls)}")
        print(f"F1: {np.mean(f1s)}")
        print()
    for method in ["majority", "random"]:
        f1s, precisions, recalls = [], [], []
        if method == "majority":
            majority = mode(gold_rules)[0]
            pred_rules = [[str(majority[0])]] * len(gold_rules)
        else:
            pred_rules =  np.random.randint(1, len(func2id.keys())+1, size=len(gold_rules))[:, None]
        for rule, gold in zip(pred_rules, gold_rules):
            f1, precision, recall = f1_score_rules(rule, gold)
            f1s.append(f1)
            precisions.append(precision)
            recalls.append(recall)
        results[method] = f1s
        print(f"{method} results:")
        print(f"Precision: {np.mean(precisions)}")
        print(f"Recall: {np.mean(recalls)}")
        print(f"F1: {np.mean(f1s)}")
        print()
    return results

In [None]:
def rule_task_qual_corr(task, suite_test, rule_results, model = "chatGPT_seen_with_rules"):
    f1_dic, hits = get_ruleF1_and_hits(task, suite_test, rule_results)

    rule_qual_agg= []
    task_qual_agg= []
    rule_qual = []
    task_qual = []

    for func, f1s in f1_dic[model].items():
        hs = np.array(hits[model][func])
        f1s = np.array(f1s)
        f1s[np.isnan(hs)] = np.nan
        r_qual = np.nanmean(f1s)
        t_qual = np.nanmean(hs)
    #     print(f"{func}:")
    #     print(f"Pass rate: {np.nanmean(hs)}")
    #     print(f"Avg rule F1: {np.nanmean(f1s)}")
    #     print()
        rule_qual_agg.append(r_qual)
        task_qual_agg.append(t_qual)
        rule_qual.extend(f1s)
        task_qual.extend(hs)

#     fig, axs = plt.subplots(ncols=2)

#     sns.scatterplot(x=task_qual_agg, y=rule_qual_agg, ax = axs[0])
    corr, pvalue = pearsonr(task_qual_agg, rule_qual_agg)
    print("aggregated", corr, pvalue)
#     sns.stripplot(x=np.array(task_qual)[~np.isnan(task_qual)], y=np.array(rule_qual)[~np.isnan(rule_qual)], ax = axs[1])
    corr, pvalue = pearsonr(np.array(task_qual)[~np.isnan(task_qual)], np.array(rule_qual)[~np.isnan(rule_qual)])
    print("by sample", corr, pvalue)
    return rule_qual, task_qual

In [None]:
def get_parrot_freq(preds):
    parrot_list = []
    parrot_rationale =[]
    for pred in preds["chatGPT_seen_example_with_rules"]:
        if "{rule list}\n" in pred.lower():
            parrot_list.append(1)
        else:
            parrot_list.append(0)
        if  "{rationale}\n" in pred.lower():
            parrot_rationale.append(1)
        else:
            parrot_rationale.append(0)
    print(f"Rule list parroting: {np.mean(parrot_list)}")
    print(f"Rationale parroting: {np.mean(parrot_rationale)}")

In [None]:
def get_rat_samples(input_cols, suite_test, preds, n, func_desc, label = "label"):
    choices = np.random.choice(range(suite_test.num_rows), n, replace=False)
    func_to_id = {func: (num+1) for num, func in enumerate(func_desc)}
    samples = []
    for choice in choices:
        sample = {}
        sample["input"] ="\n".join([suite_test[col][choice] for col in input_cols])
        sample["rule"] = func_to_id[suite_test["functionality"][choice]]
        sample["func"] = suite_test["functionality"][choice]
        sample["label"] = suite_test[label][choice]
        sample["pred"] = preds["chatGPT_seen_example_with_rules"][choice]
        samples.append(sample)
    return samples

In [None]:
rat_samples = []

In [None]:
rule_results_dic = {}

# Sentiment Analysis

## SA

In [None]:
result_path = Path(f"./results/sa/sst2/")

In [None]:
results = load_results(result_path)

In [None]:
results = {k: v for k, v in results.items() if "rules" in k}

In [None]:
dataset_test = load_dataset("glue", "sst2")["validation"]

In [None]:
metric = load_metric("glue","sst2")

In [None]:
dataset_scores, preds = get_dataset_scores("sa", results, dataset_test["label"], metric)

In [None]:
sort = sorted(dataset_scores.items(), key=lambda item: item[1]["accuracy"], reverse=True)

In [None]:
sort

In [None]:
idxs = get_complete_idxs("sa", results, "sst2", verbose=False)

In [None]:
scores = complete_results("sa", dataset_test, results, metric, idxs)

In [None]:
sort = sorted(scores.items(), key=lambda item: item[1]["accuracy"], reverse=True); sort

## Suite

In [None]:
result_path = Path("./results/sa/suite/")

In [None]:
preds = load_results(result_path)

In [None]:
preds = {k: v for k, v in preds.items() if ("seen" in k or "baseline_zero" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
results = load_results(result_path, file_type="csv")

In [None]:
results = {k: v for k, v in results.items() if ("seen" in k or "baseline_zero" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
df =pd.concat([x for x in results.values()])

In [None]:
df["avg"].sort_values()

In [None]:
idxs = get_complete_idxs("sa", preds, "suite", suite=True)

In [None]:
suite_test = get_suite(config.sa_path)["test"]

In [None]:
# df = complete_results_suite("sa", suite_test, idxs)

In [None]:
# df.loc["avg"].sort_values()

In [None]:
type(preds["flan-t5-base_classOut_example_with_rules"][0])

In [None]:
rule_results = eval_rules("sa", {k: v for k,v in preds.items() if "seen" in k}, suite_test)

In [None]:
get_parrot_freq(preds)

In [None]:
rule_results_dic["sa"] = rule_results

In [None]:
rule_task_qual_corr("sa", suite_test, rule_results)

In [None]:
rule_task_qual_corr("sa", suite_test, rule_results, model="chatGPT_seen_example_with_rules")

In [None]:
with open("./data/sa/suite/func_desc.pkl", "rb") as file:
    func_desc = pickle.load(file)

In [None]:
func_to_id = {func: (num+1) for num, func in enumerate(func_desc)}

In [None]:
samples = get_rat_samples(["test_case"], suite_test, preds, 10, func_desc)

In [None]:
rat_samples.extend(samples)

# Paraphrase identification

## QQP

In [None]:
result_path = Path(f"./results/pi/qqp/")

In [None]:
results = load_results(result_path)

In [None]:
results = {k: v for k, v in results.items() if "rules" in k}

In [None]:
dataset_test = load_dataset("glue", "qqp")["validation"]

In [None]:
metric = load_metric("glue","qqp")

In [None]:
dataset_scores, preds = get_dataset_scores("pi", results, dataset_test["label"], metric)

In [None]:
dataset_scores

In [None]:
sort = sorted(dataset_scores.items(), key=lambda item: item[1]["accuracy"], reverse=True); sort

In [None]:
idxs = get_complete_idxs("pi", results, "qqp", verbose=False)

In [None]:
scores = complete_results("pi", dataset_test, results, metric, idxs)

In [None]:
sort = sorted(scores.items(), key=lambda item: item[1]["accuracy"], reverse=True); sort

## Suite

In [None]:
result_path = Path("./results/pi/suite/")

In [None]:
preds = load_results(result_path)

In [None]:
preds = {k: v for k, v in preds.items() if ("seen" in k or "baseline_zero" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
results = load_results(result_path, file_type="csv")

In [None]:
results = {k: v for k, v in results.items() if ("seen" in k or "baseline_zero" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
df =pd.concat([x for x in results.values()])

In [None]:
df["avg"].sort_values()

In [None]:
idxs = get_complete_idxs("pi", preds, "suite", suite=True)

In [None]:
suite_test = get_suite(config.pi_path)["test"]

In [None]:
# df = complete_results_suite("pi", suite_test, idxs)

In [None]:
#df.loc["avg"].sort_values()

In [None]:
rule_results =  eval_rules("pi", {k: v for k,v in preds.items() if "seen" in k}, suite_test)

In [None]:
get_parrot_freq(preds)

In [None]:
rule_results_dic["pi"] = rule_results

In [None]:
rule_task_qual_corr("pi", suite_test, rule_results)

In [None]:
rule_task_qual_corr("pi", suite_test, rule_results, model="chatGPT_seen_example_with_rules")

In [None]:
with open("./data/pi/suite/func_desc.pkl", "rb") as file:
    func_desc = pickle.load(file)

In [None]:
func_to_id = {func: (num+1) for num, func in enumerate(func_desc)}

In [None]:
samples = get_rat_samples(["question1", "question2"], suite_test, preds, 10, func_desc)

In [None]:
rat_samples.extend(samples)

# Reading comprehension

## SQuAD

In [None]:
result_path = Path(f"./results/rc/squad/")

In [None]:
results = load_results(result_path)

In [None]:
results = {k: v for k, v in results.items() if "rules" in k}

In [None]:
dataset_test = load_dataset("squad")["validation"]

In [None]:
metric = load_metric("squad")

In [None]:
dataset_scores, preds = get_dataset_scores("rc", results, dataset_test["answers"], metric)

In [None]:
sort = sorted(dataset_scores.items(), key=lambda item: item[1]["exact_match"], reverse=True); sort

In [None]:
idxs = get_complete_idxs("rc", results, "squad", verbose=False)

In [None]:
scores = complete_results("rc", dataset_test, results, metric, idxs, label_col="answers")

In [None]:
sort = sorted(scores.items(), key=lambda item: item[1]["exact_match"], reverse=True); sort

## Suite

In [None]:
result_path = Path("./results/rc/suite/")

In [None]:
preds = load_results(result_path)

In [None]:
preds = {k: v for k, v in preds.items() if ("seen" in k or "baseline" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
results = load_results(result_path, file_type="csv")

In [None]:
results = {k: v for k, v in results.items() if ("seen" in k or "baseline" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
df =pd.concat([x for x in results.values()])

In [None]:
df["avg"].sort_values()

In [None]:
idxs = get_complete_idxs("rc", preds, "suite", suite=True)

In [None]:
suite_test = get_suite(config.rc_path)["test"]

In [None]:
# df = complete_results_suite("rc", suite_test, idxs)

In [None]:
# df.loc["avg"].sort_values()

In [None]:
rule_results = eval_rules("rc", {k: v for k,v in preds.items() if "seen" in k}, suite_test)

In [None]:
get_parrot_freq(preds)

In [None]:
rule_results_dic["rc"] = rule_results

In [None]:
rule_task_qual_corr("rc", suite_test, rule_results)

In [None]:
rule_task_qual_corr("rc", suite_test, rule_results, model="chatGPT_seen_example_with_rules")

In [None]:
with open("./data/rc/suite/func_desc.pkl", "rb") as file:
    func_desc = pickle.load(file)

In [None]:
samples = get_rat_samples(["context", "question"], suite_test, preds, 10, func_desc, label="answers")

In [None]:
rat_samples.extend(samples)

# Hate Speech detection

## Datasets

In [None]:
davidson_path = Path(f"./results/hsd/davidson2017/")
founta_path = Path(f"./results/hsd/founta2018/")

In [None]:
davidson_results = load_results(davidson_path)
founta_results = load_results(founta_path)

In [None]:
davidson_test = load_hsd_dataset("davidson2017")["test"]
founta_test = load_hsd_dataset("founta2018")["test"]

In [None]:
davidson_results = {k: v for k,v in davidson_results.items() if "rules" in k}

In [None]:
founta_results = {k: v for k,v in founta_results.items() if "rules" in k}

In [None]:
metric = load_metric("glue","qqp")

In [None]:
dataset_scores, preds = get_dataset_scores("hsd", davidson_results, davidson_test["label"], metric)

In [None]:
sort = sorted(dataset_scores.items(), key=lambda item: item[1]["f1"], reverse=True)

In [None]:
sort

In [None]:
dataset_scores, preds = get_dataset_scores("hsd", founta_results, founta_test["label"], metric)

In [None]:
sort = sorted(dataset_scores.items(), key=lambda item: item[1]["f1"], reverse=True)

In [None]:
sort

In [None]:
idxs = get_complete_idxs("hsd", davidson_results, "davidson2017", verbose=False)

In [None]:
scores = complete_results("hsd", davidson_test, davidson_results, metric, idxs)

In [None]:
sort = sorted(scores.items(), key=lambda item: item[1]["f1"], reverse=True); sort

In [None]:
idxs = get_complete_idxs("hsd", founta_results, "founta2018", verbose=False)

In [None]:
scores = complete_results("hsd",founta_test, founta_results, metric, idxs)

In [None]:
sort = sorted(scores.items(), key=lambda item: item[1]["f1"], reverse=True); sort

## Suite

In [None]:
result_path = Path("./results/hsd/suite/")

In [None]:
results = load_results(result_path, hatecheck=True)

In [None]:
results = {k: v for k, v in results.items() if ("seen" in k or "baseline_zero" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
df =pd.concat([x for x in results.values()])

In [None]:
df *= 100

In [None]:
df["avg"].sort_values()

In [None]:
preds = load_results(result_path)

In [None]:
preds = {k: v for k, v in preds.items() if ("seen" in k or "baseline_zero" in k or "funcOut" in k or "classOut" in k) and "with_rules" in k}

In [None]:
idxs = get_complete_idxs("hsd", preds, "suite")

In [None]:
suite_test= get_suite(config.hatecheck_path, hateCheck=True)["test"]

In [None]:
# complete_dic = {}
# for model, result in preds.items():
#     complete_examples = suite_test.select(indices=idxs[model])
#     complete_preds =list(np.take(result,idxs[model]))
#     complete_preds = [pred.split()[0] if ("yes" in pred.split()[0].lower() or "no" in pred.split()[0].lower()) else pred.split()[-1] for pred in complete_preds]
#     complete_preds = [1 if "yes" in pred.lower() else 0 for pred in complete_preds]
#     funcs = complete_examples["functionality"]
#     hits = pd.DataFrame.from_dict({"funcs": funcs, "hits": (np.array(complete_examples["label_gold"])== np.array(complete_preds)).astype(int), "labels": complete_examples["label_gold"]})
#     complete_dic[model] = pd.DataFrame.from_dict({model: hits.groupby("funcs").mean().to_dict()["hits"]}, orient="index")

In [None]:
# complete_df = pd.concat([x for x in complete_dic.values()])

In [None]:
# complete_df *= 100

In [None]:
# complete_df["avg"] = complete_df.mean(axis=1)

In [None]:
# complete_df["avg"].sort_values()

In [None]:
rule_results = eval_rules("hsd", {k: v for k,v in preds.items() if "seen" in k}, suite_test, verbose=False)

In [None]:
get_parrot_freq(preds)

In [None]:
rule_results_dic["hsd"] = rule_results

In [None]:
rule_task_qual_corr("hsd", suite_test, rule_results)

In [None]:
rule_task_qual_corr("hsd", suite_test, rule_results, model = "chatGPT_seen_example_with_rules")

In [None]:
with open("./data/hsd/suite/func_desc.pkl", "rb") as file:
    func_desc = pickle.load(file)

In [None]:
samples = get_rat_samples(["test_case"], suite_test, preds, 10, func_desc, label="label_gold")

In [None]:
rat_samples.extend(samples)

In [None]:
# pd.DataFrame(rat_samples).to_csv("./results/rat_samples.csv", index=False)

In [None]:
pd.read_csv("./results/rat_samples.csv").iloc[7]["input"]

## Generate rule evaluation table

In [None]:
model_order = ["random", "small", "base", "large", "xl", "xxl", "beta", "chatGPT"]

score_order = ["baseline", "Task", "Task+Spec"]

add_order = ["", "+Ex", "+Rat", "+Ex+Rat"]

method_order = [score + add for score in score_order for add in add_order]

order = {x: i for i, x in enumerate(model_order + method_order)}

In [None]:
def process_df(df):
    df['model'] = [x.split("_")[0] for x in df.index]
    df['method'] = ["Task" if "baseline" in x else ("" if ("random" in x or "majority" in x) else f"Task+Spec") for x in df.index]
    df['method'] =  [y+"+Ex"  if "example" in x else y for x,y in zip(df.index, df.method)]
    df['method'] =  [y+"+Rat"  if "rules" in x else y for x,y in zip(df.index, df.method)]
    df["model"] = df.model.str.split("-").str[-1]
    df = df.set_index(["model", "method"])
    df = df.sort_index(key=lambda x: x.map(order))
    return df

In [None]:
rules_f1s = {}
for task, results in rule_results_dic.items():
    for model, f1s in results.items():
        rules_f1s.setdefault(task, {})[model] = np.mean(f1s)

In [None]:
df = pd.DataFrame.from_dict(rules_f1s)

In [None]:
df = process_df(df)

In [None]:
df = df[["sa", "pi", "rc", "hsd"]]

In [None]:
df = df.rename(columns={"sa": "SENT",
           "pi": "PARA",
           "rc": "READ",
           "hsd": "HATE"})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10,3))
df.loc[(["beta","chatGPT"], slice(None))].plot.bar(ax=ax)
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
for line, color in zip(df.loc[("random", "")], cycle):
    ax.axhline(line, color=color)
# for line, color in zip(df.loc[("majority", "")], cycle):
#     ax.axhline(line, color=color, linestyle="--")
ax.xaxis.label.set_visible(False)
ax.set_xticklabels(["\n".join(x._text.strip("()").replace("beta", "zephyr").split(",")) for x in ax.get_xticklabels()], rotation=0)

In [None]:
fig.savefig(f"../specification-instruction-paper/media/rule_prediction.pdf", bbox_inches = "tight")