# Summary

We need to rerun `add_adjacency_distances.ipynb` code for the following domains:

***training_dataset***

~72% of rows finished.

```raw
[('database_id=G3DSA%3A3.40.50.1440', 62807),
 ('database_id=G3DSA%3A3.20.80.10', 43126),
 ('database_id=G3DSA%3A3.30.1050.10', 34935),
 ('database_id=G3DSA%3A3.30.110.60', 11123),
 ('database_id=G3DSA%3A1.10.10.440', 12939),
 ('database_id=G3DSA%3A1.10.10.41', 2895),
 ('database_id=G3DSA%3A3.30.1140.32', 17973),
 ('database_id=G3DSA%3A3.30.1120.40', 3065),
 ('database_id=G3DSA%3A2.60.40.1200', 53),
 ('database_id=G3DSA%3A1.10.10.250', 18310),
 ('database_id=G3DSA%3A2.60.40.1090', 45878),
 ('database_id=G3DSA%3A3.50.7.10', 57672),
 ('database_id=G3DSA%3A3.50.70.10', 6224),
 ('database_id=G3DSA%3A1.10.10.10', 3400746),
 ('database_id=G3DSA%3A3.20.20.70', 1701341),
 ('database_id=G3DSA%3A1.25.40.10', 2105462),
 ('database_id=G3DSA%3A3.20.90.10', 5021),
 ('database_id=G3DSA%3A3.40.640.10', 940256),
 ('database_id=G3DSA%3A1.10.10.190', 11),
 ('database_id=G3DSA%3A1.10.10.410', 28270),
 ('database_id=G3DSA%3A1.10.10.180', 1311),
 ('database_id=G3DSA%3A3.30.1040.10', 11),
 ('database_id=G3DSA%3A2.60.40.1240', 16791),
 ('database_id=G3DSA%3A1.10.10.500', 1171),
 ('database_id=G3DSA%3A3.30.110.40', 21055),
 ('database_id=G3DSA%3A3.50.50.60', 1600839),
 ('database_id=G3DSA%3A1.10.10.390', 428),
 ('database_id=G3DSA%3A2.60.40.10', 2604549),
 ('database_id=G3DSA%3A3.30.1120.30', 4708),
 ('database_id=G3DSA%3A1.10.10.400', 18468),
 ('database_id=G3DSA%3A2.60.40.1120', 184571),
 ('database_id=G3DSA%3A1.10.10.430', 1060)]
```

***validation_dataset***

~82% of rows finished.

```raw
[('database_id=G3DSA%3A2.102.10.10', 2531),
 ('database_id=G3DSA%3A2.20.140.10', 182),
 ('database_id=G3DSA%3A1.25.40.270', 340),
 ('database_id=G3DSA%3A2.10.25.30', 23),
 ('database_id=G3DSA%3A2.30.29.30', 30432),
 ('database_id=G3DSA%3A1.25.40.20', 40925),
 ('database_id=G3DSA%3A2.120.10.90', 1824),
 ('database_id=G3DSA%3A2.30.170.40', 1510),
 ('database_id=G3DSA%3A2.100.10.20', 45),
 ('database_id=G3DSA%3A2.20.90.10', 4),
 ('database_id=G3DSA%3A2.30.130.10', 2377),
 ('database_id=G3DSA%3A1.20.59.10', 964),
 ('database_id=G3DSA%3A2.170.14.10', 59),
 ('database_id=G3DSA%3A2.20.170.10', 10),
 ('database_id=G3DSA%3A2.100.10.30', 349),
 ('database_id=G3DSA%3A2.10.290.10', 180),
 ('database_id=G3DSA%3A1.25.10.30', 970),
 ('database_id=G3DSA%3A2.30.18.10', 189)]
```

***test_dataset***

~95.6% of rows finished.

```raw
[('database_id=G3DSA%3A2.40.290.10', 610),
 ('database_id=G3DSA%3A2.30.30.390', 56),
 ('database_id=G3DSA%3A2.40.200.10', 14),
 ('database_id=G3DSA%3A2.40.230.10', 589),
 ('database_id=G3DSA%3A2.40.230.20', 399),
 ('database_id=G3DSA%3A2.40.270.10', 8268),
 ('database_id=G3DSA%3A2.40.180.10', 5573)]
```

# Imports

In [None]:
import concurrent.futures
import concurrent.futures.process
import importlib
import os
import shlex
import shutil
import subprocess
import sys
import warnings
from functools import partial
from itertools import islice
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import yaml

