In [None]:
import sys
import os

parent_path = ".."
sys.path.append(os.path.abspath(parent_path))

In [None]:
import nltk

nltk.download("stopwords")

import pickle5
import pandas as pd
import numpy as np
import health_causenet
from health_causenet import constants
from health_causenet.causenet import (
    CauseNet,
    contrastive_weight,
    term_domain_specificity,
    discriminative_weight,
)
import quickumls
from tqdm import tqdm
import os
import pathlib
import json
from IPython.display import clear_output
from health_bert import health_bert


tokenizer = nltk.tokenize.TreebankWordTokenizer()

In [None]:
def pickle_5_to_4(path):
    with open(path, "rb") as fh:
        data = pickle5.load(fh)
    data.to_pickle(path)

In [None]:
pickle_5_to_4(os.path.join(constants.CEPH_PATH, "test_causenet_predictions.pkl"))

In [None]:
# Create test causenet

full_causenet = pd.DataFrame()
paths = sorted(pathlib.Path(constants.CAUSENET_PARQUET_PATH).glob("causenet_*.parquet"))
pg = tqdm(paths)
for path in pg:
    from_file = pd.read_parquet(
        path, columns=["cause", "effect", "support", "reference", "sentence"]
    )
    full_causenet = pd.concat([full_causenet, from_file])
full_causenet = full_causenet.reset_index(drop=True)
print("parsing domain...")
print("computing counts...")
causenet = full_causenet.groupby(["cause", "effect", "support"]).size()
causenet.name = "count"
causenet = causenet.reset_index()

print("sorting by support...")
test_causenet_support = causenet.sort_values(
    ["support", "cause", "effect"], ascending=False
).reset_index(drop=True)
test_causenet_support = test_causenet_support.iloc[:1000].copy()
test_causenet_support["dataset"] = "support"
test_causenet_random_high = (
    causenet.loc[causenet.support >= 2].sample(1000, random_state=42).copy()
)
test_causenet_random_high["dataset"] = "random_support"
test_causenet_random_low = (
    causenet.loc[causenet.support == 1].sample(1000, random_state=42).copy()
)
test_causenet_random_low["dataset"] = "random_full"
test_causenet = pd.concat(
    [test_causenet_support, test_causenet_random_high, test_causenet_random_low]
)
test_causenet = test_causenet.reset_index(drop=True)

# print("adding wikidata...")
# test_wikidata = (
#     pd.read_csv(os.path.join(constants.WIKIDATA_PATH, "wikidata-test.csv"), index_col=0)
#     .drop_duplicates()
#     .dropna(subset=["cause", "effect"])
#     .reset_index(drop=True)
# )
# test_wikidata["dataset"] = "wikidata"
# test_causenet = test_causenet.append(
#     test_wikidata.loc[
#         :, ["cause", "effect", "dataset", "cause_origin", "effect_origin"]
#     ]
# ).reset_index(drop=True)

test_causenet = test_causenet.drop_duplicates(["cause", "effect", "dataset"])
# Load and label evaluations

ignore_origins = [
    #     "wd:Q39833",  # microorganism
    #     "wd:Q178694",  # heredity
    #     "wd:Q289472",  # biogenic substance
    #     "wd:Q796194",  # medical procedure
    #     "wd:Q2826767",  # disease causative agent
    #     "wd:Q2996394",  # biological process
    #     "wd:Q5850078",  # etiology
    #     "wd:Q7189713",  # physiological condition
    #     "wd:Q15788410",  # state of consciousness
    #     "wd:Q86746756",  # medicinal product
    #     "wd:Q87075524",  # health risk
]

with open(constants.MANUAL_EVALUATION_PATH, "r") as file:
    manual_eval_dict = json.load(file)
data = []
for key, value in manual_eval_dict.items():
    cause, effect = key.split("->")
    data.append({"cause": cause, "effect": effect, "evaluation": value})
