# Calculate results according to a unified test set rather than relying on automated individual model outputs
* seems like some algorithms calculate performace of all answers for a given query rather than just the answer picked which could contribute to inconsistent scoring
* this notebook unifies the answers and creates a test set to score the predictions on to ensure fair evaluation
* rank predictions are filtered, meaning if a true prediction is ranked higher than given, we subtract out the number of true's preceding the given.

## summary stats
### CBR
CBR Stats
* MRR: 0.0128
* Hits@1: 0.0020
* Hits@3: 0.0156
* Hits@10: 0.0352

pCBR Stats
* MRR: 0.2557
* Hits@1: 0.1816
* Hits@3: 0.3008
* Hits@10: 0.3984

### Rephetio
Rephetio Stats
* MRR: 0.0816
* Hits@1: 0.0488
* Hits@3: 0.0781
* Hits@10: 0.1445

### KGE
TransE Stats
* MRR: 0.1601
* Hits@1: 0.0645
* Hits@3: 0.1816
* Hits@10: 0.3770

DistMult Stats
* MRR: 0.0391
* Hits@1: 0.0098
* Hits@3: 0.0293
* Hits@10: 0.0781

ComplEx Stats
* MRR: 0.0820
* Hits@1: 0.0273
* Hits@3: 0.0801
* Hits@10: 0.2031

RotatE Stats
* MRR: 0.1396
* Hits@1: 0.0840
* Hits@3: 0.1406
* Hits@10: 0.2637

In [1]:
import polars as pl

import os
import json
import numpy as np

import gc
import sys

sys.path.append("./Notebooks")
import score_utils2 as score_utils

# Import results

## read in cbr and pcbr results

In [2]:
# Import probCBR results with this function
def json_to_df(json_dir: str) -> pl.DataFrame:
    """
    Takes a string to a json object and turns it into a dataframe
    """
    # import the json object
    with open(json_dir, "r") as f:
        json_obj = json.load(f)

    # creates a dict with json keys as keys and values of emptylist based on the first json entry
    json_key_dict = {i: [] for i in list(json_obj[0].keys())}
    for i in json_obj:
        for j in json_key_dict.keys():
            # add j key in json item i to the list
            to_add = json_key_dict[j]
            to_add.append(i[j])
            json_key_dict.update({j: to_add})

    return pl.DataFrame(json_key_dict)

In [3]:
pcbr_df = json_to_df("/home/msinha/Open-BIo-Link/data1.json").filter(
    pl.col("r") == "indication"
)
cbr_df = json_to_df("/home/msinha/CBR-AKBC/results/data_CBRonMIND_CtD.json")

In [4]:
print(f"probCBR queries and results: {pcbr_df.shape[0]}")
print(f"CBR queres and results: {cbr_df.shape[0]}")

probCBR queries and results: 386
CBR queres and results: 386


In [5]:
pcbr_df.head(2)

e1,r,answers,predicted_answers
str,str,list[str],list[str]
"""CHEBI:135735""","""indication""","[""DOID:10763""]","[""DOID:6432"", ""DOID:6000"", … ""DOID:0060319""]"
"""CHEBI:135738""","""indication""","[""DOID:10763""]","[""DOID:10763"", ""DOID:446"", … ""DOID:12361""]"


In [6]:
cbr_df.head(2)

e1,r,answers,predicted_answers
str,str,list[str],list[str]
"""CHEBI:135735""","""indication""","[""DOID:10763""]","[""DOID:3393"", ""DOID:6000"", … ""DOID:0060343""]"
"""CHEBI:135738""","""indication""","[""DOID:10763""]","[""HP:0000006"", ""HP:0000007"", … ""DOID:11162""]"


In [7]:
### update headers from e1->h
pcbr_df = pcbr_df.rename({"e1": "h"})
cbr_df = cbr_df.rename({"e1": "h"})

### add column to label method type
pcbr_df = pcbr_df.with_columns(pl.lit("pCBR").alias("method"))
cbr_df = cbr_df.with_columns(pl.lit("CBR").alias("method"))

## Import Rephetio Results

In [9]:
os.chdir("./Notebooks")

In [10]:
rephetio = (
    pl.read_csv(
        "../../MechRepoNet/1_code/KG_reasoning_comparison/Rephetio_MIND_CtD_test/results.csv"
    )
    .sort("proba", descending=True)
    .group_by("chemicalsubstance_id")
    .agg(["disease_id"])
)

rephetio.head(2)

chemicalsubstance_id,disease_id
str,list[str]
"""CHEBI:5165""","[""DOID:0060224"", ""DOID:1826"", … ""DOID:2272""]"
"""CHEBI:5139""","[""DOID:0060224"", ""DOID:12849"", … ""DOID:0060677""]"