from kmbio import PDB
from kmtools import structure_tools

In [None]:
%matplotlib inline

pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

In [None]:
%run spark.ipynb

# Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "add_adjacency_distances_retry"))
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
Path.cwd().expanduser()

In [None]:
DEBUG = "CI" not in os.environ    

if DEBUG:
    ADJACENCY_MATRIX_PARQUET_PATH = (
        Path(os.getenv("DATAPKG_OUTPUT_DIR"))
        .joinpath("adjacency-net-v2", "master", "training_dataset", "adjacency_matrix.parquet")
    )
else:
    ADJACENCY_MATRIX_PARQUET_PATH = (
        Path(os.getenv("ADJACENCY_MATRIX_PARQUET_PATH")).expanduser()
    )

assert ADJACENCY_MATRIX_PARQUET_PATH.is_dir()

DEBUG, ADJACENCY_MATRIX_PARQUET_PATH

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

In [None]:
FAILED_DOMAIN_NAMES = {
    'training_dataset': [t[0] for t in 
        [('database_id=G3DSA%3A3.40.50.1440', 62807),
         ('database_id=G3DSA%3A3.20.80.10', 43126),
         ('database_id=G3DSA%3A3.30.1050.10', 34935),
         ('database_id=G3DSA%3A3.30.110.60', 11123),
         ('database_id=G3DSA%3A1.10.10.440', 12939),
         ('database_id=G3DSA%3A1.10.10.41', 2895),
         ('database_id=G3DSA%3A3.30.1140.32', 17973),
         ('database_id=G3DSA%3A3.30.1120.40', 3065),
         ('database_id=G3DSA%3A2.60.40.1200', 53),
         ('database_id=G3DSA%3A1.10.10.250', 18310),
         ('database_id=G3DSA%3A2.60.40.1090', 45878),
         ('database_id=G3DSA%3A3.50.7.10', 57672),
         ('database_id=G3DSA%3A3.50.70.10', 6224),
         ('database_id=G3DSA%3A1.10.10.10', 3400746),
         ('database_id=G3DSA%3A3.20.20.70', 1701341),
         ('database_id=G3DSA%3A1.25.40.10', 2105462),
         ('database_id=G3DSA%3A3.20.90.10', 5021),
         ('database_id=G3DSA%3A3.40.640.10', 940256),
         ('database_id=G3DSA%3A1.10.10.190', 11),
         ('database_id=G3DSA%3A1.10.10.410', 28270),
         ('database_id=G3DSA%3A1.10.10.180', 1311),
         ('database_id=G3DSA%3A3.30.1040.10', 11),
         ('database_id=G3DSA%3A2.60.40.1240', 16791),
         ('database_id=G3DSA%3A1.10.10.500', 1171),
         ('database_id=G3DSA%3A3.30.110.40', 21055),
         ('database_id=G3DSA%3A3.50.50.60', 1600839),
         ('database_id=G3DSA%3A1.10.10.390', 428),
         ('database_id=G3DSA%3A2.60.40.10', 2604549),
         ('database_id=G3DSA%3A3.30.1120.30', 4708),
         ('database_id=G3DSA%3A1.10.10.400', 18468),
         ('database_id=G3DSA%3A2.60.40.1120', 184571),
         ('database_id=G3DSA%3A1.10.10.430', 1060)]
    ],
    'validation_dataset': [t[0] for t in 
        [('database_id=G3DSA%3A2.102.10.10', 2531),
         ('database_id=G3DSA%3A2.20.140.10', 182),
         ('database_id=G3DSA%3A1.25.40.270', 340),
         ('database_id=G3DSA%3A2.10.25.30', 23),
         ('database_id=G3DSA%3A2.30.29.30', 30432),
         ('database_id=G3DSA%3A1.25.40.20', 40925),
         ('database_id=G3DSA%3A2.120.10.90', 1824),
         ('database_id=G3DSA%3A2.30.170.40', 1510),
         ('database_id=G3DSA%3A2.100.10.20', 45),
         ('database_id=G3DSA%3A2.20.90.10', 4),
         ('database_id=G3DSA%3A2.30.130.10', 2377),
         ('database_id=G3DSA%3A1.20.59.10', 964),
         ('database_id=G3DSA%3A2.170.14.10', 59),
         ('database_id=G3DSA%3A2.20.170.10', 10),
         ('database_id=G3DSA%3A2.100.10.30', 349),
         ('database_id=G3DSA%3A2.10.290.10', 180),
         ('database_id=G3DSA%3A1.25.10.30', 970),
         ('database_id=G3DSA%3A2.30.18.10', 189)]
    ],
    'test_dataset': [t[0] for t in     
        [('database_id=G3DSA%3A2.40.290.10', 610),
         ('database_id=G3DSA%3A2.30.30.390', 56),
         ('database_id=G3DSA%3A2.40.200.10', 14),
         ('database_id=G3DSA%3A2.40.230.10', 589),
         ('database_id=G3DSA%3A2.40.230.20', 399),
         ('database_id=G3DSA%3A2.40.270.10', 8268),
         ('database_id=G3DSA%3A2.40.180.10', 5573)]
    ],
}

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['pdb-ffindex'] = {
    'pdb_mmcif_ffindex': (
        Path(os.environ['DATAPKG_OUTPUT_DIR'])
        .joinpath("pdb-ffindex", "master", "pdb_mmcif_ffindex", "pdb-mmcif")
    ),
}