test_causenet_manual = test_causenet.loc[
    test_causenet.dataset.isin(["support", "random_support", "random_full", "count"])
]
manual_eval = pd.DataFrame(data)
evaluation = manual_eval.set_index(["cause", "effect"]).reindex(
    test_causenet_manual.set_index(["cause", "effect"]).index
)

end = False
to_label = test_causenet_manual.loc[evaluation.isna().values].values
for idx, row in enumerate(to_label):
    cause = row[0]
    effect = row[1]
    while True:
        clear_output(wait=True)
        inp = input(
            f"{idx+1}/{len(to_label)} ({len(to_label) - idx}) [{cause}] -> [{effect}]"
        )
        if inp == "c":
            end = True
            break
        try:
            val = int(inp)
            if val in (0, 1):
                key = f"{cause}->{effect}"
                manual_eval_dict[key] = val
                break
        except:
            pass
        print(f"invalid input: {inp}, needs to be either 1 or 0")
    if end:
        break

with open(constants.MANUAL_EVALUATION_PATH, "w") as file:
    json.dump(manual_eval_dict, file, indent=4)

data = []
for key, value in manual_eval_dict.items():
    cause, effect = key.split("->")
    data.append({"cause": cause, "effect": effect, "evaluation": value})
manual_eval = pd.DataFrame(data)
evaluation = manual_eval.set_index(["cause", "effect"]).reindex(
    test_causenet_manual.set_index(["cause", "effect"]).index
)

print(f"eval missing for {evaluation.isna().sum().values[0]} relations")

evaluation = evaluation.evaluation
# evaluation = evaluation.append(test_wikidata.set_index(["cause", "effect"]).evaluation)
test_causenet["evaluation"] = evaluation.values

print("parsing sentence test causenet")
sentences = (
    full_causenet.set_index(["cause", "effect", "support"])
    .loc[
        test_causenet.loc[
            ~test_causenet.support.isna(), ["cause", "effect", "support"]
        ].values.tolist(),
        "sentence",
    ]
    .reset_index()
    .drop_duplicates(["cause", "effect", "sentence"])
)
sentence_test_causenet = test_causenet.merge(
    sentences, on=["cause", "effect", "support"]
)
sentence_test_causenet = sentence_test_causenet.drop_duplicates(["dataset", "sentence"])
sentence_test_causenet = sentence_test_causenet.drop_duplicates(
    ["cause", "effect", "dataset"]
)
sentence_test_causenet_evaluation = pd.read_csv(
    os.path.join(constants.BASE_PATH, "sentence_test_causenet_evaluations.csv"),
    index_col=0,
).rename({"label": "manual_evaluation"}, axis=1)
sentence_test_causenet = sentence_test_causenet.merge(
    sentence_test_causenet_evaluation, on=["cause", "effect", "sentence"], how="left"
)

del causenet
del full_causenet

# print("creating ctakes and metamap data")
# relations = pd.Series(
#     pd.unique(test_causenet.loc[:, ["cause", "effect"]].values.ravel())
# )
# relations = relations.loc[~relations.isin(["", " "])]
# sentence_relations = pd.Series(
#     pd.unique(sentence_test_causenet.loc[:, ["sentence"]].values.ravel())
# )
# relations = pd.concat([relations, sentence_relations]).reset_index(drop=True)
# ctakes_path = pathlib.Path(constants.TEST_CAUSENET_PATH).joinpath("ctakes")
# ctakes_path.mkdir(exist_ok=True)
# metamap_path = pathlib.Path(constants.TEST_CAUSENET_PATH).joinpath("metamap")
# metamap_path.mkdir(exist_ok=True)
# for idx, relation in enumerate(relations.values):
#     with ctakes_path.joinpath(f"relation_{idx + 1}.txt").open("w") as file:
#         file.write(relation)
# with metamap_path.joinpath("relations.txt").open("w") as file:
#     file.write("\n".join(f"{idx + 1}|{relation}" for idx, relation in enumerate(relations.values)))

