# SIGIR

## Imports and configuration

In [None]:
from math import nan, isnan
from pathlib import Path
from typing import Sequence
from statistics import mean

from matplotlib.colors import LogNorm
from numpy import array, stack, float_, sign
from numpy.typing import NDArray
from pandas import read_json, DataFrame, isna, concat, Categorical
from pyterrier.datasets import get_dataset, Dataset
from pyterrier.terrier import IterDictIndexer, IndexFactory
from seaborn import FacetGrid, histplot, lineplot, heatmap
from tqdm.auto import tqdm

from axioms.axiom import (
    Axiom,
    # Adapted axioms
    GEN_aSL,
    # GEN_ArgUC,
    # GEN_QTArg,
    # GEN_QTPArg,
    GEN_LNC1,
    GEN_TF_LNC,
    GEN_PROX1,
    GEN_PROX2,
    GEN_PROX3,
    GEN_PROX4,
    GEN_PROX5,
    GEN_REG,
    # GEN_ANTI_REG,
    GEN_AND,
    # GEN_LEN_AND,
    # GEN_M_AND,
    GEN_DIV,
    # GEN_LEN_DIV,
    GEN_TFC1,
    GEN_STMC1,
    GEN_STMC2,
    # Generation-specific axioms
    CLAR1,
    CLAR2,
    CLAR3,
    CLAR4,
    CLAR5,
    CLAR6,
    CLAR7,
    CONS1,
    CONS2,
    CONS3,
    CONS4,
    COV1,
    COV2,
    COV3,
    COV4,
    COV5,
    # Oracle axioms
    TrecRagNuggetAxiom,
    TrecRagCrowdAxiom,
    # Utility axioms
    MajorityVoteAxiom,
)
from axioms.model import GenerationInput, GenerationOutput, Preference, PreferenceMatrix
from axioms.tools import KeyBertAspectExtraction

In [2]:
cache_path = Path("../data/cache/")

## Datasets

In [3]:
rag_assignments_path = Path("../data/nugget_assignment.20241108.jl")
rag_crowd_responses_path = Path("../data/crowd/responses.jsonl.gz")
rag_crowd_ratings_path = Path("../data/crowd/ratings.jsonl.gz")

In [None]:
def read_rag_assignments_outputs_df() -> DataFrame:
    df = read_json(rag_assignments_path, lines=True)
    df.drop(columns=["response_length", "nuggets"], inplace=True)
    df["query"] = df["query"].fillna("")
    df["answer_text"] = df["answer_text"].fillna("")
    df.rename(columns={"run_id": "name", "answer_text": "text"}, inplace=True)
    df["context"] = nan
    return df[["qid", "query", "context", "name", "text"]]

read_rag_assignments_outputs_df()

In [None]:
def read_rag_crowd_outputs_df() -> DataFrame:
    df = read_json(rag_crowd_responses_path, lines=True)

    df_ratings = read_json(rag_crowd_ratings_path, lines=True)
    ratings_responses = concat([df_ratings["response_a"], df_ratings["response_b"]]).unique()
    df = df[df["response"].isin(ratings_responses)]

    df["name"] = df["kind"] + "_" + df["style"]
    df = df[df["kind"] == "human"]
    df.drop(columns=["references_ids", "cleaned_text", "statements", "kind", "style", "response"], inplace=True)
    df.rename(columns={"topic": "qid", "references_texts": "context", "raw_text": "text"}, inplace=True)

    return df[["qid", "query", "context", "name", "text"]]

read_rag_crowd_outputs_df()

In [6]:
def to_inputs_outputs(df: DataFrame) -> list[tuple[GenerationInput, list[GenerationOutput]]]:
    data: list[tuple[GenerationInput, list[GenerationOutput]]] = []
    runs = sorted(df["name"].unique())
    for (query_id, query), df_query in df.groupby(
        ["qid", "query"], group_keys=False, sort=False, as_index=False,
    ):
        contexts = {
            tuple(row["context"]) 
            if isinstance(row["context"], Sequence) else None
            for _, row in df_query.iterrows()
        }
        if len(contexts) > 1:
            raise ValueError(f"Multiple contexts for query {query_id}: {'; '.join(contexts)}")
        context = next(iter(contexts))
        input = GenerationInput(
            id=query_id,
            text=query,
            context=context,
        )
        df_query = df_query.drop(columns=["qid", "query"])
        outputs = {
            row["name"]: GenerationOutput(
                id=row["name"],
                text=row["text"],
            )
            for _, row in df_query.iterrows()
        }
        data.append(
            (
                input,
                [
                    outputs.get(
                        run,
                        GenerationOutput(
                            id=run,
                            text="",
                        ),
                    )
                    for run in runs
                ],
            )
        )
    return data

