# Generate predictions for each KGE model
* The goal of this notebook is to generate predictions for a given set of models
* This notebook will also highlight how to use some of the functions in `score_utils2`
* Finally we will extract the top `k` results

In [None]:
import os
import pandas as pd
import polars as pl

os.chdir("./Consilience-Drug-Repurposing/Notebooks")
import score_utils2 as scu

## Get predictions for test file

In [None]:
os.chdir("../kge")

### TransE

In [None]:
!python -u codes/run.py --do_predict --do_test -init models/TransE_MIND_optimized #--cuda

### DistMult

In [None]:
!python -u codes/run.py --do_predict --do_test -init models/DistMult_MIND_optimized #--cuda

### ComplEx

In [None]:
!python -u codes/run.py --do_predict --do_test -init models/ComplEx_MIND_optimized #--cuda

### RotatE

In [None]:
!python -u codes/run.py --do_predict --do_test -init models/RotatE_MIND_optimized --cuda

## Process the outputs
### Create score input as tail-batching
* Function that removes all 'head-batch' entities if choosing 'tail-batch'
* OR removes all 'tail-batch' entities if choosing 'head-batch'

In [None]:
tran_raw = scu.ProcessOutput(
    data_dir="../data/MIND/",
    scores_outfile="./models/TransE_MIND_optimized/test_scores.tsv",
    mode="tail-batch",
)
dist_raw = scu.ProcessOutput(
    data_dir="../data/MIND/",
    scores_outfile="./models/DistMult_MIND_optimized/test_scores.tsv",
    mode="tail-batch",
)
comp_raw = scu.ProcessOutput(
    data_dir="../data/MIND/",
    scores_outfile="./models/ComplEx_MIND_optimized/test_scores.tsv",
    mode="tail-batch",
)
rota_raw = scu.ProcessOutput(
    data_dir="../data/MIND/",
    scores_outfile="./models/RotatE_MIND_optimized/test_scores.tsv",
    mode="tail-batch",
)

### Extract actual names from the dataframe

In [None]:
tran_raw.get_true_targets()

### Format the raw scores to embedded values
* Initial scores datframe has some value ranging from (-,+).
* uses torch function `argsort()` to sort from high to low. Highest value becomes 1, next highest 2 ... to n highest.
* operation is in-place

In [None]:
tran_raw.format_raw_scores_to_df()
dist_raw.format_raw_scores_to_df()
comp_raw.format_raw_scores_to_df()
rota_raw.format_raw_scores_to_df()

### Get actual names 
* conversion of embedding to value are in-place
* note the method has a variable `direction` where it can be "from" or "to". The default is "to", meaning (value TO embedding).

In [None]:
tran_raw.translate_embeddings(direction="from")
dist_raw.translate_embeddings(direction="from")
comp_raw.translate_embeddings(direction="from")
rota_raw.translate_embeddings(direction="from")

### Generate the top _n_ filtered results

In [None]:
tran_df = tran_raw.filter_predictions(top=1000)
dist_df = dist_raw.filter_predictions(top=1000)
comp_df = comp_raw.filter_predictions(top=1000)
rota_df = rota_raw.filter_predictions(top=1000)

### Export the top 1000 for each algo

In [None]:
tran_df.unique(["h", "filt_preds"]).write_parquet(
    "./data_output/test_scores_transe.parquet"
)
dist_df.unique(["h", "filt_preds"]).write_parquet(
    "./data_output/test_scores_distmult.parquet"
)
comp_df.unique(["h", "filt_preds"]).write_parquet(
    "./data_output/test_scores_complex.parquet"
)
rota_df.unique(["h", "filt_preds"]).write_parquet(
    "./data_output/test_scores_rotate.parquet"
)