In [11]:
# rename columns and add a label for the method
rephetio = rephetio.rename(
    {"chemicalsubstance_id": "h", "disease_id": "predicted_answers"}
)
rephetio = rephetio.with_columns(pl.lit("Rephetio").alias("method"))
rephetio.shape

(374, 3)

## Import KGE results

In [12]:
kge_ls = list()
for i in ["TransE", "DistMult", "ComplEx", "RotatE"]:
    print(i)
    raw = score_utils.ProcessOutput(
        data_dir="../data/MIND_CtD",
        scores_outfile=f"../models/{i}_MIND_CtD_megha/test_scores.tsv",
        mode="tail-batch",
    )
    raw.format_raw_scores_to_df()
    raw.translate_embeddings(direction="from")
    df = raw.df
    df = df.with_columns(pl.lit(i).alias("method"))
    kge_ls.append(df)
    print(f"Shape: {df.shape}")

TransE
Shape: (535, 6)
DistMult
Shape: (535, 6)
ComplEx
Shape: (535, 6)
RotatE
Shape: (535, 6)


In [13]:
kge_df = pl.concat(kge_ls)
kge_df.head(2)

h,r,t,batch,preds,method
str,str,str,str,list[str],str
"""CHEBI:6375""","""indication""","""DOID:10808""","""tail-batch""","[""CHEBI:6375"", ""DOID:13976"", … ""GO:0009601""]","""TransE"""
"""CHEBI:8708""","""indication""","""DOID:5419""","""tail-batch""","[""CHEBI:8708"", ""DOID:14320"", … ""GO:0009601""]","""TransE"""


In [14]:
kge_df = kge_df.with_columns(pl.col("preds").list.head(1000))

In [15]:
del kge_ls
gc.collect()

8184

# Get dataset overlap
* seeemingly the datasets are not 100% exactly the same. 
* check overlap
Grab the overlapping queries

### check overlapping queries

In [16]:
print("Non-overlapping queries (h) between: ")
print(
    f"- CBR and probCBR: {len(set(cbr_df['h']).symmetric_difference(set(pcbr_df['h'])))}"
)
print(f"- CBR and KGE: {len(set(pcbr_df['h']).symmetric_difference(set(kge_df['h'])))}")
print(
    f"- Rephetio and CBR: {len(set(pcbr_df['h']).symmetric_difference(set(rephetio['h'])))}"
)
print(
    f"- Rephetio and KGE: {len(set(rephetio['h']).symmetric_difference(set(kge_df['h'])))}"
)

Non-overlapping queries (h) between: 
- CBR and probCBR: 0
- CBR and KGE: 1
- Rephetio and CBR: 14
- Rephetio and KGE: 13


In [17]:
print("Overlapping queries (h) between: ")
print(f"- CBR and KGE: {len(set(pcbr_df['h']).intersection(set(kge_df['h'])))}")
print(f"- Rephetio and CBR: {len(set(pcbr_df['h']).intersection(set(rephetio['h'])))}")
print(f"- Rephetio and KGE: {len(set(rephetio['h']).intersection(set(kge_df['h'])))}")

Overlapping queries (h) between: 
- CBR and KGE: 386
- Rephetio and CBR: 373
- Rephetio and KGE: 374


In [18]:
print("Unique queries in each dataset:")
print(f"- CBR: {len(set(pcbr_df['h']))}")
print(f"- Rephetio: {len(set(rephetio['h']))}")
print(f"- KGE: {len(set(kge_df['h']))}")

Unique queries in each dataset:
- CBR: 386
- Rephetio: 374
- KGE: 387


### Get overlapping queries

In [19]:
query_overlap = (
    set(pcbr_df["h"]).intersection(set(rephetio["h"])).intersection(set(kge_df["h"]))
)

In [20]:
print(f"Number of overlapping entities: {len(query_overlap)}")

Number of overlapping entities: 373


# Gather true answers and create a test dataset

## Gather true answers

### Read in train/test/valid and get the indication

In [21]:
indications = (
    pl.concat(
        [
            pl.read_csv(
                "../data/MIND_CtD/train.txt",
                separator="\t",
                new_columns=["h", "r", "t"],
            ).filter(pl.col("r") == "indication"),
            pl.read_csv(
                "../data/MIND_CtD/valid.txt",
                separator="\t",
                new_columns=["h", "r", "t"],
            ),
            pl.read_csv(
                "../data/MIND_CtD/test.txt", separator="\t", new_columns=["h", "r", "t"]
            ),
        ]
    )
    .group_by("h")
    .agg("t")
    .rename({"t": "answers"})
)
indications.head(2)