In [None]:
print(to_inputs_outputs(read_rag_assignments_outputs_df())[:3])

In [None]:
print(to_inputs_outputs(read_rag_crowd_outputs_df())[:3])

## Oracle axiom preferences

In [None]:
rag_oracle_axioms: list[tuple[str, Axiom[GenerationInput, GenerationOutput]]] = [
    *[
        (
            f"ORACLE-NUGGET-{score_type.upper()}{"-STRICT" if strict else ''}{f'-{margin_fraction:.1f}' if margin_fraction > 0.0 else ''}",
            TrecRagNuggetAxiom(
                assignments_path=rag_assignments_path,
                score_type=score_type,
                strict=False,
                margin_fraction=margin_fraction,
            ),
            "assignments",
        )
        for score_type in (
            "all",
            "vital",
            "weighted"
        )
        for strict in (
            True,
            False,
        )
        for margin_fraction in (
            # 0.0,
            0.1,
            # 0.2,
            # 0.3,
            # 0.4,
            # 0.5,
        )
    ],
    *[
        (
            f"ORACLE-CROWD-{utility_type.upper()}{f'-{margin_fraction:.1f}' if margin_fraction > 0.0 else ''}",
            TrecRagCrowdAxiom(
                responses_path=rag_crowd_responses_path,
                ratings_path=rag_crowd_ratings_path,
                utility_type=utility_type,
                margin_fraction=margin_fraction,
            ),
            "crowd",
        )
        for utility_type in (
            "overall",
            "coherence",
            "consistency",
            "correctness",
            "coverage",
        )
        for margin_fraction in (
            # 0.0,
            0.1,
            # 0.2,
            # 0.3,
            # 0.4,
            # 0.5,
        )
    ],
]
rag_oracle_axioms = [
    (name, axiom.cached(cache_path / "axioms" / f"{name}.cache"), run_type)
    for name, axiom, run_type in rag_oracle_axioms
]
rag_oracle_axioms

In [None]:
rag_oracle_prefs = [
    (name, data_run_type, stack(
        [
            axiom.preferences(input, outputs)
            for input, outputs in tqdm(data, desc=name, unit="query")
        ]
    ))
    for data, data_run_type in (
        (to_inputs_outputs(read_rag_assignments_outputs_df()), "assignments"),
        (to_inputs_outputs(read_rag_crowd_outputs_df()), "crowd"),
    )
    for name, axiom, axiom_run_type in rag_oracle_axioms
    if data_run_type == axiom_run_type
]

## Adapted retrieval axiom preferences

In [11]:
rag_retrieval_axioms: list[tuple[str, Axiom[GenerationInput, GenerationOutput]]] = [
    ("GEN-TFC1", GEN_TFC1()),
    ("GEN-LNC1", GEN_LNC1()),
    ("GEN-REG", GEN_REG()),
    ("GEN-AND", GEN_AND()),
    ("GEN-DIV", GEN_DIV()),
    ("GEN-STMC1", GEN_STMC1()),
    ("GEN-STMC2", GEN_STMC2()),
    ("GEN-PROX1", GEN_PROX1()),
    ("GEN-PROX2", GEN_PROX2()),
    ("GEN-PROX3", GEN_PROX3()),
    ("GEN-PROX4", GEN_PROX4()),
    ("GEN-PROX5", GEN_PROX5()),
    ("GEN-aSL", GEN_aSL()),
    ("GEN-TF-LNC", GEN_TF_LNC()),
]
rag_retrieval_axioms = [
    (name, axiom.cached(cache_path / "axioms" / f"{name}.cache"))
    for name, axiom in rag_retrieval_axioms
]

In [None]:
rag_retrieval_axiom_prefs = [
    (name, data_run_type, stack(
        [
            axiom.preferences(input, outputs)
            for input, outputs in tqdm(data, desc=name, unit="query")
        ]
    ))
    for data, data_run_type in (
        (to_inputs_outputs(read_rag_assignments_outputs_df()), "assignments"),
        (to_inputs_outputs(read_rag_crowd_outputs_df()), "crowd"),
    )
    for name, axiom in rag_retrieval_axioms
]

