# Summary

# Imports

In [None]:
import importlib
import os
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
from scipy import stats
from sklearn import metrics

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_homology_models_combined')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
PROJECT_VERSION = os.getenv("PROJECT_VERSION")

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    PROJECT_VERSION = "0.1"
else:
    assert PROJECT_VERSION is not None
    
PROJECT_VERSION

In [None]:
# if DEBUG:
#     %load_ext autoreload
#     %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['validation_homology_models'] = sorted(
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("adjacency-net-v2", f"v{PROJECT_VERSION}", "validation_homology_models")
    .glob("*/*_dataset.parquet")
)

In [None]:
DATAPKG['validation_homology_models']

# Dataset

## Construct datasets

### `homology_models_dataset`

In [None]:
validation_df = None


def assert_eq(a1, a2):
    if isinstance(a1[0], np.ndarray):
        for b1, b2 in zip(a1, a2):
            b1 = b1[~np.isnan(b1)]
            b2 = b2[~np.isnan(b2)]
            assert len(b1) == len(b2)
            assert (b1 == b2).all()
    else:
        assert (a1 == a2).all()
            

for file in DATAPKG['validation_homology_models']:
    df = pq.read_table(file, use_pandas_metadata=True).to_pandas(integer_object_nulls=True)
    df.drop(pd.Index(['error']), axis=1, inplace=True)
    if validation_df is None:
        validation_df = df
    else:
        validation_df = (
            validation_df
            .merge(df, how="outer", left_index=True, right_index=True, validate="1:1", suffixes=("", "_dup"))
        )
        for col in validation_df.columns:
            if col.endswith(f"_dup"):
                col_ref = col[:-4]
                assert_eq(validation_df[col], validation_df[col_ref])
                del validation_df[col]

In [None]:
homology_models_dataset = validation_df.copy()

In [None]:
homology_models_dataset.head(2)

### `homology_models_dataset_filtered`

In [None]:
fg, ax = plt.subplots()
homology_models_dataset["identity_calc"].hist(bins=100, ax=ax)

In [None]:
IDENTITY_CUTOFF = 1.00

query_ids_w3plus = {
    query_id
    for query_id, group in 
        homology_models_dataset[
            (homology_models_dataset["identity_calc"] <= IDENTITY_CUTOFF)
        ]
        .groupby('query_id')
    if len(group) >= 10
}

homology_models_dataset_filtered = (
    homology_models_dataset[
        (homology_models_dataset["identity_calc"] <= IDENTITY_CUTOFF) &
        (homology_models_dataset['query_id'].isin(query_ids_w3plus))
    ]
    .copy()
)

print(len(homology_models_dataset))
print(len(homology_models_dataset_filtered))

### `homology_models_dataset_final`

In [None]:
# homology_models_dataset_final = homology_models_dataset.copy()

In [None]:
homology_models_dataset_final = homology_models_dataset_filtered.copy()

# Plotting

## Prepare data

### Correlations for the entire dataset

In [None]:
target_columns = [
    'dope_score',
    'dope_score_norm',
    'ga341_score',
    'rosetta_score',
]

feature_columns = [
    "identity_calc",
    # "coverage_calc", 

    "identity", 
    "similarity",
    "score",  # "probability", "evalue",
    "sum_probs",
    
#     "query_match_length", 
#     "template_match_length",
]

network_columns = [
    c
    for c in homology_models_dataset_final.columns
    if (c.endswith("_pdb") or c.endswith("_hm"))
    and not (c.startswith("adjacency_idx") or c.startswith("frac_aa_wadj"))
]

results_df = homology_models_dataset_final.dropna(subset=network_columns).copy()
print(f"Lost {len(homology_models_dataset_final) - len(results_df)} columns with nulls!")

for col in ['dope_score', 'dope_score_norm', 'rosetta_score']:
    results_df[col] = -results_df[col]

In [None]:
len(network_columns)