h,answers
str,list[str]
"""CHEBI:31608""","[""DOID:11759"", ""DOID:11759"", … ""DOID:11758""]"
"""CHEBI:30621""","[""DOID:0060318"", ""MESH:D002471"", … ""KEGG:hsa05215""]"


In [22]:
indications.shape

(1306, 2)

### gather answers from cbr and ensure they're not different

In [23]:
indications = (
    pl.concat([indications, cbr_df.select(["h", "answers"])])
    .explode("answers")
    .group_by("h")
    .agg("answers")
    .with_columns(pl.col("answers").list.unique())
)

### make sure query isn't in answers

In [24]:
indications = (
    indications.explode("answers")
    .with_columns((pl.col("answers") == pl.col("h")).alias("ans_eq_h"))
    .filter(pl.col("ans_eq_h") == False)
    .group_by("h")
    .agg("answers")
)

## Recreate a test dataframe to make sure results are fair
* from indications, extract at least 1 item from query_overlap
* For the remaining, sample from query overlap

In [25]:
# sample 1 item from the answers list
picked_0 = (
    indications.filter(pl.col("h").is_in(query_overlap))
    .with_columns(pl.col("answers").list.sample(1))
    .explode("answers")
)

In [26]:
picked_0.shape

(373, 2)

In [27]:
# any answers not already picked, sample again
picked_1 = (
    indications.explode("answers")
    .join(picked_0, on=["h", "answers"], how="anti")
    .filter(pl.col("h").is_in(query_overlap))
    .sample(511 - 372)
)

In [28]:
if os.path.exists("unified_test_set.parquet"):
    test_set = pl.read_parquet("unified_test_set.parquet")
else:
    test_set = pl.concat([picked_0, picked_1])
    test_set.write_parquet("unified_test_set.parquet")

### Check compounds exist in curated set

In [29]:
# already curated intersection predictions

required = [
    "CHEBI:520985",  # almotriptan
    "CHEBI:78540",  # apremilast
    "CHEBI:77590",  # armodafinil
    "CHEBI:59164",  # balsalazide disodium
    "CHEBI:3286",  # cabergoline
    "CHEBI:3738",  # clemastine
    "CHEBI:3756",  # clonazepam
    "CHEBI:17439",  # cyanocob(III)alamin
    "CHEBI:4046",  # cyproheptadine
    "CHEBI:4638",  # diphenidol
    "CHEBI:4647",  # dipivefrin hydrochloride
    "CHEBI:31530",  # edaravone
    "IKEY:DYLUUSLLRIQKOE-UHFFFAOYSA-N",  # Enasidenib
    "CHEBI:36791",  # escitalopram
    "CHEBI:5051",  # Fexofenadine hydrochloride
    "CHEBI:5139",  # Fluvoxamine maleate
    "CHEBI:6441",  # levacetylmethadol
    "CHEBI:135925",  # lisdexamfetamine
    "CHEBI:31854",  # Milnacipran hydrochloride
    "CHEBI:7575",  # nimodipine
    "CHEBI:8708",  # quetiapine fumarate
    "CHEBI:63620",  # rasagiline
    "CHEBI:9207",  # sotalol hydrochloride
    "CHEBI:9711",  # triflupromazine
    "CHEBI:9725",  # Trimeprazine
]

In [30]:
# make sure all items from 'required' are in the test set
# required is the set of objects already curated.
assert (
    len(set(test_set["h"]).intersection(set(required))) == 25
), "Not all required entities are in the test set"

# Clean, assemble the results and calculate performance stats

## Clean and integrate data together

### check headers

In [31]:
# check headers
cbr_df.head(2)

h,r,answers,predicted_answers,method
str,str,list[str],list[str],str
"""CHEBI:135735""","""indication""","[""DOID:10763""]","[""DOID:3393"", ""DOID:6000"", … ""DOID:0060343""]","""CBR"""
"""CHEBI:135738""","""indication""","[""DOID:10763""]","[""HP:0000006"", ""HP:0000007"", … ""DOID:11162""]","""CBR"""


In [32]:
kge_df.head(2)

h,r,t,batch,preds,method
str,str,str,str,list[str],str
"""CHEBI:6375""","""indication""","""DOID:10808""","""tail-batch""","[""CHEBI:6375"", ""DOID:13976"", … ""DOID:5082""]","""TransE"""
"""CHEBI:8708""","""indication""","""DOID:5419""","""tail-batch""","[""CHEBI:8708"", ""DOID:14320"", … ""DOID:3345""]","""TransE"""


