# Summary

# Imports

In [None]:
import pyarrow

In [None]:
import concurrent.futures
import itertools
import multiprocessing
import os
import os.path as op
import pickle
import subprocess
import tempfile
from functools import partial
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import sqlalchemy as sa
from scipy import stats

from kmtools import py_tools, sequence_tools

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_training_stats')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
proc = subprocess.run(["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE)
GIT_REV = proc.stdout.decode().strip()
GIT_REV

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
NETWORK_NAME = os.getenv("CI_COMMIT_SHA")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT, NETWORK_NAME

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    NETWORK_NAME = "6bbf5b792c30570b8ab1a4c1b3426cdc6ad84446"
else:
    assert NETWORK_NAME is not None
    
NETWORK_NAME

In [None]:
# if DEBUG:
#     %load_ext autoreload
#     %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['uniparc-domain-wstructure'] = (
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("uniparc-domain-wstructure", "master")
)

In [None]:
DATAPKG['adjacency_net_v2'] = (
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("adjacency-net-v2", "master")
)

In [None]:
DATAPKG['hhsuite-wstructure'] = (
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("hhsuite-wstructure", "master")
)

# Training statistics

In [None]:
%run trained_networks.ipynb

## Load data

In [None]:
engine = sa.create_engine(f"sqlite:///{TRAINED_NETWORKS[NETWORK_NAME]['stats_db']}")

In [None]:
engine.table_names()

In [None]:
info_df = pd.read_sql_table("info", engine)

In [None]:
stats_df = pd.read_sql_table("stats", engine)

## Extract data

In [None]:
display(stats_df.head(2))
print(len(stats_df))

In [None]:
stats_df['preds_list'] = stats_df['preds'].apply(pickle.loads)
stats_df['targets_list'] = stats_df['targets'].apply(pickle.loads)
stats_df['losses_list'] = stats_df['losses'].apply(pickle.loads)

In [None]:
def split_preds_pos_neg(stats_df):
    pos_mean_list = []
    neg_mean_list = []
    for pred, target in stats_df[['preds_list', 'targets_list']].values:
        pos_mean = np.array(pred)[np.array(target, dtype=np.bool)].mean()
        neg_mean = np.array(pred)[~np.array(target, dtype=np.bool)].mean()
        pos_mean_list.append(pos_mean)
        neg_mean_list.append(neg_mean)
    return pos_mean_list, neg_mean_list

if "pos_preds-mean" not in stats_df or "neg_preds-mean" not in stats_df:
    stats_df['pos_preds-mean'], stats_df['neg_preds-mean'] = split_preds_pos_neg(stats_df)

## Plot statistics

In [None]:
def plot_stats(x_col, y_col):
    if x_col not in stats_df.columns or y_col not in stats_df.columns:
        print("Data not available")
        return

    df = stats_df[[x_col, y_col]].dropna()

    plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
    plt.plot(df[x_col], df[y_col], '-', label=y_col)
    plt.xlabel(x_col)
    plt.legend()

In [None]:
plot_stats("sequence_number", "time_between_checkpoints")
plt.ylabel("Time (s)")
None

In [None]:
with plt.rc_context({'figure.figsize': (12, 4), 'font.size': 13}):
    fg, axs = plt.subplots(1, 2)

    plt.sca(axs[0])
    plot_stats("sequence_number", "training_pos-auc")
    plt.ylabel("AUC")
    plt.ylim(0.4, 1)
    plt.title("Training")
    
    plt.sca(axs[1])
    plot_stats("sequence_number", "validation_gan_permute_80_1000-auc")
    plot_stats("sequence_number", "validation_gan_exact_80_1000-auc")
    plt.ylabel("AUC")
    plt.ylim(0.4, 1)
    plt.title("Validation")

    plt.tight_layout()
    plt.savefig(OUTPUT_PATH.joinpath(f"{NETWORK_NAME}_training_validation_auc.png"), dpi=300, bbox_inches="tight")
    plt.savefig(OUTPUT_PATH.joinpath(f"{NETWORK_NAME}_training_validation_auc.pdf"), bbox_inches="tight")

In [None]:
with plt.rc_context({'figure.figsize': (6, 4), 'font.size': 13}):
    plot_stats("sequence_number", "pos_preds-mean")
    plot_stats("sequence_number", "neg_preds-mean")
    plt.ylabel("Average probability")
    plt.title("Training")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.savefig(OUTPUT_PATH.joinpath(f"{NETWORK_NAME}_training_pos_neg_preds.png"), dpi=300, bbox_inches="tight")
    plt.savefig(OUTPUT_PATH.joinpath(f"{NETWORK_NAME}_training_pos_neg_preds.pdf"), bbox_inches="tight")