# Summary

# Imports

In [None]:
import concurrent.futures
import concurrent.futures.process
import importlib
import os
import shlex
import shutil
import subprocess
import sys
import warnings
from functools import partial
from itertools import islice
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import yaml

from kmbio import PDB
from kmtools import structure_tools

In [None]:
%matplotlib inline

pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "add_adjacency_distances"))
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
Path.cwd().expanduser()

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
ADJACENCY_MATRIX_PARQUET_PATH = os.getenv("ADJACENCY_MATRIX_PARQUET_PATH")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None
ADJACENCY_MATRIX_PARQUET_PATH = (
    Path(ADJACENCY_MATRIX_PARQUET_PATH).expanduser()
    if ADJACENCY_MATRIX_PARQUET_PATH is not None
    else None
)

TASK_ID, TASK_COUNT, ADJACENCY_MATRIX_PARQUET_PATH

In [None]:
DEBUG = "CI" not in os.environ    

if DEBUG:
    TASK_ID = 78
    TASK_COUNT = 1029
    ADJACENCY_MATRIX_PARQUET_PATH = (
        Path(os.getenv("DATAPKG_OUTPUT_DIR"))
        .joinpath("adjacency-net-v2", "master", "training_dataset", "adjacency_matrix.parquet")
    )
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None
    assert ADJACENCY_MATRIX_PARQUET_PATH is not None

assert ADJACENCY_MATRIX_PARQUET_PATH.is_dir()

TASK_ID, TASK_COUNT, ADJACENCY_MATRIX_PARQUET_PATH

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['pdb-ffindex'] = {
    'pdb_mmcif_ffindex': (
        Path(os.environ['DATAPKG_OUTPUT_DIR'])
        .joinpath("pdb-ffindex", "master", "pdb_mmcif_ffindex", "pdb-mmcif")
    ),
}

# Load data

In [None]:
files = sorted([
    f for f in ADJACENCY_MATRIX_PARQUET_PATH.glob("**/*.parquet")
    if f.is_file()
])

print(files[:2])
print(len(files))

In [None]:
{f.parent.parent for f in files}

In [None]:
chunk_size = int(np.ceil(len(files) / TASK_COUNT))
if len(files) > chunk_size:
    files = files[(TASK_ID - 1) * chunk_size:TASK_ID * chunk_size]

print(len(files))

In [None]:
df = (
    pq.ParquetFile(files[0])
    .read_row_group(0, use_pandas_metadata=True)
    .to_pandas(integer_object_nulls=True)
    .set_index("__index_level_0__")
)

In [None]:
df.head(2)

# Run pipeline

In [None]:
def worker(data):
    row = helper.to_namedtuple(data)
    results = {}
    try:
        results['residue_idx_1'], results['residue_idx_2'], results['distances'] = (
            helper.get_adjacency_with_distances(
                row, max_cutoff=12, min_cutoff=None, structure_url_prefix=STRUCTURE_URL_PREFIX
            )
        )
#         results['residue_idx_1'] = results['residue_idx_1'].tolist()
#         results['residue_idx_2'] = results['residue_idx_2'].tolist()
#         results['distances'] = results['distances'].tolist()
        results["error"] = None
    except Exception as e:
        results["error"] = f"{type(e)}: {e}"
    return results

## Run worker for single row

In [None]:
row = list(islice(df.itertuples(), 3))[0]

In [None]:
STRUCTURE_URL_PREFIX = f"ff://{DATAPKG['pdb-ffindex']['pdb_mmcif_ffindex']}?"
STRUCTURE_URL_PREFIX

In [None]:
residue_idx_1, residue_idx_2, distance = helper.get_adjacency_with_distances(
    row, max_cutoff=12, min_cutoff=None, structure_url_prefix=STRUCTURE_URL_PREFIX
)

In [None]:
fg, ax = plt.subplots()
ax.hist(distance, range=(0, 12), bins=100)
None

In [None]:
worker(row._asdict())

## Run for all rows

In [None]:
def get_new_file(file):
    file_parts = list(file.parts)
    file_parts[-4] = file_parts[-4] + "_wdistances"
    new_file = Path(*file_parts)
    return new_file

In [None]:
for file in files:
    ds = pq.ParquetFile(file)
    for row_group in range(ds.num_row_groups):
        df = (
            ds.read_row_group(row_group, use_pandas_metadata=True)
            .to_pandas(integer_object_nulls=True)
            .set_index("__index_level_0__")
        )
        try:
            with concurrent.futures.ProcessPoolExecutor(psutil.cpu_count(logical=False)) as pool:
                futures = pool.map(
                    worker,
                    (t._asdict() for t in df[helper.GET_ADJACENCY_WITH_DISTANCES_ROW_ATTRIBUTES].itertuples()),
                    chunksize=1)
                results = list(futures)
                results_df = pd.DataFrame(results)
                df["residue_idx_1_corrected"] = results_df["residue_idx_1"].values
                df["residue_idx_2_corrected"] = results_df["residue_idx_2"].values
                df["distances"] = results_df["distances"].values
                df["error_adding_distances"] = results_df["error"].values
                num_errors = df["error_adding_distances"].notnull().sum()
                if num_errors:
                    print(f"Encountered {num_errors} errors when parsing file '{file}'.")
        except concurrent.futures.process.BrokenProcessPool as e:
            warnings.warn(
                f"ProcessPool crashed while processing row_group '{row_group}' in file '{file}'."
                f"The error is '{type(e)}': {e}."
            )
            break
        df = df.dropna(subset=["residue_idx_1_corrected", "residue_idx_2_corrected", "distances"])
        table = pa.Table.from_pandas(df, preserve_index=True)
        if row_group == 0:
            new_file = get_new_file(file)
            new_file.parent.mkdir(parents=True, exist_ok=True)
            writer = pq.ParquetWriter(new_file, table.schema, version="2.0", flavor="spark")
        writer.write_table(table)
        if row_group == ds.num_row_groups - 1:
            writer.close()

In [None]:
if DEBUG:
    display(df.head())

    # Make sure that the file we wrote makes sense
    for file in files:
        new_file = get_new_file(file)

        ds = pq.ParquetFile(file)
        ds_new = pq.ParquetFile(new_file)
        assert ds.num_row_groups == ds_new.num_row_groups

        for row_group in range(ds.num_row_groups):
            df = (
                ds.read_row_group(0, use_pandas_metadata=True)
                .to_pandas(integer_object_nulls=True)
                .set_index("__index_level_0__")
            )
            df_new = (
                ds_new.read_row_group(0, use_pandas_metadata=True)
                .to_pandas(integer_object_nulls=True)
#                 .set_index("__index_level_0__")
            )
            shared_columns = [
                c for c in df.columns
                if c in df_new.columns
                and c not in [
                    'a2b', 'b2a', 'residue_idx_1', 'residue_idx_2',
                    'residue_id_1', 'residue_id_2', 'residue_aa_1', 'residue_aa_2',
                    "residue_idx_1_corrected", "residue_idx_2_corrected"]
            ]
            assert (df[shared_columns] == df_new[shared_columns]).all().all()
            assert all(
                (l1 == l2).all() 
                for l1, l2
                in zip(df["residue_id_1"].values, df_new["residue_id_1"].values)
            )
            assert all(
                (l1.shape != l2.shape or not (l1 == l2).all())
                for l1, l2
                in zip(df["residue_idx_1_corrected"].values, df_new["residue_idx_1_corrected"].values)
            )