## Summary

### Submitting jobs

**Note:** These jobs must be submitted from the <code>./notebooks</code> folder.

**Cedar:**

```bash
NOTEBOOK_PATH=$(realpath 01_process_pdb_interface.ipynb) sbatch --array=1-300 --time=72:00:00 --nodes=1 --ntasks-per-node=48 --mem=0 --job-name=process-pdb-interface --account=rrg-pmkim --output=/scratch/strokach/tmp/log/run-notebook-cpu-%j-%N.log ../scripts/run_notebook_cpu.sh
```

### To Do

Remove hydrogen atoms on all structures.

----

## Imports

In [None]:
import concurrent.futures
import concurrent.futures.process
import gzip
import importlib
import io
import logging
import os
import pickle
import shlex
import shutil
import socket
import subprocess
import sys
import tempfile
import time
import traceback
import warnings
from functools import partial
from itertools import islice
from pathlib import Path

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import mdtraj
import numpy as np
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import tenacity
import yaml
from kmbio import PDB
from kmtools import structure_tools

In [None]:
%matplotlib inline

pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

## Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "01_process_pdb_interface"))

NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

OUTPUT_PATH

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)) // 2)

CPU_COUNT

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    TASK_ID = 300
    TASK_COUNT = 300
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None

TASK_ID, TASK_COUNT

In [None]:
ADJACENCY_NET_DATA_PATH = (
    Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath("adjacency-net-v2", "v0.3").resolve(strict=True)
)

ADJACENCY_NET_DATA_PATH

In [None]:
PDB_DATA_PATH = (
    Path(os.environ["DATAPKG_OUTPUT_DIR"]).joinpath("pdb-ffindex", "2020-01-16", "arrow").resolve(strict=True)
)

PDB_DATA_PATH

## Load data

In [None]:
with PDB_DATA_PATH.joinpath("pdb-list.pickle").open("rb") as fin:
    pdb_list = pickle.load(fin)
    
pdb_data_reader = pa.RecordBatchFileReader(PDB_DATA_PATH.joinpath("pdb-mmcif.arrow"))

assert len(pdb_list) == pdb_data_reader.num_record_batches

In [None]:
pdb_list[:3]

In [None]:
chunk_size = int(np.ceil(len(pdb_list) / TASK_COUNT))
task_idx = TASK_ID - 1
pdb_chunk = pdb_list[task_idx * chunk_size : (task_idx + 1) * chunk_size]
pdb_chunk_idxs = list(range(task_idx * chunk_size, (task_idx + 1) * chunk_size))[:len(pdb_chunk)]
assert all(pdb_chunk[i] == pdb_list[j] for i, j in enumerate(pdb_chunk_idxs))

chunk_size, task_idx, len(pdb_chunk), pdb_chunk[:3]

In [None]:
if DEBUG:
    pdb_chunk = pdb_chunk[:10]
    pdb_chunk_idxs = pdb_chunk_idxs[:10]

## Process structures

In [None]:
def strip_hydrogens(structure_df):
    structure_df = structure_df[
        ~((structure_df["residue_id_0"].str.strip() == "") & structure_df["atom_name"].str.startswith("H"))
    ]
    return structure_df

In [None]:
def get_interacting_chain_pairs(structure_df, distance_cutoff=5, n_residue_pairs_cutoff=5):
    # Strip hetatms
    structure_df = structure_df[structure_df["residue_id_0"].str.strip() == ""]

    # Make sure chain ids uniquely identify chains across all models
    assert len(structure_df.drop_duplicates(["model_idx", "chain_idx"])) == structure_df["chain_idx"].nunique()
    model_lookup = structure_df.drop_duplicates(["model_idx", "residue_idx"]).set_index("residue_idx")["model_idx"]

    # Make sure residue ids uniquely identify residues across all chains
    assert len(structure_df.drop_duplicates(["chain_idx", "residue_idx"])) == structure_df["residue_idx"].nunique()
    chain_lookup = structure_df.drop_duplicates(["chain_idx", "residue_idx"]).set_index("residue_idx")["chain_idx"]

    interactions_df = structure_tools.get_distances(structure_df, max_cutoff=distance_cutoff, groupby="residue")
    interactions_df["model_idx_1"] = interactions_df["residue_idx_1"].map(model_lookup)
    interactions_df["model_idx_2"] = interactions_df["residue_idx_2"].map(model_lookup)
    interactions_df["chain_idx_1"] = interactions_df["residue_idx_1"].map(chain_lookup)
    interactions_df["chain_idx_2"] = interactions_df["residue_idx_2"].map(chain_lookup)

    chain_pairs = {}
    for chain_pair, df in interactions_df.groupby(["chain_idx_1", "chain_idx_2"], as_index=False):
        if chain_pair[0] == chain_pair[1]:
            continue
        num_interacting_residue_pairs = len(df.drop_duplicates(["residue_idx_1", "residue_idx_2"]))
        if num_interacting_residue_pairs < n_residue_pairs_cutoff:
            continue
        chain_pairs[chain_pair] = num_interacting_residue_pairs
    return chain_pairs

