# Summary

# Imports

In [None]:
import importlib
import os
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
from scipy import stats
from sklearn import metrics

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_remote_homology_detection_combined')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
PROJECT_VERSION = os.getenv("PROJECT_VERSION")

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    PROJECT_VERSION = "0.1"
else:
    assert PROJECT_VERSION is not None
    
PROJECT_VERSION

In [None]:
# if DEBUG:
#     %load_ext autoreload
#     %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['validation_remote_homology_detection'] = sorted(
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("adjacency-net-v2", f"v{PROJECT_VERSION}", "validation_remote_homology_detection")
    .glob("*/*_dataset.parquet")
)

In [None]:
DATAPKG['validation_remote_homology_detection']

# Dataset

## Construct datasets

### `remote_homology_dataset`

In [None]:
validation_df = None


def assert_eq(a1, a2):
    if isinstance(a1[0], np.ndarray):
        for b1, b2 in zip(a1, a2):
            b1 = b1[~np.isnan(b1)]
            b2 = b2[~np.isnan(b2)]
            assert len(b1) == len(b2)
            assert (b1 == b2).all()
    else:
        assert (a1 == a2).all()
            

for file in DATAPKG['validation_remote_homology_detection']:
    df = pq.read_table(file, use_pandas_metadata=True).to_pandas(integer_object_nulls=True)
    if validation_df is None:
        validation_df = df
    else:
        validation_df = (
            validation_df
            .merge(df, how="outer", left_index=True, right_index=True, validate="1:1", suffixes=("", "_dup"))
        )
        for col in validation_df.columns:
            if col.endswith(f"_dup"):
                col_ref = col[:-4]
                assert_eq(validation_df[col], validation_df[col_ref])
                del validation_df[col]

In [None]:
remote_homology_dataset = validation_df.copy()

In [None]:
remote_homology_dataset.head(2)

### `remote_homology_dataset_filtered`

In [None]:
remote_homology_dataset['adjacency_coverage_1'].hist()

In [None]:
remote_homology_dataset['adjacency_coverage_2'].hist()

In [None]:
query_ids_w3plus = {
    query_id
    for query_id, group in 
        remote_homology_dataset
        .groupby('query_id')
    if len(group) >= 3
}

remote_homology_dataset_filtered = (
    remote_homology_dataset[
        remote_homology_dataset['query_id'].isin(query_ids_w3plus)
    ]
    .copy()
)

print(len(remote_homology_dataset))
print(len(remote_homology_dataset_filtered))

# Plotting

## Prepare data for plots

In [None]:
features_ref = [
    "identity_calc",
    "coverage_calc", 
    "identity", "similarity",
    "score",  "probability", # "evalue",
    "sum_probs",
]

features_network = [c for c in remote_homology_dataset.columns if len(c) == 40]

In [None]:
for scop_level in [1, 2, 3, 4]:
    remote_homology_dataset_filtered[f'scop_domain_matches_l{scop_level}'] = (
        remote_homology_dataset_filtered.apply(
            lambda row: 
                '.'.join(row['scop_domain'].split('.')[:scop_level]) == 
                '.'.join(row['scop_domain_canonical'].split('.')[:scop_level]),
            axis=1,
        )
    )

### `DATA_ALL`

In [None]:
DATA_ALL = {}
for scop_level in [1, 2, 3, 4]:
    df = remote_homology_dataset_filtered.copy()
    data = []
    for feature in features_ref + features_network:
        corr, pvalue = stats.spearmanr(df[feature], df[f'scop_domain_matches_l{scop_level}'])
        auc = metrics.roc_auc_score(df[f'scop_domain_matches_l{scop_level}'], df[feature])
        data.append((feature, corr, pvalue, auc))
    out_df = pd.DataFrame(data, columns=['feature', 'correlation', 'pvalue', 'auc'])
    out_df[out_df['feature'].isin(features_network)] = \
        out_df[out_df['feature'].isin(features_network)].sort_values("auc", ascending=False).values
    DATA_ALL[scop_level] = len(df['query_id'].drop_duplicates()), len(df), out_df

### `DATA_GBQ`

In [None]:
DATA_GBQ = {}

num_skips_small = {1: 0, 2: 0, 3: 0, 4: 0}
num_skips_eq = {1: 0, 2: 0, 3: 0, 4: 0}
num_skips_neq = {1: 0, 2: 0, 3: 0, 4: 0}