## New generation axiom preferences

In [46]:
key_bert = KeyBertAspectExtraction()
rag_generation_axioms: list[tuple[str, Axiom[GenerationInput, GenerationOutput]]] = [
    ("CLAR1", CLAR1()),
    ("CLAR2", CLAR2()),
    ("CLAR3", CLAR3()),
    ("CLAR4", CLAR4()),
    ("CLAR5", CLAR5()),
    ("CLAR6", CLAR6()),
    ("CLAR7", CLAR7()),
    ("CONS1", CONS1()),
    ("CONS1-KB", CONS1(aspect_extraction=key_bert)),
    ("CONS2", CONS2()),
    ("CONS2-KB", CONS2(aspect_extraction=key_bert)),
    ("CONS3", CONS3()),
    ("CONS3-KB", CONS3(aspect_extraction=key_bert)),
    ("CONS4", CONS4()),
    ("COV1", COV1()),
    ("COV1-KB", COV1(aspect_extraction=key_bert)),
    ("COV2", COV2()),
    ("COV2-KB", COV2(aspect_extraction=key_bert)),
    ("COV3", COV3()),
    ("COV3-KB", COV3(aspect_extraction=key_bert)),
    ("COV4", COV4()),
    ("COV5", COV5()),
]
rag_generation_axioms = [
    (name, axiom.cached(cache_path / "axioms" / f"{name}.cache"))
    for name, axiom in rag_generation_axioms
]
# TODO: Add majority vote / consensus axiom.

In [None]:
rag_generation_axiom_prefs = [
    (name, data_run_type, stack(
        [
            axiom.preferences(input, outputs)
            for input, outputs in tqdm(data, desc=name, unit="query")
        ]
    ))
    for data, data_run_type in (
        (to_inputs_outputs(read_rag_assignments_outputs_df()), "assignments"),
        (to_inputs_outputs(read_rag_crowd_outputs_df()), "crowd"),
    )
    for name, axiom in rag_generation_axioms
]

## Evaluation

Here we count the number matches for all types of preferences

In [None]:
df_data = []
for oracle_name, oracle_data_run_type, oracle_preferences in rag_oracle_prefs:
    # Normalize preferences to -1, 0, 1.
    oracle_preferences = sign(oracle_preferences)

    for axiom_type, axiom_prefs in (
        ("adapted retrieval axiom", rag_retrieval_axiom_prefs),
        ("generation-specific axiom", rag_generation_axiom_prefs),
    ):

        for axiom_name, axiom_data_run_type, axiom_preferences in axiom_prefs:
            # Normalize preferences to -1, 0, 1.
            axiom_preferences = sign(axiom_preferences)

            if oracle_data_run_type != axiom_data_run_type:
                continue

            for oracle_preference in (-1, 0, 1):
                for axiom_preference in (-1, 0, 1):

                    matching_preferences = (oracle_preferences == oracle_preference) & (
                        axiom_preferences == axiom_preference
                    )

                    df_data.append(
                        {
                            "data_run_type": oracle_data_run_type,
                            "oracle_name": oracle_name,
                            "axiom_name": axiom_name,
                            "axiom_type": axiom_type,
                            "oracle_preference": oracle_preference,
                            "axiom_preference": axiom_preference,
                            "count": matching_preferences.sum(),
                        }
                    )
df_distribution = DataFrame(df_data)
df_distribution

In [None]:
total_counts = (
    df_distribution
    .groupby(["data_run_type", "oracle_name", "axiom_name"])[["count"]]
    .sum()
    .rename(columns={"count": "total_count"})
)
total_counts

## Axiom decisiveness

In [None]:
df_decisiveness = df_distribution.copy()
df_decisiveness = df_decisiveness[df_decisiveness["axiom_preference"] != 0]
df_decisiveness = (
    df_decisiveness.groupby(["data_run_type", "oracle_name", "axiom_name", "axiom_type"])[["count"]]
    .sum()
    .reset_index()
    .rename(columns={"count": "non_zero_count"})
)
df_decisiveness = df_decisiveness.merge(total_counts, on=["data_run_type", "oracle_name", "axiom_name"])

df_decisiveness = df_decisiveness.groupby(by=list(set(df_decisiveness.columns) - {"oracle_name"})).first().reset_index().drop(columns="oracle_name")