In [33]:
rephetio.head(2)

h,predicted_answers,method
str,list[str],str
"""CHEBI:5165""","[""DOID:0060224"", ""DOID:1826"", … ""DOID:2272""]","""Rephetio"""
"""CHEBI:5139""","[""DOID:0060224"", ""DOID:12849"", … ""DOID:0060677""]","""Rephetio"""


### drop relation and stack results

In [47]:
res_df = pl.concat(
    [
        cbr_df.select(["h", "predicted_answers", "method"]),
        pcbr_df.select(["h", "predicted_answers", "method"]),
        kge_df.select(["h", "preds", "method"]).rename({"preds": "predicted_answers"}),
        rephetio.select(["h", "predicted_answers", "method"]),
    ]
).unique()

In [48]:
res_df.shape

(3271, 3)

### Get relevant results
* some probCBR and CBR entries don't have predictions

In [49]:
print("Number of no predictions for")
print(
    f"- pCBR: {pcbr_df.with_columns(pl.col('predicted_answers').list.len().alias('len')).filter(pl.col('len')==0).shape[0]}"
)
print(
    f"- CBR: {cbr_df.with_columns(pl.col('predicted_answers').list.len().alias('len')).filter(pl.col('len')==0).shape[0]}"
)
print(
    f"- rephetio: {rephetio.with_columns(pl.col('predicted_answers').list.len().alias('len')).filter(pl.col('len')==0).shape[0]}"
)

Number of no predictions for
- pCBR: 27
- CBR: 35
- rephetio: 0


In [50]:
res_df = (
    res_df.filter(pl.col("h").is_in(query_overlap))
    .explode("predicted_answers")
    .with_columns((pl.col("predicted_answers") == pl.col("h")).alias("match"))
    .filter(pl.col("match") == False)  # remove query from answers if it exists
    .group_by(["h", "method"], maintain_order=True)
    .agg("predicted_answers")
)

res_df.shape

(2553, 3)

In [51]:
res_df.head(2)

h,method,predicted_answers
str,str,list[str]
"""CHEBI:3738""","""RotatE""","[""DOID:5419"", ""DOID:1470"", … ""MESH:D049970""]"
"""CHEBI:3758""","""ComplEx""","[""DOID:11106"", ""DOID:11107"", … ""NCBIGene:2207""]"


### some cbr/pcbr entries don't have predictions
* add entries back in so we can calculate hits fairly.

In [52]:
# some entries are not matching because no predictions were made and they were removed
for i in res_df["method"].unique().to_list():
    print(f"{i}: {res_df.filter(pl.col('method')==i).shape[0]}")

RotatE: 373
CBR: 341
Rephetio: 373
pCBR: 347
ComplEx: 373
TransE: 373
DistMult: 373


In [53]:
missing_pcbr = set(
    res_df.filter(pl.col("method") == "TransE")["h"]
).symmetric_difference(set(res_df.filter(pl.col("method") == "pCBR")["h"]))
missing_cbr = set(
    res_df.filter(pl.col("method") == "TransE")["h"]
).symmetric_difference(set(res_df.filter(pl.col("method") == "CBR")["h"]))

In [54]:
res_df = pl.concat(
    [
        res_df,
        pl.DataFrame({"h": list(missing_pcbr)}).with_columns(
            method=pl.lit("pCBR"), predicted_answers=[]
        ),
        pl.DataFrame({"h": list(missing_cbr)}).with_columns(
            method=pl.lit("CBR"), predicted_answers=[]
        ),
    ]
)

In [55]:
# some entries are not matching because no predictions were made and they were removed
res_df = res_df.unique().filter(pl.col("method") != "probCBR")
for i in res_df["method"].unique().to_list():
    print(f"{i}: {res_df.filter(pl.col('method')==i).shape[0]}")

TransE: 373
ComplEx: 373
DistMult: 373
RotatE: 373
Rephetio: 373
pCBR: 373
CBR: 373


In [56]:
assert res_df.shape[0] / 7 == 373, "Some algorithms are missing predictions"

## add answers to the dataframe

In [57]:
indications.head(2)

h,answers
str,list[str]
"""MESH:D003894""","[""HP:0010677"", ""HP:0000103""]"
"""CHEBI:61030""","[""DOID:893""]"


In [58]:
res_df = res_df.join(indications, on="h", how="left")

res_df.head(2)

h,method,predicted_answers,answers
str,str,list[str],list[str]
"""CHEBI:31652""","""RotatE""","[""DOID:0050860"", ""DOID:1909"", … ""GO:0043234""]","[""DOID:10534""]"
"""CHEBI:3173""","""RotatE""","[""MESH:D002375"", ""DOID:0060224"", … ""DOID:13949""]","[""HP:0001663""]"


