## Summary

----

## Imports

In [None]:
import concurrent.futures
import concurrent.futures.process
import importlib
import os
import shlex
import shutil
import subprocess
import sys
import warnings
from functools import partial
from itertools import islice
from pathlib import Path

import numpy as np
import yaml
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
from kmbio import PDB
from kmtools import structure_tools

In [None]:
%matplotlib inline

pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

## Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "03_add_adjacency_distances"))
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
Path.cwd().expanduser()

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
ADJACENCY_MATRIX_PARQUET_PATH = os.getenv("ADJACENCY_MATRIX_PARQUET_PATH")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None
ADJACENCY_MATRIX_PARQUET_PATH = (
    Path(ADJACENCY_MATRIX_PARQUET_PATH).expanduser() if ADJACENCY_MATRIX_PARQUET_PATH is not None else None
)

TASK_ID, TASK_COUNT, ADJACENCY_MATRIX_PARQUET_PATH

In [None]:
DEBUG = "CI" not in os.environ    

if DEBUG:
    TASK_ID = 78
    TASK_COUNT = 1029
    ADJACENCY_MATRIX_PARQUET_PATH = (
        Path(os.getenv("DATAPKG_OUTPUT_DIR"))
        .joinpath("adjacency-net-v2", "v0.3", "training_dataset", "adjacency_matrix.parquet")
    )
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None
    assert ADJACENCY_MATRIX_PARQUET_PATH is not None

assert ADJACENCY_MATRIX_PARQUET_PATH.is_dir()

TASK_ID, TASK_COUNT, ADJACENCY_MATRIX_PARQUET_PATH

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

## `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG["pdb-ffindex"] = {
    "pdb_mmcif_ffindex": (
        Path(os.environ["DATAPKG_OUTPUT_DIR"]).joinpath("pdb-ffindex", "2018-09-06", "pdb-mmcif")
    )
}

## Load data

In [None]:
files = sorted([f for f in ADJACENCY_MATRIX_PARQUET_PATH.glob("**/*.parquet") if f.is_file()])

print(files[:2])
print(len(files))

In [None]:
{f.parent.parent for f in files}

In [None]:
chunk_size = int(np.ceil(len(files) / TASK_COUNT))
if len(files) > chunk_size:
    files = files[(TASK_ID - 1) * chunk_size:TASK_ID * chunk_size]

print(len(files))

In [None]:
df = (
    pq.ParquetFile(files[0])
    .read_row_group(0, use_pandas_metadata=True)
    .to_pandas(integer_object_nulls=True)
    .set_index("__index_level_0__")
)

In [None]:
df.head(2)

## Run pipeline

### Test on a single row

In [None]:
row = list(islice(df.itertuples(), 3))[0]

In [None]:
STRUCTURE_URL_PREFIX = f"ff://{DATAPKG['pdb-ffindex']['pdb_mmcif_ffindex']}?"
STRUCTURE_URL_PREFIX

In [None]:
results = helper.get_adjacency_with_distances_and_orientations(
    row, max_cutoff=12, min_cutoff=None, structure_url_prefix=STRUCTURE_URL_PREFIX
)

In [None]:
ar = results["distance"]

In [None]:
fg, ax = plt.subplots()
ax.hist(ar.to_pylist(), range=(0, 12), bins=100)
None

In [None]:
pa.RecordBatch.from_arrays(list(results.values()), list(results.keys()))

### Test as part of a multiprocessing worker