# Load data

In [None]:
Path(os.getenv("DATAPKG_OUTPUT_DIR"))

In [None]:
files = sorted(sum(
    (list(
        Path(os.getenv("DATAPKG_OUTPUT_DIR"))
        .joinpath("adjacency-net-v2", "master", dataset_name, "adjacency_matrix.parquet", domain_name)
        .glob("*.parquet")
    )
    for dataset_name in ["training_dataset", "validation_dataset", "test_dataset"]
    for domain_name in FAILED_DOMAIN_NAMES[dataset_name])
, []))

In [None]:
{f.parent.parent for f in files}

In [None]:
df = (
    pq.ParquetFile(files[0])
    .read_row_group(0, use_pandas_metadata=True)
    .to_pandas(integer_object_nulls=True)
    .set_index("__index_level_0__")
)

In [None]:
df.head(1)

# Run pipeline

In [None]:
def worker(data):
    row = helper.to_namedtuple(data)
    results = {}
    try:
        results['residue_idx_1'], results['residue_idx_2'], results['distances'] = (
            helper.get_adjacency_with_distances(
                row, max_cutoff=12, min_cutoff=None, structure_url_prefix=STRUCTURE_URL_PREFIX
            )
        )
#         results['residue_idx_1'] = results['residue_idx_1'].tolist()
#         results['residue_idx_2'] = results['residue_idx_2'].tolist()
#         results['distances'] = results['distances'].tolist()
        results["error"] = None
    except Exception as e:
        results = {
            "residue_idx_1": None,
            "residue_idx_2": None,
            "distances": None,
            "error": f"{type(e)}: {e}",
        }
    return results

## Run worker for single row

In [None]:
start = 0
row = next(islice(df.itertuples(), start, start + 1))

In [None]:
STRUCTURE_URL_PREFIX = f"ff://{DATAPKG['pdb-ffindex']['pdb_mmcif_ffindex']}?"
STRUCTURE_URL_PREFIX

In [None]:
# residue_idx_1, residue_idx_2, distance = helper.get_adjacency_with_distances(
#     row, max_cutoff=12, min_cutoff=None, structure_url_prefix=STRUCTURE_URL_PREFIX
# )

In [None]:
worker(row._asdict())

## Run for all rows

In [None]:
def get_new_file(file, suffix):
    file_parts = list(file.parts)
    file_parts[-4] = file_parts[-4] + suffix
    new_file = Path(*file_parts)
    return new_file

get_new_file(files[0], "_wdistances_failed")

### Examine subset of files

In [None]:
print(len(files))

In [None]:
if DEBUG:
    files_arrow_not_implemented = set()

    for file in files:
        try:
            ds = pq.ParquetFile(file)
            table = ds.read_row_group(0, use_pandas_metadata=True)
        except pa.ArrowNotImplementedError:
            files_arrow_not_implemented.add(file)

    files = [f for f in files if f not in files_arrow_not_implemented]

In [None]:
print(len(files))

In [None]:
if DEBUG:
    files_to_run = files[:10]
else:
    files_to_run = files

In [None]:
print(len(files_to_run))