df_decisiveness["zero_count"] = (
    df_decisiveness["total_count"] - df_decisiveness["non_zero_count"]
)
df_decisiveness["decisiveness"] = (
    df_decisiveness["non_zero_count"] / df_decisiveness["total_count"]
)

df_decisiveness = df_decisiveness[
    ["data_run_type", "axiom_name", "axiom_type", "zero_count", "non_zero_count", "decisiveness"]
]

df_decisiveness.sort_values(
    ["data_run_type", "decisiveness"],
    ascending=[True, False],
)

## Axiom consistency

In [None]:
df_consistency = df_distribution.copy()
df_consistency = df_consistency[df_consistency["axiom_preference"] != 0]

# Case 1: A preference pair of 0 and 1 is consistent (must not contradict).
df_consistency = df_consistency[(df_consistency["axiom_preference"] - df_consistency["oracle_preference"]).abs() <= 1]
# Case 2: A preference pair of 0 and 1 is inconsistent (must exactly match).
# df_consistency = df_consistency[df_consistency["axiom_preference"] == df_consistency["oracle_preference"]]

df_consistency = df_consistency.groupby(["data_run_type", "oracle_name", "axiom_name"])[["count"]].sum().reset_index().rename(columns={"count": "consistent_count"})

df_consistency = df_consistency.merge(df_decisiveness, on=["data_run_type","axiom_name"])
df_consistency["consistency"] = df_consistency["consistent_count"] / df_consistency["non_zero_count"]
df_consistency["consistency"] = df_consistency["consistency"].fillna(1)

df_consistency.sort_values(
    ["data_run_type", "oracle_name", "consistency"],
    ascending=[True, True, False],
)

## Combined axiom effectiveness

In [None]:
df_effectiveness = df_consistency.copy()

# Harmonic mean of decisiveness and consistency
df_effectiveness["effectiveness"] = 2 * (df_effectiveness["decisiveness"] * df_effectiveness["consistency"]) / (df_effectiveness["decisiveness"] + df_effectiveness["consistency"])

df_effectiveness.sort_values(
    ["data_run_type", "oracle_name", "effectiveness"],
    ascending=[True, True, False],
)

In [None]:
df_effectiveness.groupby(["data_run_type", "oracle_name"])[["decisiveness","consistency", "effectiveness"]].mean().sort_values(by=["data_run_type", "effectiveness"], ascending=[True, False])

In [None]:
df_effectiveness.groupby(["data_run_type", "axiom_name", "axiom_type"])[["decisiveness","consistency", "effectiveness"]].mean().sort_values(by=["data_run_type", "axiom_type", "effectiveness"], ascending=[True, True, False])

## Tables for axiom effectiveness

In [55]:
oracle_display_names = {
    "ORACLE-NUGGET-ALL-0.1": r"$A$",
    "ORACLE-NUGGET-ALL-STRICT-0.1": r"$A_\text{strict}$",
    "ORACLE-NUGGET-VITAL-0.1": r"$V$",
    "ORACLE-NUGGET-VITAL-STRICT-0.1": r"$V_\text{strict}$",
    "ORACLE-NUGGET-WEIGHTED-0.1": r"$W$",
    "ORACLE-NUGGET-WEIGHTED-STRICT-0.1": r"$W_\text{strict}$",
    "ORACLE-CROWD-OVERALL-0.1": r"Overall",
    "ORACLE-CROWD-COHERENCE-0.1": r"Coh.",
    "ORACLE-CROWD-CONSISTENCY-0.1": r"Cons.",
    "ORACLE-CROWD-CORRECTNESS-0.1": r"Corr.",
    "ORACLE-CROWD-COVERAGE-0.1": r"Cov.",
}

In [56]:
data_run_type_display_names = {
    "assignments": r"TREC 2024 RAG",
    "crowd": r"Crowd-sourced",
}

In [57]:
axiom_names = [
    "TFC1",
    "TFC3",
    "M-TDC",
    "LNC1",
    "TF-LNC",
    "LB1",
    "REG",
    "AND",
    "DIV",
    "STMC1",
    "STMC2",
    "PROX1",
    "PROX2",
    "PROX3",
    "PROX4",
    "PROX5",
    "ArgUC",
    "QTArg",
    "QTPArg",
    "aSL",
    "CLAR1",
    "CLAR2",
    "CLAR3",
    "CLAR4",
    "CLAR5",
    "CLAR6",
    "CLAR7",
    "CONS1",
    "CONS2",
    "CONS3",
    "CONS4",
    "COV1",
    "COV2",
    "COV3",
    "COV4",
    "COV5",
]

