# Summary

# Imports

In [None]:
import concurrent.futures
import itertools
import importlib
import multiprocessing
import os
import os.path as op
import pickle
import subprocess
import sys
import tempfile
from collections import Counter
from functools import partial
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import sqlalchemy as sa
from scipy import stats

from kmtools import py_tools, sequence_tools

In [None]:
%matplotlib inline

In [None]:
np.seterr(all='raise')

In [None]:
pd.set_option("max_columns", 100)
pd.set_option("max_rows", 200)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())
    
import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_decoy_discrimination')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
proc = subprocess.run(["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE)
GIT_REV = proc.stdout.decode().strip()
GIT_REV

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
NETWORK_NAME = os.getenv("CI_COMMIT_SHA")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT, NETWORK_NAME

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    NETWORK_NAME = "dcn_old_0,6bbf5b792c30570b8ab1a4c1b3426cdc6ad84446"
else:
    assert NETWORK_NAME is not None
    
NETWORK_NAME

In [None]:
# if DEBUG:
#     %load_ext autoreload
#     %autoreload 2

# `DATAPKG`

Can be skipped when `DEBUG = True`.

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['adjacency_net_v2'] = {
    'decoy_discrimination_dataset': (
        Path(os.environ['DATAPKG_OUTPUT_DIR'])
        .joinpath(
            "adjacency-net-v2",
            "v0.1",
            "decoy_discrimination_dataset",
            "f9e8ffb64f4d1335b075e656bc83d3dd1824d513")
    ),
    'decoy_discrimination_dataset_rosetta': (
        Path(os.environ['DATAPKG_OUTPUT_DIR'])
        .joinpath(
            "adjacency-net-v2",
            "v0.1",
            "decoy_discrimination_dataset_rosetta",
            "f9e8ffb64f4d1335b075e656bc83d3dd1824d513")
    ),
}

# Load data

Can be skipped when `DEBUG = True`.

## Read Parquet files

### `decoy_discrimination_dataset_quick`

In [None]:
parquet_files = sorted(
    DATAPKG['adjacency_net_v2']['decoy_discrimination_dataset'].glob('*.parquet')
)

assert len(parquet_files) == 200
parquet_files[:2]

In [None]:
dfs = []

for file in parquet_files:
    df = pq.read_table(file, use_pandas_metadata=True).to_pandas(integer_object_nulls=True)
    dfs.append(df)

decoy_discrimination_dataset_quick = pd.concat(dfs, ignore_index=False)

In [None]:
assert len(set(decoy_discrimination_dataset_quick['unique_id'])) == len(decoy_discrimination_dataset_quick['unique_id'])

In [None]:
decoy_discrimination_dataset_quick.head(2)

### `decoy_discrimination_dataset_rosetta`

In [None]:
parquet_files = sorted(
    DATAPKG['adjacency_net_v2']['decoy_discrimination_dataset_rosetta'].glob('*.parquet')
)

assert len(parquet_files) == 200
parquet_files[:2]

In [None]:
dfs = []

for file in parquet_files:
    df = pq.read_table(file, use_pandas_metadata=True).to_pandas(integer_object_nulls=True)
    dfs.append(df)

decoy_discrimination_dataset_rosetta = pd.concat(dfs, ignore_index=False)

In [None]:
assert len(set(decoy_discrimination_dataset_rosetta['unique_id'])) == len(decoy_discrimination_dataset_rosetta['unique_id'])

In [None]:
decoy_discrimination_dataset_rosetta.head(2)

## Combine into a single dataset

### `decoy_discrimination_dataset`

In [None]:
# TODO: Figure out why we need an inner join
decoy_discrimination_dataset = (
    decoy_discrimination_dataset_quick
    .merge(
        decoy_discrimination_dataset_rosetta[
            ['unique_id'] + 
            [c for c in decoy_discrimination_dataset_rosetta.columns if c.startswith("rosetta_")]],
        on='unique_id', how='inner',
        validate="1:1",
    )
)

assert len(set(decoy_discrimination_dataset['unique_id'])) == len(decoy_discrimination_dataset_quick)

In [None]:
display(decoy_discrimination_dataset.head(1))
print(len(decoy_discrimination_dataset))

# Run network

Can be skipped when `DEBUG = True`.

In [None]:
%run trained_networks.ipynb

## Predictions using PDB adjacencies

In [None]:
for network_name in reversed(NETWORK_NAME.split(',')):
    decoy_discrimination_dataset[network_name] = (
        helper.predict_with_network(
            decoy_discrimination_dataset.rename(columns={
                'residue_idx_1': 'adjacency_idx_1',
                'residue_idx_2': 'adjacency_idx_2',
            }),
            network_state=TRAINED_NETWORKS[network_name]['network_state'],
            network_info=TRAINED_NETWORKS[network_name]['network_info'],
        )
    )
    assert decoy_discrimination_dataset[network_name].notnull().all(), network_name

## Save to cache

In [None]:
table = pa.Table.from_pandas(decoy_discrimination_dataset, preserve_index=True)
pq.write_table(
    table,
    OUTPUT_PATH.joinpath(f"dataset.parquet"),
    version='2.0',
    flavor='spark',
    row_group_size=2000,
)

In [None]:
if not DEBUG:
    decoy_discrimination_dataset_bak = decoy_discrimination_dataset.copy()

# Analysis

## Load data

In [None]:
dfs = []

parquet_file = pq.ParquetFile(OUTPUT_PATH.joinpath(f"dataset.parquet"))

for row_group in range(parquet_file.num_row_groups):
    df = (
        parquet_file
        .read_row_group(row_group, use_pandas_metadata=True)
        .to_pandas(integer_object_nulls=True)
    )
    dfs.append(df)
    
decoy_discrimination_dataset = pd.concat(dfs, ignore_index=False)

In [None]:
try:
    decoy_discrimination_dataset_bak
except NameError:
    pass
else:
    assert len(decoy_discrimination_dataset) == len(decoy_discrimination_dataset_bak)
    assert (decoy_discrimination_dataset.index == decoy_discrimination_dataset_bak.index).all()

In [None]:
decoy_discrimination_dataset.head(2)

In [None]:
assert len(set(decoy_discrimination_dataset['structure_id'])) == 200

## `combined_stats`

In [None]:
data = []
columns = ["feature", "correlation", "pvalue", "zscore", "top100", "top200", "top400", "top800"]
skipped_features = set()

for c in set(columns) & set(globals()):
    del globals()[c]

for feature in decoy_discrimination_dataset.select_dtypes(include=['number']).columns:
    if feature in ["rmsd"]:
        continue
    df = decoy_discrimination_dataset[["rmsd", feature, "decoy_name"]].dropna().copy()

    if len(df) != len(decoy_discrimination_dataset):
        skipped_features.add(feature)
        continue

    try:
        correlation, pvalue = stats.spearmanr(df["rmsd"], df[feature])
    except FloatingPointError:
        skipped_features.add(feature)
        continue

    df["zscore"] = stats.zscore(df[feature])
    zscore = df[df["decoy_name"] == "native.pdb"]["zscore"].mean()

    df_sorted = df.sample(frac=1).reset_index(drop=True).sort_values(feature, ascending=correlation > 0)
    df_sorted["rank"] = range(1, len(df) + 1)
    rank = df_sorted[df_sorted["decoy_name"] == "native.pdb"]["rank"]
    top100 = sum(rank <= 100)
    top200 = sum(rank <= 200)
    top400 = sum(rank <= 400)
    top800 = sum(rank <= 800)

    data.append([globals()[c] for c in columns])

combined_stats = pd.DataFrame(data, columns=columns)

In [None]:
combined_stats["correlation_abs"] = combined_stats["correlation"].abs()
combined_stats["zscore_abs"] = combined_stats["zscore"].abs()

In [None]:
combined_stats.sort_values("zscore_abs", ascending=False).head()

In [None]:
table = pa.Table.from_pandas(combined_stats, preserve_index=True)
pq.write_table(
    table,
    OUTPUT_PATH.joinpath(f"combined_stats.parquet"),
    version='2.0',
    flavor='spark',
)

## `per_structure_stats`

In [None]:
data = []
columns = ['query_id', 'feature', 'correlation', 'pvalue', 'rank', 'zscore']
skipped_features = Counter()

for c in set(columns) & set(globals()):
    del globals()[c]

for query_id, group in decoy_discrimination_dataset.groupby('structure_id'):
    for feature in decoy_discrimination_dataset.select_dtypes(include=['number']).columns:
        df = group.copy()
        if feature in ["rmsd"]:
            continue
        try:
            correlation, pvalue = stats.spearmanr(df["rmsd"], df[feature])
        except FloatingPointError:
            skipped_features[feature] += 1
            continue

        df["zscore"] = stats.zscore(df[feature])
        zscore = df[df["decoy_name"] == "native.pdb"]["zscore"].iloc[0]

        df_sorted = group.sample(frac=1).reset_index(drop=True).sort_values(feature, ascending=correlation > 0)
        df_sorted["rank"] = range(1, len(df) + 1)
        rank = df_sorted[df_sorted["decoy_name"] == "native.pdb"]["rank"].iloc[0]
        data.append([globals()[c] for c in columns])

per_structure_stats = pd.DataFrame(data, columns=columns)

In [None]:
per_structure_stats.head()

In [None]:
table = pa.Table.from_pandas(per_structure_stats, preserve_index=True)
pq.write_table(
    table,
    OUTPUT_PATH.joinpath(f"per_structure_stats.parquet"),
    version='2.0',
    flavor='spark',
)

## `per_structure_stats_agg`

In [None]:
def topk(s, k):
    return sum(s <= k)

def top1(s):
    return topk(s, 1)

def top5(s):
    return topk(s, 5)

def top10(s):
    return topk(s, 10)

def top20(s):
    return topk(s, 20)


per_structure_stats_agg = (
    per_structure_stats
    .groupby(["feature"])
    .agg({
        "correlation": "mean",
        "pvalue": "mean",
        "zscore": "mean",
        "rank": [top1, top5, top10, top20]
#         "rank_top1": lambda s: topk(s, 1),
    })
    .reset_index()
)

In [None]:
per_structure_stats_agg.columns = per_structure_stats_agg.columns.values
if isinstance(per_structure_stats_agg.columns[0], tuple):
    per_structure_stats_agg = per_structure_stats_agg.rename(
        columns=lambda c: c[0] if c[1] in ["", "mean"] else "_".join(c))

In [None]:
per_structure_stats_agg["correlation_abs"] = per_structure_stats_agg["correlation"].abs()
per_structure_stats_agg["zscore_abs"] = per_structure_stats_agg["zscore"].abs()

In [None]:
per_structure_stats_agg.sort_values("zscore_abs", ascending=False).head()

In [None]:
table = pa.Table.from_pandas(per_structure_stats_agg, preserve_index=True)
pq.write_table(
    table,
    OUTPUT_PATH.joinpath(f"per_structure_stats_agg.parquet"),
    version='2.0',
    flavor='spark',
)

# Plotting

In [None]:
all_columns = [
    c for c in decoy_discrimination_dataset.columns
    if c in ["rmsd", "dope_score", "normalized_dope_score"]
    or c.startswith("modeller_")
    or c.startswith("rosetta_")
    or c in NETWORK_NAME.split(",")
]

target_columns = [
    'rmsd',
]

feature_columns = [
    "rosetta_relax_total_score",
    "rosetta_score_total_score",
    "normalized_dope_score",
    "ga341_score_1",
    "ga341_score_2",
    "ga341_score_3",
    "ga341_score_4",
    "ga341_score_5",
    "ga341_score_6",
    "ga341_score_7",
]

network_columns = NETWORK_NAME.split(",")

## `combined_stats`

In [None]:
features = feature_columns + network_columns
data = combined_stats.set_index("feature").loc[features].reset_index()
labels = [c if len(c) < 32 or any(s in c for s in [' ', '_', '-']) else c[:7] for c in features]
colors = [cmap(1)] * len(feature_columns) + [cmap(2)] * len(network_columns)

for i, feature in enumerate(["correlation_abs", "zscore_abs", "top100"]):
    fig, axes = plt.subplots(dpi=100, constrained_layout=True)
    ax = axes
    ax.bar("feature", feature, data=data, tick_label=labels, color=colors)
    for tick in ax.get_xticklabels():
        tick.set_rotation(90)
#     ax.set_ylim(0, 0.8)
    ax.set_ylabel(feature)
    plt.savefig(OUTPUT_PATH.joinpath(f"combined_stats_{feature}.svg"), dpi=300, bbox_inches="tight")

## `per_structure_stats`

In [None]:
columns = feature_columns + network_columns

data = (
    per_structure_stats
    .groupby("feature")
    .agg(tuple)
    .loc[columns]
    .reset_index()
)
data.loc[data["feature"].isin(network_columns), "correlation"] = (
    data.loc[data["feature"].isin(network_columns), "correlation"]
    .apply(lambda row: tuple(-r for r in row))
)
data.loc[data["feature"].isin(feature_columns), "zscore"] = (
    data.loc[data["feature"].isin(feature_columns), "zscore"]
    .apply(lambda row: tuple(-r for r in row))
)

labels = [c if len(c) < 32 or any(s in c for s in [' ', '_', '-']) else c[:7] for c in columns]

cmap = plt.get_cmap("Set1")
colors = [cmap(1)] * len(feature_columns) + [cmap(2)] * len(network_columns)

boxplot_rc = {
    "boxplot.boxprops.linewidth": 1.5,
    'boxplot.whiskerprops.linewidth': 1.5,
    "boxplot.meanprops.linewidth": 1.5,
    "boxplot.medianprops.color": 'k',
}

for i, feature in enumerate(["correlation", "zscore", "rank"]):
    fig, axes = plt.subplots(dpi=100, constrained_layout=True)
    ax = axes
    with plt.rc_context(rc=boxplot_rc):
        ax.boxplot(feature, data=data, labels=labels, sym="")
    for feature_idx, points in enumerate(data[feature]):
        jitter = np.random.normal(feature_idx + 1, 0.05, len(points))
        ax.scatter(jitter, points, c=[colors[feature_idx]], alpha=0.3)
    for tick in ax.get_xticklabels():
        tick.set_rotation(90)
    # ax.set_ylim(0, 0.8)
    ax.set_ylabel(feature)
    if feature in ["rank"]:
        ax.invert_yaxis()
    plt.savefig(OUTPUT_PATH.joinpath(f"per_structure_stats_{feature}.svg"), dpi=300, bbox_inches="tight")

## `per_structure_stats_agg`

In [None]:
features = feature_columns + network_columns
data = per_structure_stats_agg.set_index("feature").loc[features].reset_index()
labels = [c if len(c) < 32 or any(s in c for s in [' ', '_', '-']) else c[:7] for c in features]
colors = [cmap(1)] * len(feature_columns) + [cmap(2)] * len(network_columns)

for i, feature in enumerate(["correlation_abs", "zscore_abs", "rank_top1"]):
    fig, axes = plt.subplots(dpi=100, constrained_layout=True)
    ax = axes
    ax.bar("feature", feature, data=data, tick_label=labels, color=colors)
    for tick in ax.get_xticklabels():
        tick.set_rotation(90)
#     ax.set_ylim(0, 1)
    ax.set_ylabel(feature)
    plt.savefig(OUTPUT_PATH.joinpath(f"per_structure_stats_agg_{feature}.svg"), dpi=300, bbox_inches="tight")