# Summary

# Imports

In [None]:
import importlib
import os
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import sqlalchemy as sa
from scipy import stats
from sklearn import metrics

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_homology_models_combined')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
PROJECT_VERSION = os.getenv("PROJECT_VERSION")

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    PROJECT_VERSION = "0.2"
else:
    assert PROJECT_VERSION is not None
    
PROJECT_VERSION

In [None]:
# if DEBUG:
#     %load_ext autoreload
#     %autoreload 2

# `DATAPKG`

In [None]:
NETWORK_NAME = ",".join([
    "4a4320bd49d7b25fe9018c1b40426a45b1642565",  # test50-cedar

    "a195c0e680a6dec151ea19de4735a6577dde399b",  # test50
    "654c84ccb1bc0ecd0fa5d16c31ab3bfe21d45c8b",  # test51

    "a3556373181d42ce0985e8d2146cfd5b0788502e",  # test65

    "7b4ff1af3ec63a01fa415435420c554be1fecbb0",  # test74
    "55374d153b6646f041dde6ee49ab751ef2d833aa",
    "a7c0444c959a656be8ff6acbf88ef36fd02c59fc",
    "8aa30e0188404d429ecdc6357205bc6924fb7759",
    "9b134475368bd81fa1de197f8180ff1c82ce8727",
    "4e2968caa1d0a9cb9fdee0488a3ede2283bce316",
    "b22189e7357853cc5c76c9435b1c0497030761dd",
])

# Dataset

## Construct datasets

### `validation_df`

In [None]:
data = []

for network_name in NETWORK_NAME.split(','):
    stats_db_file = (
        Path(os.environ['DATAPKG_OUTPUT_DIR'])
        .joinpath("adjacency-net-v2", network_name, "train_network", "stats.db")
    )
    assert stats_db_file.is_file(), stats_db_file
    engine = sa.create_engine(f"sqlite:///{stats_db_file}")

    network_names = (
        pd.read_sql_query("select distinct network_name from info", engine)
        ['network_name'].values.tolist()
    )
    assert len(network_names) == 1
    assert network_names[0][4:] == network_name  

    # Select best step
    best_step_df = pd.read_sql_query(
        "SELECT step "
        "FROM stats "
        "WHERE model_location IS NOT NULL "
        "ORDER BY `validation_gan_permute_80_1000-auc` DESC, `validation_gan_exact_80_1000-auc` DESC "
        "LIMIT 1 ", engine)
    best_step = int(best_step_df.values)

    training_auc_max = (
        pd.read_sql_query(
            "select `training_pos-auc` as training_auc_max "
            "from stats "
            f"where step = {best_step} ",
            engine)
        ["training_auc_max"].values[0]
    )

    validation_permute_auc_max = (
        pd.read_sql_query(
            "select `validation_gan_permute_80_1000-auc` as validation_permute_auc_max "
            "from stats "
            f"where step = {best_step} ",
            engine)
        ["validation_permute_auc_max"].values[0]
    )

    validation_exact_auc_max = (
        pd.read_sql_query(
            "select `validation_gan_exact_80_1000-auc` as validation_exact_auc_max "
            "from stats "
            f"where step = {best_step} ",
            engine)
        ["validation_exact_auc_max"].values[0]
    )

    max_info_id = (
        pd.read_sql_query(
            "select max(`info_id`) as max_info_id "
            "from stats ",
            engine)
        ["max_info_id"].values.item()
    )

    max_sequence_number = (
        pd.read_sql_query(
            "select max(`sequence_number`) as max_sequence_number "
            "from stats ",
            engine)
        ["max_sequence_number"].values.item()
    )

    data.append((network_name, training_auc_max, validation_permute_auc_max, validation_exact_auc_max, max_info_id, max_sequence_number))
    
validation_df = pd.DataFrame(
    data,
    columns=["network_name", "training_auc_max", "validation_permute_auc_max", "validation_exact_auc_max", "max_info_id", "max_sequence_number"])
# validation_df = validation_df.sort_values("validation_exact_auc_max", ascending=False)
validation_df['network_slug'] = validation_df['network_name'].str[0:7]

In [None]:
validation_df

# Plot

In [None]:
cmap = plt.get_cmap("Set1")

feature_names = {}

for column in ["training_auc_max", "validation_permute_auc_max", "validation_exact_auc_max", "max_info_id", "max_sequence_number"]:
    with plt.rc_context(rc={'figure.figsize': (2 + 0.6 * len(validation_df), 4), 'font.size': 13}):
        x = np.arange(len(validation_df))
        plt.bar(x, validation_df[column], color=cmap(2))
        plt.xticks(x, validation_df['network_slug'], rotation=45)
        if ((validation_df[column] > 0) & (validation_df[column] <= 1)).all():
            plt.ylim(0.7, 1.01)
        if column == "max_info_id":
            plt.ylabel("Number of job arrays completed")
        elif column == "max_sequence_number":
            plt.ylabel("Number of sequences seen")
        else:
            plt.ylabel("AUC")
        plt.title(column)
        plt.tight_layout()
        plt.savefig(OUTPUT_PATH.joinpath(f"{column}.png"), dpi=300, bbox_inches="tight")
        plt.savefig(OUTPUT_PATH.joinpath(f"{column}.pdf"), bbox_inches="tight")
        plt.show()