In [58]:
axiom_name_renaming = {
    "CONS1": nan,
    "CONS1-KB": "CONS1",
    "CONS2": nan,
    "CONS2-KB": "CONS2",
    "CONS3": nan,
    "CONS3-KB": "CONS3",
    "COV1": nan,
    "COV1-KB": "COV1",
    "COV2": nan,
    "COV2-KB": "COV2",
    "COV3": nan,
    "COV3-KB": "COV3",
    "COV4": nan,
    "COV4-KB": "COV4",
    "COV5": nan,
    "COV5-KB": "COV5",
}

In [None]:
table_data_run_type = "assignments"
df_table = df_effectiveness.copy()


df_table["axiom_name"] = Categorical(df_table['axiom_name'].str.removeprefix("GEN-").replace(axiom_name_renaming), axiom_names)

df_table["oracle_name"] = Categorical(df_table["oracle_name"], list(oracle_display_names.keys()))
df_table["data_run_type"] = Categorical(df_table["data_run_type"], list(data_run_type_display_names.keys()))

columns = ["l"]
for _, df_data_run_type in df_table.groupby("data_run_type"):
    if len(df_data_run_type) == 0:
        continue
    columns += ["r"]
    for _, df_oracle in df_data_run_type.groupby("oracle_name"):
        if len(df_oracle) == 0:
            continue
        columns += ["r"]
print(r"\begin{tabular}{@{}" + "".join(columns) + r"@{}}")
print(r"  \toprule")

columns = [r"\textbf{Axiom}"]
for data_run_type, df_data_run_type in df_table.groupby("data_run_type"):
    if len(df_data_run_type) == 0:
        continue
    num_columns = 1
    for _, df_oracle in df_data_run_type.groupby("oracle_name"):
        if len(df_oracle) == 0:
            continue
        num_columns += 1
    columns += [r"\multicolumn{" + f"{num_columns}" + r"}{c}{\textbf{" + data_run_type_display_names[data_run_type] + r"}}"]
print(r"  " + r" & ".join(columns).strip() + r" \\")
columns = [""]
i = 2
for _, df_data_run_type in df_table.groupby("data_run_type"):
    if len(df_data_run_type) == 0:
        continue
    start = i
    i += 1
    for _, df_oracle in df_data_run_type.groupby("oracle_name"):
        if len(df_oracle) == 0:
            continue
        i += 1
    columns += [r"\cmidrule(lr){" + f"{start}" + r"-" + f"{i-1}" + r"}"]
print(r"  " + r"".join(columns))

columns = [""]
for _, df_data_run_type in df_table.groupby("data_run_type"):
    if len(df_data_run_type) == 0:
        continue
    columns += [r"\textbf{Dec.}"]
    num_columns = 0
    for _, df_oracle in df_data_run_type.groupby("oracle_name"):
        if len(df_oracle) == 0:
            continue
        num_columns += 1
    columns += [r"\multicolumn{" + f"{num_columns}" + r"}{c}{\textbf{Consistency}}"]
print(r"  " + r" & ".join(columns).strip() + r" \\")
columns = [""]
i = 2
for _, df_data_run_type in df_table.groupby("data_run_type"):
    if len(df_data_run_type) == 0:
        continue
    columns += [r"\cmidrule(lr){" + f"{i}" + r"-" + f"{i}" + r"}"]
    i += 1
    start = i
    for _, df_oracle in df_data_run_type.groupby("oracle_name"):
        if len(df_oracle) == 0:
            continue
        i += 1
    columns += [r"\cmidrule(lr){" + f"{start}" + r"-" + f"{i-1}" + r"}"]
print(r"  " + r"".join(columns))

columns = [""]
for _, df_data_run_type in df_table.groupby("data_run_type"):
    if len(df_data_run_type) == 0:
        continue
    columns += [""]
    for oracle_name, df_oracle in df_data_run_type.groupby("oracle_name"):
        if len(df_oracle) == 0:
            continue
        columns += [r"\multicolumn{1}{c}{" + oracle_display_names[oracle_name] + r"}"]