In [None]:
for file in files_to_run:
    ds = pq.ParquetFile(file)
    new_file = get_new_file(file, "_wdistances")
    new_file.parent.mkdir(parents=True, exist_ok=True)
    writer_failed_is_initialized = False
    using_spark_sql = False
    for row_group in range(ds.num_row_groups):
        # Read row group (or entire file using Spark SQL if PyArrow fails)
        try:
            df = (
                ds.read_row_group(row_group, use_pandas_metadata=True)
                .to_pandas(integer_object_nulls=True)
                .set_index("__index_level_0__")
            )
        except pa.ArrowNotImplementedError:
            using_spark_sql = True
            df = spark.read.parquet(file.as_posix()).toPandas()

        try:
            with concurrent.futures.ProcessPoolExecutor(psutil.cpu_count(logical=False)) as pool:
                futures = pool.map(
                    worker,
                    (t._asdict() for t in df[helper.GET_ADJACENCY_WITH_DISTANCES_ROW_ATTRIBUTES].itertuples()),
                    chunksize=1)
                results = list(futures)
                results_df = pd.DataFrame(results)
                df["residue_idx_1_corrected"] = results_df["residue_idx_1"].values
                df["residue_idx_2_corrected"] = results_df["residue_idx_2"].values
                df["distances"] = results_df["distances"].values
                df["error_adding_distances"] = results_df["error"].values
                num_errors = df["error_adding_distances"].notnull().sum()
                if num_errors:
                    print(f"Encountered {num_errors} errors when parsing file '{file}'.")
        except concurrent.futures.process.BrokenProcessPool as e:
            warnings.warn(
                f"ProcessPool crashed while processing row_group '{row_group}' in file '{file}'."
                f"The error is '{type(e)}': {e}."
            )
            break

        df_succeeded = df[
            df[["residue_idx_1_corrected", "residue_idx_2_corrected", "distances"]].notnull().all(axis=1)
        ]
        df_failed =  df[
            df[["residue_idx_1_corrected", "residue_idx_2_corrected", "distances"]].isnull().all(axis=1)
        ]
        assert len(df_succeeded) + len(df_failed) == len(df)

        # Write successful results
        table = pa.Table.from_pandas(df_succeeded, preserve_index=True)
        if row_group == 0:
            writer = pq.ParquetWriter(new_file, table.schema, version="2.0", flavor="spark")
        writer.write_table(table)

        # Write failed results
        if not df_failed.empty:
            table_failed = pa.Table.from_pandas(df_failed, preserve_index=True)
            if not writer_failed_is_initialized:
                new_file_failed = get_new_file(file, "_wdistances_failed")
                new_file_failed.parent.mkdir(parents=True, exist_ok=True)
                writer_failed = pq.ParquetWriter(new_file_failed, table_failed.schema, version="2.0", flavor="spark")
                writer_failed_is_initialized = True
            writer_failed.write_table(table_failed)

        # print(len(df_succeeded), len(df_failed))

        if using_spark_sql:
            break

    writer.close()
    if writer_failed_is_initialized:
        writer_failed.close()

In [None]:
file

In [None]:
new_file

In [None]:
new_file_failed

In [None]:
if DEBUG:
    display(df.head())

    # Make sure that the file we wrote makes sense
    for file in files_to_run:
        new_file = get_new_file(file)

        ds = pq.ParquetFile(file)
        ds_new = pq.ParquetFile(new_file)
        assert ds.num_row_groups == ds_new.num_row_groups

        for row_group in range(ds.num_row_groups):
            df = (
                ds.read_row_group(0, use_pandas_metadata=True)
                .to_pandas(integer_object_nulls=True)
                .set_index("__index_level_0__")
            )
            df_new = (
                ds_new.read_row_group(0, use_pandas_metadata=True)
                .to_pandas(integer_object_nulls=True)
#                 .set_index("__index_level_0__")
            )
            shared_columns = [
                c for c in df.columns
                if c in df_new.columns
                and c not in [
                    'a2b', 'b2a', 'residue_idx_1', 'residue_idx_2',
                    'residue_id_1', 'residue_id_2', 'residue_aa_1', 'residue_aa_2',
                    "residue_idx_1_corrected", "residue_idx_2_corrected"]
            ]
            assert (df[shared_columns] == df_new[shared_columns]).all().all()
            assert all(
                (l1 == l2).all() 
                for l1, l2
                in zip(df["residue_id_1"].values, df_new["residue_id_1"].values)
            )
            assert all(
                (l1.shape != l2.shape or not (l1 == l2).all())
                for l1, l2
                in zip(df["residue_idx_1_corrected"].values, df_new["residue_idx_1_corrected"].values)
            )