In [None]:
import sys
import os

parent_path = ".."
sys.path.append(os.path.abspath(parent_path))

In [None]:
import pandas as pd
import pathlib
from health_causenet import constants
import extract_medical
from tqdm.autonotebook import tqdm
import numpy as np
import json

In [None]:
full_causenet = pd.DataFrame()
paths = sorted(pathlib.Path(constants.CAUSENET_PARQUET_PATH).glob("causenet_*.parquet"), key=lambda x: int(str(x).split("_")[-1][:-8]))
for path in tqdm(paths):
    from_file = pd.read_parquet(
        path, columns=["cause", "effect", "support", "reference", "sentence"]
    )
    full_causenet = pd.concat([full_causenet, from_file])
full_causenet = full_causenet.reset_index(drop=True)
full_causenet

In [None]:
def p_mean_threshold_combiner(cause, effect, p):
    return ((cause ** p + effect ** p) / 2) ** (1 / p)

def max_combiner(cause, effect):
    return np.maximum(cause, effect)

def min_combiner(cause, effect):
    return np.minimum(cause, effect)

ops = {
    "and": lambda cause, effect: min_combiner(cause, effect),
    "p=1_mean": lambda cause, effect: p_mean_threshold_combiner(cause, effect, 1),
    "p=2_mean": lambda cause, effect: p_mean_threshold_combiner(cause, effect, 2),
    "p=5_mean": lambda cause, effect: p_mean_threshold_combiner(cause, effect, 5),
    "p=10_mean": lambda cause, effect: p_mean_threshold_combiner(cause, effect, 10),
    "p=inf_mean": lambda cause, effect: max_combiner(cause, effect),
}

In [None]:
best_mcc = pd.read_csv("./test_best_approaches_mcc.csv", index_col=0).set_index(["dataset", "method_class"])
best_prec = pd.read_csv("./test_best_approaches_recall_precision_0.9.csv", index_col=0).set_index(["dataset", "method_class"])

full_mcc = best_mcc.loc["random_full"].loc[["contrastive_weight", "term_domain_specificity", "discriminative_weight"]].sort_values("mcc").iloc[-1]
full_prec = best_prec.loc["random_full"].loc[["contrastive_weight", "term_domain_specificity", "discriminative_weight"]].sort_values("recall").iloc[-1]
support_mcc = best_mcc.loc["random_support"].loc[["contrastive_weight", "term_domain_specificity", "discriminative_weight"]].sort_values("mcc").iloc[-1]
support_prec = best_prec.loc["random_support"].loc[["contrastive_weight", "term_domain_specificity", "discriminative_weight"]].sort_values("recall").iloc[-1]

file_patterns = {
    "full_mcc": {
        "method": full_mcc.method,
        "threshold": full_mcc.threshold,
        "op": full_mcc.operator,
    },
    "full_prec": {
        "method": full_prec.method,
        "threshold": full_prec.threshold,
        "op": full_prec.operator,
    },
    "support_mcc": {
        "method": support_mcc.method,
        "threshold": support_mcc.threshold,
        "op": support_mcc.operator,
    },
    "support_prec": {
        "method": support_prec.method,
        "threshold": support_prec.threshold,
        "op": support_prec.operator,
    }
}

print(json.dumps(file_patterns, indent=2))

full_causenet_medical = full_causenet.copy()
for name, kwargs in tqdm(list(file_patterns.items())):
    medical_score = pd.DataFrame()
    paths = []
    while True:
        paths = sorted(pathlib.Path(constants.CAUSENET_PARQUET_PATH).glob(kwargs["method"].replace(", ", "_") + "_*.parquet"), key=lambda x: int(str(x).split("_")[-1][:-8]))
        if paths:
            break
        termhood, corpus, n_gram_size, p = kwargs["method"].split("-")
        n_grams = (n_gram_size[1], n_gram_size[4])
        args = [termhood, "--corpora", corpus, "--n_gram_size", *n_grams, "--p", p]
        extract_medical.main(args)
        
    for path in paths:
        medical_score = pd.concat([medical_score, pd.read_parquet(path)])
    medical = ops[kwargs["op"]](medical_score["medical_score-cause"], medical_score["medical_score-effect"]) >= kwargs["threshold"]
    medical = pd.Series(medical, name=name).reset_index(drop=True)
    full_causenet_medical = full_causenet_medical.join(medical)
    
full_causenet_medical.loc[full_causenet_medical.support == 1, list(filter(lambda x: "support" in x, file_patterns.keys()))] = np.nan
causenet_medical = full_causenet_medical.drop(["reference", "sentence"], axis=1).drop_duplicates(["cause", "effect"]).reset_index(drop=True)
full_causenet_medical

In [None]:
causenet_medical

In [None]:
causenet_medical.drop("support", axis=1).to_csv(constants.CAUSENET_PARQUET_PATH + "/health-causenet.tsv", index=False, sep="\t")

In [None]:
full_causenet_medical.loc[:, list(file_patterns.keys())].agg(["sum", "mean", "count"]).astype(str)

In [None]:
causenet_medical.loc[:, list(file_patterns.keys())].agg(["sum", "mean", "count"]).astype(str)

In [None]:
tmp

In [None]:
# tmp = full_causenet_medical.loc[full_causenet_medical.support_prec.fillna(False)]
tmp = full_causenet_medical.loc[(full_causenet_medical.cause == "cancer") | (full_causenet_medical.cause == "cancer")]
tmp["medical_score"] = medical_score.loc[~medical_score.index.duplicated()].sum(axis=1)
tmp.sort_values("medical_score")

In [None]:
for sentence in tmp.loc[tmp.sentence.str.contains("Virgo", case=True), "sentence"].drop_duplicates().values:
    print(sentence, "\n")

In [None]:
tmp = full_causenet_medical.loc[full_causenet_medical.support_prec.fillna(False)]
tmp = tmp.loc[tmp.reference.str.contains("http://www.nlm.nih.gov/medlineplus/")].sample(10)
for sample in tmp.sample(10).values:
    print(sample[0], "->", sample[1], sample[3], sample[4])
    print()

In [None]:
tmp = full_causenet_medical.loc[full_causenet_medical.support_prec.fillna(False)]
# tmp.loc[tmp.cause.str.contains("jupiter")].drop_duplicates(["cause", "effect"]).head(20)
for sample in tmp.sample(10).values:
    print(sample[0], "->", sample[1], sample[3], sample[4])
    print()

In [None]:
resources = [
#     "full_mcc",
    "full_prec",
#     "support_mcc",
#     "support_prec"
]
for resource in resources:
    print(resource)
    samples = full_causenet_medical.loc[full_causenet_medical.loc[:, resource] & ~full_causenet_medical.loc[:, resource].isna(), ["cause", "effect", "support", "reference", "sentence"]].sample(n=1000)
    samples = samples.reset_index(drop=True)
    samples.to_csv(constants.BASE_PATH + f"resources/{resource}.csv")