In [None]:
def worker(data):
    row = helper.to_namedtuple(data)

    results = None
    failures = None

    try:
        results = helper.get_adjacency_with_distances_and_orientations(
            row, max_cutoff=12, min_cutoff=None, structure_url_prefix=STRUCTURE_URL_PREFIX
        )
    except Exception as e:
        failures = {"error": pa.array([f"{type(e)}: {e}"])}

    for column in [
        "Index",
        "uniparc_id",
        "sequence",
        "database",
        "interpro_name",
        "interpro_id",
        "domain_start",
        "domain_end",
        "domain_length",
        "structure_id",
        "model_id",
        "chain_id",
        "pc_identity",
        "alignment_length",
        "mismatches",
        "gap_opens",
        "q_start",
        "q_end",
        "s_start",
        "s_end",
        "evalue_log10",
        "bitscore",
        "qseq",
        "sseq",
    ]:
        if results is not None:
            results[column] = pa.array([data[column]])
        if failures is not None:
            failures[column] = pa.array([data[column]])

    for column in ["a2b", "b2a", "residue_id_1", "residue_id_2", "residue_aa_1", "residue_aa_2"]:
        if data[column].dtype in (int, float):
            values = [(int(i) if pd.notnull(i) else None) for i in data[column]]
        else:
            values = data[column].tolist()
        if results is not None:
            results[column] = pa.array([values])
        if failures is not None:
            failures[column] = pa.array([values])

    return results, failures

In [None]:
worker(row._asdict())

### Run for all rows

In [None]:
def get_new_file(file, failed=False):
    file_parts = list(file.parts)
    file_parts[-4] = file_parts[-4] + "_wdistances" + ("_failed" if failed else "")
    new_file = Path(*file_parts)
    return new_file

In [None]:
n_rows_processed = 0
for file in tqdm(files):
    ds = pq.ParquetFile(file)

    new_file = get_new_file(file)
    new_file.parent.mkdir(parents=True, exist_ok=True)
    writer = None

    new_file_failed = get_new_file(file, failed=True)
    new_file_failed.parent.mkdir(parents=True, exist_ok=True)
    writer_failed = None

    for row_group in tqdm(range(ds.num_row_groups), leave=False):
        df = (
            ds.read_row_group(row_group, use_pandas_metadata=True)
            .to_pandas(integer_object_nulls=True)
            .set_index("__index_level_0__")
        )
        try:
            with concurrent.futures.ProcessPoolExecutor(psutil.cpu_count(logical=False)) as pool:
                futures = pool.map(worker, (t._asdict() for t in df.itertuples()), chunksize=1)
                results, failures = list(zip(*list(tqdm(futures, leave=False, total=len(df)))))

            num_failures = sum(1 for r in results if r is None)
            results = [r for r in results if r is not None]
            failures = [f for f in failures if f is not None]
            assert len(failures) == num_failures
            if num_failures:
                print(f"Encountered {num_failures} errors when parsing file '{file}'.")
            n_rows_processed += len(df)
        except concurrent.futures.process.BrokenProcessPool as e:
            warnings.warn(
                f"ProcessPool crashed while processing row_group '{row_group}' in file '{file}'."
                f"The error is '{type(e)}': {e}."
            )
            break

        if results:
            if writer is None:
                result = results[0]
                batch = pa.RecordBatch.from_arrays(list(result.values()), list(result.keys()))
                writer = pa.RecordBatchFileWriter(new_file, batch.schema)
            for result in results:
                batch = pa.RecordBatch.from_arrays(list(result.values()), list(result.keys()))
                writer.write_batch(batch)

        if failures:
            if writer_failed is None:
                failure = failures[0]
                batch = pa.RecordBatch.from_arrays(list(failure.values()), list(failure.keys()))
                writer_failed = pa.RecordBatchFileWriter(new_file_failed, batch.schema)
        for failure in failures:
            batch = pa.RecordBatch.from_arrays(list(failure.values()), list(failure.keys()))
            writer_failed.write_batch(batch)

    if writer is not None:
        writer.close()

    if writer_failed is not None:
        writer_failed.close()

### Test that everything went ok

In [None]:
try:
    reader = pa.RecordBatchFileReader(new_file)
except pa.ArrowIOError:
    num_successful_batches = 0
else:
    num_successful_batches = reader.num_record_batches

In [None]:
try:
    reader_failed = pa.RecordBatchFileReader(new_file_failed)
except pa.ArrowIOError:
    num_failed_batches = 0
else:
    num_failed_batches = reader_failed.num_record_batches

In [None]:
assert (num_successful_batches + num_failed_batches) == len(df)