print(r"  " + r" & ".join(columns).strip() + r" \\")
for axiom_type, df_axiom_type in df_table.groupby("axiom_type"):
    print(r"  \midrule")
    for axiom_name, df_axiom in df_axiom_type.groupby("axiom_name"):
        if len(df_axiom) <= 0:
            continue
        columns = [axiom_name.removeprefix("GEN-")]
        for data_run_type, df_data_run_type in df_axiom.groupby("data_run_type"):
            if len(df_data_run_type) == 0:
                continue
            if len(df_data_run_type["decisiveness"].unique()) != 1:
                raise ValueError()
            decisiveness = df_data_run_type.iloc[0]["decisiveness"]
            # decisiveness_prefix = "" if decisiveness > 0 else r"\color{gray} "
            # decisiveness_prefix = ""
            decisiveness_prefix = r"\color{black!" + f"{(pow(decisiveness, 1/4)*100):.0f}" + r"!gray} "
            columns += [
                decisiveness_prefix + f"{decisiveness:0.0%}".replace("%", r"\%"),
            ]
            for _, df_oracle in df_data_run_type.groupby("oracle_name"):
                if len(df_oracle) == 0:
                    continue
                if len(df_oracle) != 1:
                    raise ValueError()
                row = df_oracle.iloc[0]
                columns += [
                    decisiveness_prefix + f"{row["consistency"]:0.0%}".replace("%", r"\%"),
                ]
        print(r"  " + r" & ".join(columns).strip() + r" \\")
print(r"  \bottomrule")
print(r"\end{tabular}")
df_table

## Visulizations

In [60]:
def draw_heatmap(
    data: DataFrame,
    x: str,
    y: str,
    count="count",
    **kwargs,
) -> None:
    data = data.pivot(index=y, columns=x, values=count)
    heatmap(data=data, **kwargs)

In [None]:
df_plot = df_distribution.copy()
# df_plot = df_plot[df_plot["oracle_name"] == "TREC-ORACLE-ALL"]
# df_plot = df_plot[df_plot["oracle_name"] == "TREC-ORACLE-VITAL"]
# df_plot = df_plot[df_plot["oracle_preference"] != 0]
# df_plot = df_plot[df_plot["axiom_preference"] != 0]
# df_plot = df_plot[df_plot["axiom_preference"] != df_plot["oracle_preference"]]
df_plot

In [62]:
# plot = FacetGrid(
#     data=df_plot,
#     # col="oracle_threshold",
#     col="oracle_name",
#     row="axiom_name",
#     margin_titles=True,
# )
# plot.map_dataframe(
#     draw_heatmap,
#     x="oracle_preference",
#     y="axiom_preference",
#     count="count",
#     # vmin=0,
#     # vmax=1,
#     # norm=LogNorm(
#     #     vmin=0,
#     #     vmax=1,
#     # ),
#     norm=LogNorm(
#         vmin=0.1,
#         vmax=45 * 45 * 21,
#         clip=True,
#     ),
#     square=True,
#     cmap="rocket",
# )
# plot.set_titles(
#     row_template="{row_name}",
#     col_template="{col_name}",
# )
# plot

In [63]:
# df_plot = df_consistency.copy()
# df_plot = df_plot[
#     ~df_plot["non_zero"]
#     & ~df_plot["strict"]
#     & (df_plot["oracle_name"] == "TREC-ORACLE-ALL")
#     # & (df_plot["oracle_name"] == "TREC-ORACLE-VITAL")
# ]
# df_plot["decisiveness"] = df_plot["axiom_proportion_non_zero"] * df_plot["mean_consistency"]
# # df_plot["decisiveness"] = df_plot["mean_consistency"]
# df_plot = df_plot.pivot(
#     index="axiom_name",
#     columns="oracle_threshold",
#     values="decisiveness",
# )
# df_plot = df_plot.fillna(0)
# plot = heatmap(
#     data=df_plot,
#     cmap="rocket",
#     vmin=0,
#     vmax=1,
# )
# # plot.set_titles(template="{col_name}")
# plot

In [64]:
# df_plot = df_consistency.copy()
# df_plot = df_plot[
#     df_plot["non_zero"]
#     & ~df_plot["strict"]
#     & (df_plot["oracle_name"] == "TREC-ORACLE-ALL")
# ]
# df_plot = df_plot.pivot(
#     index="axiom_name",
#     columns="oracle_threshold",
#     values="axiom_proportion_non_zero",
# )
# df_plot = 1 - df_plot
# plot = heatmap(
#     data=df_plot,
#     cmap="rocket",
#     vmin=0,
#     vmax=1,
# )
# # plot.set_titles(template="{col_name}")
# plot