In [None]:
def process_pdb_interface(pdb_id, pdb_idx):
    # Load data
    pdb_data_reader = pa.RecordBatchFileReader(PDB_DATA_PATH.joinpath("pdb-mmcif.arrow"))
    pdb_data = pdb_data_reader.get_record_batch(pdb_idx).to_pydict()
    assert pdb_data["pdb_id"][0] == pdb_id

    # Create structure from data
    buf = io.StringIO()
    buf.write(gzip.decompress(pdb_data["mmcif_data"][0]).decode())
    use_auth_id = False
    try:
        buf.seek(0)
        bioassembly_id = True
        structure = PDB.MMCIFParser(use_auth_id=use_auth_id).get_structure(buf, bioassembly_id=bioassembly_id)
    except PDB.BioassemblyError as e:
        print(f"Encountered error when parsing pdb {pdb_idx} ('{pdb_id}'): {e!s}.")
        buf.seek(0)
        bioassembly_id = False
        structure = PDB.MMCIFParser(use_auth_id=use_auth_id).get_structure(buf, bioassembly_id=bioassembly_id)

    structure_df = structure.to_dataframe()
    structure_df = strip_hydrogens(structure_df)
    all_chains = structure_df["chain_idx"].unique().tolist()
    all_chain_pairs = [(i, j) for i in all_chains for j in all_chains]
    interacting_chain_pairs = get_interacting_chain_pairs(structure_df, distance_cutoff=5, n_residue_pairs_cutoff=5)

    results = []
    evaluated_sequence_pairs = set()
    for chain_idx_1, chain_1 in enumerate(structure.chains):
        for chain_idx_2, chain_2 in enumerate(structure.chains):
            assert chain_idx_1 in all_chains
            assert chain_idx_2 in all_chains

            if (chain_idx_1, chain_idx_2) not in interacting_chain_pairs:
                continue

            aa_sequence_1 = structure_tools.get_chain_sequence(chain_1)
            aa_sequence_2 = structure_tools.get_chain_sequence(chain_2)

            aa_sequence_pair = (
                (aa_sequence_1, aa_sequence_2) if (aa_sequence_1 <= aa_sequence_2) else (aa_sequence_2, aa_sequence_1)
            )
            if aa_sequence_pair in evaluated_sequence_pairs:
                continue
            evaluated_sequence_pairs.add(aa_sequence_pair)

            if (len(aa_sequence_1.strip()) < 5) or len(aa_sequence_2.strip()) < 5:
                continue

            # Create a structure with only the two chains of interest
            structure_chunk_df = structure_df[
                (structure_df["chain_idx"] == chain_idx_1) | (structure_df["chain_idx"] == chain_idx_2)
            ].copy()
            chain_idx_array = (
                structure_chunk_df.drop_duplicates("residue_idx")["chain_idx"]
                .map({c: i for i, c in enumerate(structure_chunk_df["chain_idx"].unique())})
                .values
            )
            structure_chunk_df["model_idx"] = 0
            structure_chunk_df["model_id"] = 0
            structure_chunk_df["chain_idx"] = 0
            structure_chunk_df["chain_id"] = "A"
            structure_chunk_df["residue_id_1"] = (
                structure_chunk_df["residue_idx"]
                .map({r: i for i, r in enumerate(structure_chunk_df["residue_idx"].unique())})
                .values
            )

            # Convert to mdtraj trajectory
            schain = PDB.Structure.from_dataframe(structure_chunk_df)
            with tempfile.NamedTemporaryFile(suffix=".pdb") as pdb_file:
                PDB.save(schain, pdb_file.name)
                traj = mdtraj.load(pdb_file.name)
            assert (aa_sequence_1 + aa_sequence_2) == traj.top.to_fasta()[0]

            residue_df = helper.construct_residue_df(traj)
            assert len(residue_df) == len(chain_idx_array)
            residue_df["chain_idx"] = chain_idx_array
            helper.validate_residue_df(residue_df)

            residue_pairs_df = helper.construct_residue_pairs_df(traj)
            helper.validate_residue_pairs_df(residue_pairs_df)

            result = {
                "pdb_id": [pdb_id],
                "pdb_idx": [pdb_idx],
                "use_auth_id": [use_auth_id],
                "bioassembly_id": [bioassembly_id],
                "model_id_1": [chain_1.parent.id],
                "chain_idx_1": [chain_idx_1],
                "chain_id_1": [chain_1.id],
                "chain_idx_2": [chain_idx_2],
                "chain_id_2": [chain_2.id],
                "num_interacting_residue_pairs": [interacting_chain_pairs[(chain_idx_1, chain_idx_2)]],
                "aa_sequence_1": [aa_sequence_1],
                "aa_sequence_2": [aa_sequence_2],
                **helper.residue_df_to_row(residue_df),
                **helper.residue_pairs_df_to_row(residue_pairs_df),
            }
            result = helper.downcast_and_compress(result)
            results.append(result)

    return results

