In [None]:
import re

import matplotlib.pyplot as plt
import pandas as pd

In [None]:
filename = "data/rl_eval_results2.json"
data = pd.read_json(filename)
data.head()

In [None]:
def parse_instance(answer: str) -> tuple[dict[str, list[str]], str | None]:
    """Parse string answer to separate into class and spans
    Simple case:
    [Cause] This is a cause [Effect] This is an effect

    Complex case:
    [Cause] This cause 1 | This cause 2 [Effect] This effect 1 | This effect 2
    """
    # TODO (italo): Document the relation
    matches = re.findall(r"\[Cause\](.*?)\[Relation\](.*?)\[Effect\](.*?)$", answer)
    if not matches:
        return {
            "Cause": [],
            "Effect": [],
        }, "cause"
    causes, relation, effects = matches[0]
    causes = sorted(c.strip() for c in causes.split("|") if c.strip())
    effects = sorted(e.strip() for e in effects.split("|") if e.strip())
    relation = relation.strip()

    return {
        "Cause": causes,
        "Effect": effects,
    }, relation

In [None]:
parse_instance(data.iloc[0].output)

In [None]:
def parse(row: str, col: str) -> tuple[str | None, str | None]:
    d, _ = parse_instance(row[col])
    if not d["Cause"] or not d["Effect"]:
        return None, None
    return d["Cause"][0], d["Effect"][0]

In [None]:
df = data.copy()
df[["pred_cause", "pred_effect"]] = df.apply(
    parse, col="output", axis=1, result_type="expand"
)
df[["gold_cause", "gold_effect"]] = df.apply(
    parse, col="gold", axis=1, result_type="expand"
)
# df = df.drop(columns=["output", "gold"]).dropna()
df = df.dropna()
df.head()

In [None]:
len(df.query("pred_effect != gold_effect"))

In [None]:
def clean_str(s: str) -> str:
    s = s.lower().strip()
    return re.sub(r"\s", "", s)


def symm_substr(a: str, b: str) -> bool:
    a = clean_str(a)
    b = clean_str(b)
    return a in b or b in a


df["cause_substr"] = df.apply(
    lambda x: symm_substr(x["pred_cause"], x["gold_cause"]), axis=1
)
df.query("pred_cause != gold_cause")["cause_substr"].value_counts()

In [None]:
df["effect_substr"] = df.apply(
    lambda x: symm_substr(x["pred_effect"], x["gold_effect"]), axis=1
)
df.query("pred_cause != gold_cause")["effect_substr"].value_counts()

In [None]:
def excess_words(a: str, b: str) -> int:
    a = a.lower().strip()
    b = b.lower().strip()

    if a in b:
        x = b.replace(a, "")
    else:
        x = a.replace(b, "")

    return x


def excess_words_count(a: str, b: str) -> int:
    return len(excess_words(a, b).split())


df_cause = df.query("(pred_cause != gold_cause) and cause_substr").copy()
df_cause["cause_excess"] = df_cause.apply(
    lambda x: excess_words(x["pred_cause"], x["gold_cause"]), axis=1
)
df_cause["cause_excess_count"] = df_cause.apply(
    lambda x: excess_words_count(x["pred_cause"], x["gold_cause"]), axis=1
)
print(df_cause['cause_excess_count'].describe())
df_cause.head()

In [None]:
plt.figure(figsize=(10, 6))

excess_count = df_cause["cause_excess_count"].value_counts().sort_index()
ax = excess_count.plot(kind="bar")
ax.set_xticklabels(excess_count.index, rotation=0)

plt.show()

In [None]:
plt.figure(figsize=(10, 6))

excess_count = df_cause["cause_excess_count"].value_counts().sort_index().cumsum()
ax = excess_count.plot(kind="bar")
ax.set_xticklabels(excess_count.index, rotation=0)

percentiles = [0.8, 0.9, 0.95, 0.99]
heights = [100, 50, 30, 10]
for percentile, height in zip(percentiles, heights):
    target_percentile = percentile * excess_count.max()
    value_at_percentile = excess_count[excess_count >= target_percentile].index[0]
    ax.axvline(x=value_at_percentile, color="black", linestyle="--")
    ax.annotate(
        f"{int(percentile * 100)}%",
        (value_at_percentile + 0.1, target_percentile + height),
        color="black",
    )

plt.show()

In [None]:
df_effect = df.query("(pred_effect != gold_effect) and effect_substr").copy()
df_effect["effect_excess"] = df_effect.apply(
    lambda x: excess_words(x["pred_effect"], x["gold_effect"]), axis=1
)
df_effect["effect_excess_count"] = df_effect.apply(
    lambda x: excess_words_count(x["pred_effect"], x["gold_effect"]), axis=1
)
print(df_effect['effect_excess_count'].describe())
df_effect.head()

In [None]:
df_diff = df.query("pred_cause != gold_cause or pred_effect != gold_effect")
len(df_diff)

In [None]:
df.query("pred_cause == gold_cause")['cause_substr'].value_counts()

In [None]:
df_diff['cause_substr'].value_counts()

In [None]:
df_diff['effect_substr'].value_counts()

In [None]:
(df_diff['cause_substr'] & df_diff['effect_substr']).value_counts()

In [None]:
df_substr = df_diff.query("cause_substr and effect_substr").copy()
len(df_substr)

In [None]:
df_substr["cause_excess"] = df_substr.apply(
    lambda x: excess_words(x["pred_cause"], x["gold_cause"]), axis=1
)
df_substr["cause_excess_count"] = df_substr.apply(
    lambda x: excess_words_count(x["pred_cause"], x["gold_cause"]), axis=1
)
df_substr["cause_excess_count"].describe()

In [None]:
df_substr["effect_excess"] = df_substr.apply(
    lambda x: excess_words(x["pred_effect"], x["gold_effect"]), axis=1
)
df_substr["effect_excess_count"] = df_substr.apply(
    lambda x: excess_words_count(x["pred_effect"], x["gold_effect"]), axis=1
)
df_substr["effect_excess_count"].describe()

In [None]:
df_substr.head()

In [None]:
df_substr.to_json("rl_eval_substr.json", orient="records")

In [None]:
df_nosub = df.query("(pred_cause != gold_cause) and (not cause_substr or not effect_substr)").copy()
df_nosub.head()

In [None]:
nosub_agg = df_nosub.groupby(["cause_substr", "effect_substr"])['input'].count().reset_index()
nosub_agg

In [None]:
print(nosub_agg.to_markdown(tablefmt="simple", index=False))

In [None]:
df_nosub.shape

In [None]:
df_nosub.to_json("rl_eval_nosubstr.json", orient="records")