test_causenet.to_pickle(os.path.join(constants.TEST_CAUSENET_PATH, "test_causenet.pkl"))
sentence_test_causenet.to_pickle(
    os.path.join(constants.TEST_CAUSENET_PATH, "sentence_test_causenet.pkl")
)
test_causenet

In [None]:
# load test_causenet
test_causenet = pd.read_pickle(os.path.join(constants.TEST_CAUSENET_PATH, "test_causenet.pkl"))
sentence_test_causenet = pd.read_pickle(
    os.path.join(constants.TEST_CAUSENET_PATH, "sentence_test_causenet.pkl")
)
test_causenet

In [None]:
# load cf and compute termhood scores
cf = pd.read_parquet(os.path.join(constants.CF_PATH, "cf.parquet"))
cf.loc[cf.num_terms == 1].sum().astype(str)
medical_termhood = {}
for corpus in list(cf.filter(regex=r".*_frequency_(?!open_domain)")):
    _cf = cf.loc[:, ["corpus_frequency_open_domain", corpus, "num_terms"]]
    corpus = corpus.replace("corpus_frequency_", "")
    medical_termhood[corpus] = {}
    print(f"{corpus}: computing term domain specificity...")
    medical_termhood[corpus]["term_domain_specificity"] = term_domain_specificity(
        _cf, np.e
    )
    print(f"{corpus}: computing contrastive weight...")
    medical_termhood[corpus]["contrastive_weight"] = contrastive_weight(_cf, np.e, 1)
    print(f"{corpus}: computing_discriminative weight...")
    medical_termhood[corpus]["discriminative_weight"] = (
        medical_termhood[corpus]["contrastive_weight"]
        * medical_termhood[corpus]["term_domain_specificity"]
    )
del cf

In [None]:
# speed test
import time
import datetime

times = {
    "sentence": {
        "termhood": {1: [], 2: [], 3: []},
        "bert": {" ": []}, 
        "quickumls": {"full": [], "rx_sno": []}, 
        "scispacy": {"full": [], "rx_sno": []}, 
        "metamap": {"full": [160.638], "rx_sno": [120.277]},
        "ctakes": {"full": [0], "rx_sno": [212.711]},
    }, 
    "phrase": {
        "termhood": {1: [], 2: [], 3: []},
        "bert": {" ": []}, 
        "quickumls": {"full": [], "rx_sno": []}, 
        "scispacy": {"full": [], "rx_sno": []},
        "metamap": {"full": [70.639], "rx_sno": [49.636]},
        "ctakes": {"full": [0], "rx_sno": [119.684]},
    }
}

n = 10

model = health_bert.HealthBert.load_from_checkpoint(
    constants.BASE_PATH + "models/health_bert/" + "pubmedbert_pubmed_sentence.ckpt"
)