for scop_level in [1, 2, 3, 4]:
    df = remote_homology_dataset_filtered.copy()
    data = {f: {'corrs': [], 'pvalues': [], 'aucs': []} for f in features_ref + features_network}
    count_groups = 0
    count_rows = 0
    for query_id, group in df.groupby('query_id'):
        if len(group) < 3:
            num_skips_small[scop_level] += 1
            continue
        elif (group[f'scop_domain_matches_l{scop_level}'] == True).all():
            num_skips_eq[scop_level] += 1
            continue
        elif (group[f'scop_domain_matches_l{scop_level}'] == False).all():
            num_skips_neq[scop_level] += 1
            continue
        for feature in features_ref + features_network:
            if len(group[feature].drop_duplicates()) == 1:
                print(f"Skipping '{feature}'")
                continue
            corr, pvalue = stats.spearmanr(group[feature], group[f'scop_domain_matches_l{scop_level}'])
            auc = metrics.roc_auc_score(group[f'scop_domain_matches_l{scop_level}'], group[feature])
            data[feature]['corrs'].append(corr)
            data[feature]['pvalues'].append(pvalue)
            data[feature]['aucs'].append(auc)
        count_groups += 1
        count_rows += len(group)
    data_list = [
        (k, np.mean(v['corrs']), np.mean(v['pvalues']), np.mean(v['aucs']))
        for k, v in data.items()
    ]
    out_df = pd.DataFrame(data_list, columns=['feature', 'correlation', 'pvalue', 'auc'])
    out_df[out_df['feature'].isin(features_network)] = \
        out_df[out_df['feature'].isin(features_network)].sort_values("auc", ascending=False).values
    DATA_GBQ[scop_level] = count_groups, count_rows, out_df
    
print(num_skips_small)
print(num_skips_eq)
print(num_skips_neq)

## Make plots

In [None]:
cmap = plt.cm.get_cmap('Set1', 10)

In [None]:
scop_levels = {
    1: "class",
    2: "fold",
    3: "superfamily",
    4: "family",
}

In [None]:
feature_names = {}

In [None]:
for scop_level in DATA_ALL:
    fg, axs = plt.subplots(1, 2, figsize=(2 + 0.7 * len(features_ref + features_network), 4))

    plt.sca(axs[0])
    num1, num2, df = DATA_ALL[scop_level]
    x = np.arange(len(df))
    c = [cmap(2) if f in features_network else cmap(1) for f in df['feature']]
#     c = cmap(1)
    plt.bar(x, df['auc'].abs(), color=c)
    plt.xticks(x, [feature_names.get(f, f[:7] if len(f) == 40 else f) for f in df['feature'].values], rotation=45)
    plt.ylim(0.4, 1)
    plt.ylabel("AUC")
    plt.title(
        f"Predicting SCOP {scop_levels[scop_level]} - combined\n"
        f"(N = {num2}, M = {num1})")
    plt.hlines(0.5, -0.75, len(df) - 0.25, linestyle='--')
    plt.ylim(0.4, 1)
    plt.xlim(-0.75, len(df) - 0.25)

    plt.sca(axs[1])
    num1, num2, df = DATA_GBQ[scop_level]
    x = np.arange(len(df))
    c = [cmap(2) if f in features_network else cmap(1) for f in df['feature']]
#     c = cmap(1)
    plt.bar(x, df['auc'].abs(), color=c)
    plt.xticks(x, [feature_names.get(f, f[:7] if len(f) == 40 else f) for f in df['feature'].values], rotation=45)
    plt.ylabel("AUC")
    plt.title(
        f"Predicting SCOP {scop_levels[scop_level]} - per protein\n"
        f"(N = {num2}, M = {num1})")
    plt.hlines(0.5, -0.75, len(df) - 0.25, linestyle='--')
    plt.ylim(0.4, 1)
    plt.xlim(-0.75, len(df) - 0.25)
    
    plt.tight_layout()
    plt.savefig(OUTPUT_PATH.joinpath(f"remote_homology_detection_sl{scop_level}.png"), dpi=300, bbox_inches="tight")
    plt.savefig(OUTPUT_PATH.joinpath(f"remote_homology_detection_sl{scop_level}.pdf"), bbox_inches="tight")