### Correlations for each sequence independently

In [None]:
data = []

for query_id, group in results_df.groupby('query_id'):
    assert (group['sequence'].str.replace('-', '') == group['sequence'].iloc[0].replace('-', '')).all()
    assert (group['query_match_length'] == group['query_match_length'].iloc[0]).all()

    if len(group) < 3:
        print(f"Skipping small group for query_id = '{query_id}'")
        continue

    for y_col in target_columns:
        if len(group) < 3 or len(set(group[y_col])) == 1:
            print(f"skipping y_col '{y_col}'")
            continue
        for x_col in feature_columns + network_columns:
            if x_col in ['query_match_length']:
                continue
            if len(group) < 3 or len(set(group[x_col])) == 1:
                print(f"skipping x_col '{x_col}'")
                continue
            corr, pvalue = stats.spearmanr(group[x_col], group[y_col])
            data.append((y_col, x_col, corr, pvalue))
            
correlations_df = pd.DataFrame(data, columns=['target', 'feature', 'correlation', 'pvalue'])

In [None]:
network_columns_sorted = (
    correlations_df[
        (correlations_df['target'] == 'dope_score_norm') &
        (correlations_df['feature'].isin(network_columns))
    ]
    .groupby("feature", as_index=True)
    ['correlation']
    .mean()
    .sort_values(ascending=False)
    .index
    .tolist()
)

assert len(network_columns_sorted) == len(network_columns)

## Make Plots

In [None]:
def plot(df, columns):
    
    mat = np.zeros((len(columns), len(columns)), float)
    for i, c1 in enumerate(columns):
        for j, c2 in enumerate(columns):
            mat[i, j] = stats.spearmanr(df[c1], df[c2])[0]

    fig, ax = plt.subplots()
    im = ax.imshow(mat)

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(columns)))
    ax.set_yticks(np.arange(len(columns)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(columns)
    ax.set_yticklabels(columns)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(columns)):
        for j in range(len(columns)):
            text = ax.text(j, i, f"{mat[i, j]:.2f}", ha="center", va="center", color="w")

    ax.set_title("Spearman correlation between alignment, structure, and network scores")

In [None]:
features = target_columns + feature_columns + network_columns_sorted
dim = 4 + 0.4 * len(features)

with plt.rc_context(rc={'figure.figsize': (dim, dim), 'font.size': 11}):
    plot(results_df, features)

plt.tight_layout()
plt.savefig(OUTPUT_PATH.joinpath("validation_homology_models_corr_all.png"), dpi=300, bbox_inches="tight")
plt.savefig(OUTPUT_PATH.joinpath("validation_homology_models_corr_all.pdf"), bbox_inches="tight")

In [None]:
ignore = ['query_match_length']
features = [c for c in feature_columns + network_columns_sorted if c not in ignore]
figsize = (2 + 0.5 * len(features), 6)

for i, target in enumerate(target_columns):
    corr = [
        correlations_df[
            (correlations_df['target'] == target) &
            (correlations_df['feature'] == feature)
        ]['correlation'].values
        for feature in features
    ]
    with plt.rc_context(rc={'figure.figsize': figsize, 'font.size': 14}):
        plt.boxplot(corr)
        plt.ylim(-0.55, 1.05)
        plt.xticks(range(1, len(features) + 1), features, rotation=45, ha="right", rotation_mode="anchor")
        plt.ylabel("Spearman R")
        plt.title(f"{target} (identity cutoff: {IDENTITY_CUTOFF:.2})")
        plt.tight_layout()
        plt.savefig(OUTPUT_PATH.joinpath(f"{target}_corr_gby_query.png"), dpi=300, bbox_inches="tight", transparent=False, frameon=True)
        plt.savefig(OUTPUT_PATH.joinpath(f"{target}_corr_gby_query.pdf"), bbox_inches="tight", transparent=False, frameon=True)
        plt.show()