In [None]:
def worker(pdb_id, pdb_idx):
    try:
        results = process_pdb_interface(pdb_id, pdb_idx)
        return results, []
    except Exception as error:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        traceback_string = "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
        failure = {
            "pdb_id": [pdb_id],
            "pdb_idx": [pdb_idx],
            "error_type": [str(type(error))],
            "error_message": [str(error)],
            "error_traceback": [traceback_string],
        }
        return [], [failure]

In [None]:
# result = process_pdb_interface("1orf", 17384)

In [None]:
# result

In [None]:
# with concurrent.futures.ProcessPoolExecutor() as pool:
#     futures = pool.map(worker, pdb_chunk[:10], pdb_chunk_idxs[:10])
#     for result in tqdm(futures, total=100):
#         break

In [None]:
# result

In [None]:
output_dir = ADJACENCY_NET_DATA_PATH.joinpath("pdb-interface")
output_dir.mkdir(exist_ok=True)
output_file = output_dir.joinpath(f"pdb-interface-{TASK_ID}-{TASK_COUNT}.arrow")

output_dir_failed = output_dir.joinpath("failed")
output_dir_failed.mkdir(exist_ok=True)
output_file_failed = output_dir_failed.joinpath(f"pdb-interface-{TASK_ID}-{TASK_COUNT}-failed.arrow")

output_file, output_file_failed

In [None]:
logging.getLogger("kmtools.structure_tools.fixes").setLevel(logging.CRITICAL)

In [None]:
writer = None
writer_failed = None
num_pdbs_processed = 0
while num_pdbs_processed < len(pdb_chunk):
    try:
        with concurrent.futures.ProcessPoolExecutor(CPU_COUNT) as pool:
            futures = pool.map(worker, pdb_chunk[num_pdbs_processed:], pdb_chunk_idxs[num_pdbs_processed:], chunksize=1)
            for (results, results_failed) in tqdm(futures, total=len(pdb_chunk) - num_pdbs_processed):
                num_pdbs_processed += 1
                for result in results:
                    if writer is None:
                        batch = pa.RecordBatch.from_arrays(list(result.values()), list(result.keys()))
                        writer = pa.RecordBatchFileWriter(output_file, batch.schema)
                    batch = pa.RecordBatch.from_arrays(list(result.values()), list(result.keys()))
                    writer.write_batch(batch)
                for result_failed in results_failed:
                    if writer_failed is None:
                        batch = pa.RecordBatch.from_arrays(list(result_failed.values()), list(result_failed.keys()))
                        writer_failed = pa.RecordBatchFileWriter(output_file_failed, batch.schema)
                    batch = pa.RecordBatch.from_arrays(list(result_failed.values()), list(result_failed.keys()))
                    writer_failed.write_batch(batch)
    except concurrent.futures.BrokenExecutor as e:
        print(f"ProcessPoolExecutor crashed with an error ('{type(e)!s}'): '{e!s}'.")
if writer is not None:
    writer.close()
if writer_failed is not None:
    writer_failed.close()

In [None]:
if output_file.is_file():
    reader = pa.RecordBatchFileReader(output_file)
    print(f"Number of successful chains: {reader.num_record_batches}.")

In [None]:
if output_file_failed.is_file():
    reader_failed = pa.RecordBatchFileReader(output_file_failed)
    print(f"Number of failed PDBs: {reader_failed.num_record_batches}.")

### Write a `*.SUCCESS` file

In [None]:
with output_file.with_suffix(".SUCCESS").open("wt"):
    pass