for _ in tqdm(range(n), total=n):
    # termhood
    for n_gram in (1, 2, 3):
        start = time.perf_counter()
        health_causenet.causenet._contrastive_score(
            test_causenet.loc[test_causenet.dataset == "random_full"],
            medical_termhood["encyclopedia"]["discriminative_weight"],
            r=1,
            n_gram_size=(1, n_gram),
            verbose=False,
        )
        elapsed = time.perf_counter() - start
        times["phrase"]["termhood"][n_gram].append(elapsed)
        start = time.perf_counter()
        health_causenet.causenet._contrastive_score(
            sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].assign(
                cause=sentence_test_causenet.sentence, effect=""
            ),
            medical_termhood["encyclopedia"]["discriminative_weight"],
            p=1,
            n_gram_size=(1, n_gram),
            verbose=False,
        )
        elapsed = time.perf_counter() - start
        times["sentence"]["termhood"][n_gram].append(elapsed)

    # bert
    start = time.perf_counter()
    health_causenet.causenet._health_bert(
        test_causenet.loc[test_causenet.dataset == "random_full"],
        model,
        verbose=False,
        batch_size=1,
    )
    elapsed = time.perf_counter() - start
    start = time.perf_counter()
    times["phrase"]["bert"][" "].append(elapsed)
    health_causenet.causenet._health_bert(
        sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].assign(
            cause=sentence_test_causenet.sentence, effect=""
        ),
        model,
        batch_size=1,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["sentence"]["bert"][" "].append(elapsed)

    # quickumls
    start = time.perf_counter()
    CauseNet.is_medical(
        test_causenet.loc[test_causenet.dataset == "random_full"],
        "quickumls",
        jaccard_threshold=0.9,
        umls_subset="full",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["phrase"]["quickumls"]["full"].append(elapsed)
    start = time.perf_counter()
    CauseNet.is_medical(
        test_causenet.loc[test_causenet.dataset == "random_full"],
        "quickumls",
        jaccard_threshold=0.9,
        umls_subset="rx_sno",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["phrase"]["quickumls"]["rx_sno"].append(elapsed)
    start = time.perf_counter()
    CauseNet.is_medical(
        sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].assign(
            cause=sentence_test_causenet.sentence, effect=""
        ),
        "quickumls",
        jaccard_threshold=0.9,
        umls_subset="full",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["sentence"]["quickumls"]["full"].append(elapsed)
    start = time.perf_counter()
    CauseNet.is_medical(
        sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].assign(
            cause=sentence_test_causenet.sentence, effect=""
        ),
        "quickumls",
        jaccard_threshold=0.9,
        umls_subset="rx_sno",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["sentence"]["quickumls"]["rx_sno"].append(elapsed)

    # scispacy
    start = time.perf_counter()
    CauseNet.is_medical(
        test_causenet.loc[test_causenet.dataset == "random_full"],
        "scispacy",
        threshold=0.9,
        umls_subset="full",
        model="en_core_sci_sm",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["phrase"]["scispacy"]["full"].append(elapsed)
    start = time.perf_counter()
    CauseNet.is_medical(
        test_causenet.loc[test_causenet.dataset == "random_full"],
        "scispacy",
        threshold=0.9,
        umls_subset="rx_sno",
        model="en_core_sci_sm",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["phrase"]["scispacy"]["rx_sno"].append(elapsed)
    start = time.perf_counter()
    CauseNet.is_medical(
        sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].assign(
            cause=sentence_test_causenet.sentence, effect=""
        ),
        "scispacy",
        threshold=0.9,
        umls_subset="full",
        model="en_core_sci_sm",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["sentence"]["scispacy"]["full"].append(elapsed)
    start = time.perf_counter()
    CauseNet.is_medical(
        sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].assign(
            cause=sentence_test_causenet.sentence, effect=""
        ),
        "scispacy",
        threshold=0.9,
        umls_subset="rx_sno",
        model="en_core_sci_sm",
        st21pv=False,
        verbose=False,
    )
    elapsed = time.perf_counter() - start
    times["sentence"]["scispacy"]["rx_sno"].append(elapsed)


def create_index(index, dictionary, index_terms):
    for key, value in dictionary.items():
        index_terms.append(key)
        if isinstance(value, list):
            index.append(tuple(index_terms))
            index_terms = index_terms[:-1]
        else:
            _, index_terms = create_index(index, value, index_terms)
    index_terms = index_terms[:-1]
    return index, index_terms

def grab_values(dictionary):
    values = []
    for value in dictionary.values():
        if isinstance(value, dict):
            values.extend(grab_values(value))
        else:
            values.append(sum(value) / len(value))
    return values

index, _ = create_index([], times, [])
index
values = grab_values(times)
times_df = pd.DataFrame(values, pd.MultiIndex.from_tuples(index, names=["data", "method", "n-gram"]), columns=["time"])
num_samples = pd.Series(
    [
        test_causenet.loc[test_causenet.dataset == "random_full"].shape[0],
        sentence_test_causenet.loc[sentence_test_causenet.dataset == "random_full"].shape[0]
    ],
    index=pd.Index(["phrase", "sentence"], name="data"),
)
times_df["time_per_iter"] = (times_df["time"] / num_samples) * 1000
ratios = times_df["time"].values[:, None] / times_df["time"].values[None, :]
ratio_df = pd.DataFrame(ratios, index=times_df.index, columns=times_df.index)

from IPython.display import display
display(times_df)
display(ratio_df.loc["sentence", "sentence"])
display(ratio_df.loc["phrase", "phrase"])

In [None]:
display(times_df)
display(ratio_df.loc["sentence", "sentence"])
display(ratio_df.loc["phrase", "phrase"])

In [None]:
def classify_test_set(test_causenet_predictions, label):
    drop_labels = ["medical_score-cause", "medical_score-effect"]
    pg.set_description(" ".join(label))
    st21pv = "st21pv" in label
    if label[0] == "health_bert":
        _, name, corpus, text_type = label
        model_path = constants.BASE_PATH + "models/health_bert/"
        model = health_bert.HealthBert.load_from_checkpoint(
            model_path + f"{name}_{corpus}_{text_type}.ckpt"
        )
        medical_score = health_causenet.causenet._health_bert(
            test_causenet_predictions,
            model,
            verbose=False,
        )
    elif label[0] in ("metamap", "ctakes"):
        json_path = os.path.join(
            constants.BASE_PATH, "tagger_jsons", f"{label[0]}-{label[1]}.jsonl"
        )
        medical_score = CauseNet.is_medical(
            test_causenet_predictions, "tagger", json_path=json_path, st21pv=st21pv
        )
    elif label[0] == "scispacy":
        umls_subset, model, threshold = label[1:4]
        medical_score = CauseNet.is_medical(
            test_causenet_predictions,
            "scispacy",
            umls_subset=umls_subset,
            model=model,
            threshold=float(threshold),
            verbose=False,
            st21pv=st21pv,
        )
    elif label[0] == "quickumls":
        umls_subset, jaccard_threshold = label[1:3]
        jaccard_threshold = float(jaccard_threshold)
        medical_score = CauseNet.is_medical(
            test_causenet_predictions,
            "quickumls",
            jaccard_threshold=jaccard_threshold,
            umls_subset=umls_subset,
            st21pv=st21pv,
            verbose=False,
        )
    elif label[0] in contrastive_scores:
        contrastive_score, corpus, n_gram_size, r_value = label
        n_gram_size = tuple(n_gram_size.strip("()").split(", "))
        n_gram_size = (int(n_gram_size[0]), int(n_gram_size[1]))
        neg = "neg_" in r_value
        if neg:
            r_value = r_value[4:]
        try:
            r_value = int(r_value)
        except:
            r_value = float(r_value)
        if neg:
            r_value = r_value * -1
        medical_score = health_causenet.causenet._contrastive_score(
            test_causenet_predictions,
            medical_termhood[corpus][contrastive_score],
            r=r_value,
            n_gram_size=n_gram_size,
            verbose=False,
        )
    else:
        raise RuntimeError(f"unknown label {label}")
    suffix = "-" + "-".join(label)
    medical_score.index = test_causenet_predictions.index
    medical_score = medical_score.add_suffix(suffix)
    return medical_score

In [None]:
try:
    test_causenet_predictions = pd.read_pickle(
        os.path.join(constants.TEST_CAUSENET_PATH, "test_causenet_predictions.pkl")
    )
except FileNotFoundError:
    test_causenet_predictions = test_causenet.loc[:, ["cause", "effect", "dataset"]]
test_causenet_predictions = pd.read_pickle(
    os.path.join(constants.TEST_CAUSENET_PATH, "test_causenet_predictions.pkl")
)
test_causenet_predictions = test_causenet_predictions.merge(
    test_causenet.drop(
        ["support", "count", "evaluation"], axis=1
    ),
    on=["cause", "effect", "dataset"],
    how="outer",
)

jaccard_thresholds = [round(thresh, 2) for thresh in np.arange(0.7, 1.01, 0.1)]
scispacy_thresholds = [round(thresh, 2) for thresh in np.arange(0.6, 0.9, 0.1)]
scispacy_models = ["en_core_sci_sm", "en_core_sci_lg"]
r_values = [-float("inf"), -10, -5, -2, -1, 0, 1, 2, 5, 10, float("inf")]
n_gram_sizes = [(1, 1), (1, 2), (1, 3)]
medical_corpora = [
    "pubmed",
    "textbook",
    "pubmed_central",
    "encyclopedia",
]
umls_subsets = [
    "full",
    "rx_sno",
]
text_types = ["sentence", "noun_phrase"]
contrastive_scores = [
    "term_domain_specificity",
    "contrastive_weight",
    "discriminative_weight",
]
bert_names = ["bert", "scibert", "pubmedbert"]

labels = (
    [
        f"quickumls-{umls_subset}-{jaccard_threshold}-st21pv"
        if st21pv
        else f"quickumls-{umls_subset}-{jaccard_threshold}"
        for st21pv in [True, False]
        for umls_subset in umls_subsets
        for jaccard_threshold in jaccard_thresholds
    ]
    + [
        f"scispacy-{umls_subset}-{model}-{threshold}-st21pv"
        if st21pv
        else f"scispacy-{umls_subset}-{model}-{threshold}"
        for st21pv in [True, False]
        for model in scispacy_models
        for umls_subset in umls_subsets
        for threshold in scispacy_thresholds
    ]
    + [
        f"ctakes-{umls_subset}-st21pv" if st21pv else f"ctakes-{umls_subset}"
        for st21pv in [True, False]
        for umls_subset in umls_subsets
    ]
    + [
        f"metamap-{umls_subset}-st21pv" if st21pv else f"metamap-{umls_subset}"
        for st21pv in [True, False]
        for umls_subset in umls_subsets
    ]
    + [
        f"{contrastive_score}-{medical_corpus}-{n_gram_size}-{r_value}"
        if r_value >= 0
        else f"{contrastive_score}-{medical_corpus}-{n_gram_size}-neg_{-1 * r_value}"
        for contrastive_score in contrastive_scores
        for n_gram_size in n_gram_sizes
        for r_value in r_values
        for medical_corpus in medical_corpora
    ]
    + [
        f"health_bert-{name}-{medical_corpus}-{text_type}"
        for name in bert_names
        for medical_corpus in ["pubmed", "encyclopedia"]
        for text_type in text_types
    ]
)
df_labels = [
    f"medical_score-{relation}-{label}"
    for relation in ["cause", "effect"]
    for label in labels
]
labels = [label.split("-") for label in labels]
new_columns = []
for label in df_labels:
    if label not in test_causenet_predictions:
        new_columns.append(
            pd.Series(np.nan, index=test_causenet_predictions.index, name=label)
        )
test_causenet_predictions = pd.concat([test_causenet_predictions, *new_columns], axis=1)

missing_rows = test_causenet_predictions[df_labels].isna().all(1)
if missing_rows.any():
    print(f"parsing {missing_rows.sum()} missing rows")
    pg = tqdm(labels)
    for label in pg:
        replace_rows = classify_test_set(
            test_causenet_predictions.loc[missing_rows, ["cause", "effect"]], label,
        )
        label = "-".join(label)
        columns = [f"medical_score-cause-{label}", f"medical_score-effect-{label}"]
        replace_rows = replace_rows.loc[:, columns]
        test_causenet_predictions.loc[missing_rows, columns] = replace_rows.values
        test_causenet_predictions.to_pickle(
            os.path.join(constants.TEST_CAUSENET_PATH, "test_causenet_predictions.pkl")
        )

score_column = test_causenet_predictions.columns.str.startswith("medical_score-")
isna = test_causenet_predictions.isna().any(0)
missing_columns = list(test_causenet_predictions.loc[:, score_column & isna])
missing_columns = [
    "-".join(missing_column.split("-")[2:]) for missing_column in missing_columns
]
missing_columns = list(set(missing_columns))
missing_columns = sorted([missing_column.split("-") for missing_column in missing_columns])

if missing_columns:
    print("parsing missing columns")
    for missing_column in missing_columns:
        print(" ".join(missing_column))
    pg = tqdm(missing_columns)
    for label in pg:
        replace_columns = classify_test_set(
            test_causenet_predictions.loc[:, ["cause", "effect"]], label,
        )
        label = "-".join(label)
        columns = [f"medical_score-cause-{label}", f"medical_score-effect-{label}"]
        replace_columns = replace_columns.loc[:, columns]
        test_causenet_predictions.loc[:, columns] = replace_columns.values
        assert not test_causenet_predictions.loc[:, columns].isna().any().any()
        test_causenet_predictions.to_pickle(
            os.path.join(constants.TEST_CAUSENET_PATH, "test_causenet_predictions.pkl")
        )

assert not test_causenet_predictions.loc[:, score_column].isna().any().any()
test_causenet.merge(
    test_causenet_predictions, on=["cause", "effect", "dataset"], how="left"
)

In [None]:
try:
    sentence_test_causenet_predictions = pd.read_pickle(
        os.path.join(constants.TEST_CAUSENET_PATH, "sentence_test_causenet_predictions.pkl")
    )
except FileNotFoundError:
    sentence_test_causenet_predictions = sentence_test_causenet.loc[
        :, ["cause", "effect", "dataset", "sentence"]
    ]

sentence_test_causenet_predictions = sentence_test_causenet_predictions.merge(
    sentence_test_causenet.drop(
        [
            "support",
            "count",
            "evaluation",
            "manual_evaluation"
        ],
        axis=1,
    ),
    on=["cause", "effect", "sentence", "dataset"],
    how="outer",
)

jaccard_thresholds = [round(thresh, 2) for thresh in np.arange(0.7, 1.01, 0.1)]
scispacy_thresholds = [round(thresh, 2) for thresh in np.arange(0.6, 0.9, 0.1)]
scispacy_models = ["en_core_sci_sm", "en_core_sci_lg"]
r_values = [-float("inf"), -10, -5, -2, -1, 0, 1, 2, 5, 10, float("inf")]
n_gram_sizes = [(1, 1), (1, 2), (1, 3)]
medical_corpora = [
    "pubmed",
    "textbook",
    "pubmed_central",
    "encyclopedia",
]
umls_subsets = [
    "full",
    "rx_sno",
]
text_types = ["sentence", "noun_phrase"]
contrastive_scores = [
    "term_domain_specificity",
    "contrastive_weight",
    "discriminative_weight",
]
bert_names = ["bert", "scibert", "pubmedbert"]

labels = (
    [
        f"quickumls-{umls_subset}-{jaccard_threshold}-st21pv"
        if st21pv
        else f"quickumls-{umls_subset}-{jaccard_threshold}"
        for st21pv in [True, False]
        for umls_subset in umls_subsets
        for jaccard_threshold in jaccard_thresholds
    ]
    + [
        f"scispacy-{umls_subset}-{model}-{threshold}-st21pv"
        if st21pv
        else f"scispacy-{umls_subset}-{model}-{threshold}"
        for st21pv in [True, False]
        for model in scispacy_models
        for umls_subset in umls_subsets
        for threshold in scispacy_thresholds
    ]
    + [
        f"ctakes-{umls_subset}-st21pv" if st21pv else f"ctakes-{umls_subset}"
        for st21pv in [True, False]
        for umls_subset in umls_subsets
    ]
    + [
        f"metamap-{umls_subset}-st21pv" if st21pv else f"metamap-{umls_subset}"
        for st21pv in [True, False]
        for umls_subset in umls_subsets
    ]
    + [
        f"{contrastive_score}-{medical_corpus}-{n_gram_size}-{r_value}"
        if r_value >= 0
        else f"{contrastive_score}-{medical_corpus}-{n_gram_size}-neg_{-1 * r_value}"
        for contrastive_score in contrastive_scores
        for n_gram_size in n_gram_sizes
        for r_value in r_values
        for medical_corpus in medical_corpora
    ]
    + [
        f"health_bert-{name}-{medical_corpus}-{text_type}"
        for name in bert_names
        for medical_corpus in ["pubmed", "encyclopedia"]
        for text_type in text_types
    ]
)
df_labels = [
    f"medical_score-{relation}-{label}"
    for relation in ["cause", "effect"]
    for label in labels
]
labels = [label.split("-") for label in labels]
new_columns = []
for label in df_labels:
    if label not in sentence_test_causenet_predictions:
        new_columns.append(
            pd.Series(
                np.nan, index=sentence_test_causenet_predictions.index, name=label
            )
        )
sentence_test_causenet_predictions = pd.concat(
    [sentence_test_causenet_predictions, *new_columns], axis=1
)

prediction_sentence_test_causenet = sentence_test_causenet_predictions.copy()
prediction_sentence_test_causenet.cause = prediction_sentence_test_causenet.sentence
prediction_sentence_test_causenet.effect = ""
prediction_sentence_test_causenet = prediction_sentence_test_causenet.drop(
    ["sentence",], axis=1,
)

missing_rows = sentence_test_causenet_predictions[df_labels].isna().all(1)
if missing_rows.any():
    print(f"parsing {missing_rows.sum()} missing rows")
    pg = tqdm(labels)
    for label in pg:
        replace_rows = classify_test_set(
            prediction_sentence_test_causenet.loc[missing_rows, ["cause", "effect"]],
            label,
        )
        label = "-".join(label)
        columns = [f"medical_score-cause-{label}", f"medical_score-effect-{label}"]
        replace_rows = replace_rows.loc[:, columns]
        sentence_test_causenet_predictions.loc[
            missing_rows, columns
        ] = replace_rows.values
        
        effect_columns = sentence_test_causenet_predictions.columns[
            sentence_test_causenet_predictions.columns.str.startswith("medical_score-effect")
        ]
        sentence_test_causenet_predictions.loc[:, effect_columns] = 0
        sentence_test_causenet_predictions.to_pickle(
            os.path.join(constants.TEST_CAUSENET_PATH, "sentence_test_causenet_predictions.pkl")
        )
         
score_column = sentence_test_causenet_predictions.columns.str.startswith(
    "medical_score-"
)
isna = sentence_test_causenet_predictions.isna().any(0)
missing_columns = list(sentence_test_causenet_predictions.loc[:, score_column & isna])
missing_columns = ["-".join(missing_column.split("-")[2:]) for missing_column in missing_columns]
missing_columns = list(set(missing_columns))
missing_columns = sorted([missing_column.split("-") for missing_column in missing_columns])

if missing_columns:
    print("parsing missing columns")
    for missing_column in missing_columns:
        print(" ".join(missing_column))
    pg = tqdm(missing_columns)
    for label in pg:
        replace_columns = classify_test_set(
            prediction_sentence_test_causenet.loc[:, ["cause", "effect"]], label,
        )
        label = "-".join(label)
        columns = [f"medical_score-cause-{label}", f"medical_score-effect-{label}"]
        replace_columns = replace_columns.loc[:, columns]
        sentence_test_causenet_predictions.loc[:, columns] = replace_columns.values
        assert not sentence_test_causenet_predictions.loc[:, columns].isna().any().any()

        effect_columns = sentence_test_causenet_predictions.columns[
            sentence_test_causenet_predictions.columns.str.startswith("medical_score-effect")
        ]
        sentence_test_causenet_predictions.loc[:, effect_columns] = 0
        sentence_test_causenet_predictions.to_pickle(
            os.path.join(constants.TEST_CAUSENET_PATH, "sentence_test_causenet_predictions.pkl")
        )
        
assert not sentence_test_causenet_predictions.loc[:, score_column].isna().any().any()

sentence_test_causenet.merge(
    sentence_test_causenet_predictions,
    on=["cause", "effect", "sentence", "dataset"],
    how="left",
)