# Summary

# Imports

In [None]:
import importlib
import os
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import sqlalchemy as sa
from scipy import stats
from sklearn import metrics

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_homology_models_combined')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
PROJECT_VERSION = os.getenv("PROJECT_VERSION")

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    PROJECT_VERSION = "0.1"
else:
    assert PROJECT_VERSION is not None
    
PROJECT_VERSION

In [None]:
# if DEBUG:
#     %load_ext autoreload
#     %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['validation_training_stats'] = sorted(
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("adjacency-net-v2", f"v{PROJECT_VERSION}", "train_network")
    .glob("*/stats.db")
)

In [None]:
DATAPKG['validation_training_stats']

# Dataset

## Construct datasets

### `validation_df`

In [None]:
data = []

for stats_db_file in DATAPKG['validation_training_stats']:
    engine = sa.create_engine(f"sqlite:///{stats_db_file}")

    network_names = (
        pd.read_sql_query("select distinct network_name from info", engine)
        ['network_name'].values.tolist()
    )
    assert len(network_names) == 1
    network_name = network_names[0]

    network_name_ref = stats_db_file.parent.name
    assert network_name[4:11] == network_name_ref[:7], (network_name, network_name_ref)

    # Select best step
    best_step_df = pd.read_sql_query(
        "SELECT step "
        "FROM stats "
        "WHERE model_location IS NOT NULL "
        "ORDER BY `validation_gan_permute_80_1000-auc` DESC, `validation_gan_exact_80_1000-auc` DESC "
        "LIMIT 1 ", engine)
    best_step = int(best_step_df.values)

    training_auc_max = (
        pd.read_sql_query(
            "select `training_pos-auc` as training_auc_max "
            "from stats "
            f"where step = {best_step} ",
            engine)
        ["training_auc_max"].values.tolist()[0]
    )

    validation_permute_auc_max = (
        pd.read_sql_query(
            "select `validation_gan_permute_80_1000-auc` as validation_permute_auc_max "
            "from stats "
            f"where step = {best_step} ",
            engine)
        ["validation_permute_auc_max"].values.tolist()[0]
    )

    validation_exact_auc_max = (
        pd.read_sql_query(
            "select `validation_gan_exact_80_1000-auc` as validation_exact_auc_max "
            "from stats "
            f"where step = {best_step} ",
            engine)
        ["validation_exact_auc_max"].values.tolist()[0]
    )

    data.append((network_name, training_auc_max, validation_permute_auc_max, validation_exact_auc_max))
    
validation_df = pd.DataFrame(
    data,
    columns=["network_name", "training_auc_max", "validation_permute_auc_max", "validation_exact_auc_max"])
validation_df = validation_df.sort_values("validation_exact_auc_max", ascending=False)
validation_df['network_slug'] = validation_df['network_name'].str[4:11]

In [None]:
validation_df

# Plot

In [None]:
cmap = plt.get_cmap("Set1")

feature_names = {}

for column in ["training_auc_max", "validation_permute_auc_max", "validation_exact_auc_max"]:
    with plt.rc_context(rc={'figure.figsize': (2 + 0.6 * len(validation_df), 4), 'font.size': 13}):
        x = np.arange(len(validation_df))
        plt.bar(x, validation_df[column], color=cmap(2))
        plt.xticks(x, validation_df['network_slug'], rotation=45)
        plt.ylim(0.8, 1.01)
        plt.ylabel("AUC")
        plt.title(column)
        plt.tight_layout()
        plt.savefig(OUTPUT_PATH.joinpath(f"{column}.png"), dpi=300, bbox_inches="tight")
        plt.savefig(OUTPUT_PATH.joinpath(f"{column}.pdf"), bbox_inches="tight")
        plt.show()