## get answer positions

In [59]:
res_df = (
    res_df.explode("answers")
    .explode("predicted_answers")
    .with_columns(match=pl.col("answers") == pl.col("predicted_answers"))
    .group_by(["h", "method", "answers"], maintain_order=True)
    .agg(["match", "predicted_answers"])
    .with_columns(
        pl.when(pl.col("answers").is_in(pl.col("predicted_answers")))
        .then(pl.col("match").list.arg_max() + 1)
        .otherwise(None)
    )  # get rank of correct answer
    .sort(["h", "method", "match"], nulls_last=True)
    # # put ranks in order so we can calculate the filtered rank
    .group_by(["h", "method", "predicted_answers"], maintain_order=True)
    .agg(["answers", "match"])
    # # calculate how much to decrease the ranks by
    .with_columns(rank=pl.int_ranges(0, pl.col("match").list.len()))
    .explode(["answers", "match", "rank"])
    .with_columns(
        rank=pl.when(pl.col("match").is_not_null())
        .then(pl.col("match") - pl.col("rank"))
        .otherwise(None)
    )
)

res_df.head(2)

h,method,predicted_answers,answers,match,rank
str,str,list[str],str,u32,i64
"""CHEBI:135735""","""CBR""","[""DOID:3393"", ""DOID:6000"", … ""DOID:0060343""]","""DOID:10763""",3,3
"""CHEBI:135735""","""CBR""","[""DOID:3393"", ""DOID:6000"", … ""DOID:0060343""]","""DOID:10591""",13,12


In [60]:
res_df.shape

(19579, 6)

## Export the results with ranks dataframe

In [61]:
res_df.select(pl.all().exclude("match")).write_parquet("results_df.parquet")

## Given the answers, recalculate the MRR/Hits

### Functions to process results

In [62]:
def calculate_hits(
    results: pl.DataFrame,
    test_set: pl.DataFrame,
    hitsk: int = 10,
    method: str = "probCBR",
):
    """
    Calculate the hits at k for a given method on the test set of queries
    """
    res = test_set.join(
        results.select(["h", "answers", "method", "rank"]).filter(
            pl.col("method") == method
        ),
        on=["h", "answers"],
        how="left",
    ).with_columns(pl.col("rank").fill_null(10000000))

    score_ls = res["rank"].to_list()

    return (np.array(score_ls) <= hitsk).sum() / np.array(score_ls).shape[0]

In [63]:
def calculate_mrr(
    results: pl.DataFrame, test_set: pl.DataFrame, method: str = "probCBR"
):
    """
    Calculate the mean reciprocal rank for a given method on the test set of queries
    """
    res = test_set.join(
        results.select(["h", "answers", "method", "rank"]).filter(
            pl.col("method") == method
        ),
        on=["h", "answers"],
        how="left",
    ).with_columns(pl.col("rank").fill_null(10000000))

    score_ls = res["rank"].to_list()

    return (1 / np.array(score_ls)).mean()

### Results
* most algorithms did worse than before with the exception of pCBR which performed drastically better.

In [64]:
for i in res_df["method"].unique().to_list():
    print(f"{i} Stats")
    print(f"MRR: {calculate_mrr(res_df,test_set,i):.4f}")
    print(f"Hits@1: {calculate_hits(res_df,test_set,1,i):.4f}")
    print(f"Hits@3: {calculate_hits(res_df,test_set,3,i):.4f}")
    print(f"Hits@10: {calculate_hits(res_df,test_set,10,i):.4f}")
    print(f"\n")

DistMult Stats
MRR: 0.0391
Hits@1: 0.0098
Hits@3: 0.0293
Hits@10: 0.0781


CBR Stats
MRR: 0.0128
Hits@1: 0.0020
Hits@3: 0.0156
Hits@10: 0.0352


TransE Stats
MRR: 0.1601
Hits@1: 0.0645
Hits@3: 0.1816
Hits@10: 0.3770


ComplEx Stats
MRR: 0.0820
Hits@1: 0.0273
Hits@3: 0.0801
Hits@10: 0.2031


pCBR Stats
MRR: 0.2557
Hits@1: 0.1816
Hits@3: 0.3008
Hits@10: 0.3984


Rephetio Stats
MRR: 0.0816
Hits@1: 0.0488
Hits@3: 0.0781
Hits@10: 0.1445


RotatE Stats
MRR: 0.1396
Hits@1: 0.0840
Hits@3: 0.1406
Hits@10: 0.2637


