## Summary

<div class="alert alert-info">

**Note:**

These jobs must be submitted from the <code>./notebooks</code> folder.

</div>


**Cedar:**

```bash
NOTEBOOK_PATH=$(realpath 01_process_pdb_core.ipynb) sbatch --array=1-300 --time=24:00:00 --ntasks=1 --cpus-per-task=32 --constraint=skylake --job-name=process-pdb-core --account=rrg-pmkim --output=/scratch/strokach/tmp/log/run-notebook-cpu-%N-%j.log ../scripts/run_notebook_cpu.sh
```

----

## Imports

In [None]:
import concurrent.futures
import concurrent.futures.process
import gzip
import importlib
import io
import os
import pickle
import shlex
import shutil
import socket
import subprocess
import sys
import tempfile
import time
import warnings
from functools import partial
from itertools import islice
from pathlib import Path

import numpy as np
import yaml
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import mdtraj
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import tenacity
from kmbio import PDB
from kmtools import structure_tools

In [None]:
%matplotlib inline

pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

## Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "01_process_pdb_core"))
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
Path.cwd().expanduser()

In [None]:
PDB_DATA_PATH = (
    Path(os.environ["DATAPKG_OUTPUT_DIR"]).joinpath("pdb-ffindex", "2020-01-16", "arrow").resolve(strict=True)
)

PDB_DATA_PATH

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = psutil.cpu_count(logical=False)
    
CPU_COUNT

In [None]:
ADJACENCY_NET_DATA_PATH = (
    Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath("adjacency-net-v2", "v0.3").resolve(strict=True)
)

ADJACENCY_NET_DATA_PATH

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    TASK_ID = 4
    TASK_COUNT = 300
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None

TASK_ID, TASK_COUNT

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

## Load data

In [None]:
with PDB_DATA_PATH.joinpath("pdb-list.pickle").open("rb") as fin:
    pdb_list = pickle.load(fin)
    
pdb_data_reader = pa.RecordBatchFileReader(PDB_DATA_PATH.joinpath("pdb-mmcif.arrow"))

assert len(pdb_list) == pdb_data_reader.num_record_batches

In [None]:
pdb_list[:3]

In [None]:
chunk_size = int(np.ceil(len(pdb_list) / TASK_COUNT))
task_idx = TASK_ID - 1
pdb_chunk = pdb_list[task_idx * chunk_size : (task_idx + 1) * chunk_size]
pdb_chunk_idxs = list(range(task_idx * chunk_size, (task_idx + 1) * chunk_size))
assert all(pdb_chunk[i] == pdb_list[j] for i, j in enumerate(pdb_chunk_idxs))

chunk_size, task_idx, len(pdb_chunk), pdb_chunk[:3]

In [None]:
if DEBUG:
    pdb_chunk = pdb_chunk[:10]
    pdb_chunk_idxs = pdb_chunk_idxs[:10]

## Process structures

In [None]:
def structure_from_chain(structure_ref, model_ref, chain):
    model = PDB.Model(model_ref.id, model_ref.serial_num)
    model.add(chain)
    structure = PDB.Structure(structure_ref.id)
    structure.add(model)
    assert len(list(structure.chains)) == 1
    return structure

In [None]:
def process_pdb_core(pdb_id, pdb_idx):
    # Load data
    pdb_data_reader = pa.RecordBatchFileReader(PDB_DATA_PATH.joinpath("pdb-mmcif.arrow"))
    pdb_data = pdb_data_reader.get_record_batch(pdb_idx).to_pydict()
    assert pdb_data["pdb_id"][0] == pdb_id

    # Create structure from data
    buf = io.StringIO()
    buf.write(gzip.decompress(pdb_data["mmcif_data"][0]).decode())
    use_auth_id = False
    try:
        buf.seek(0)
        bioassembly_id = True
        structure = PDB.MMCIFParser(use_auth_id=use_auth_id).get_structure(buf, bioassembly_id=bioassembly_id)
    except PDB.BioassemblyError as e:
        print(f"Encountered error when parsing pdb {pdb_idx} ('{pdb_id}'): {e!s}.")
        buf.seek(0)
        bioassembly_id = False
        structure = PDB.MMCIFParser(use_auth_id=use_auth_id).get_structure(buf, bioassembly_id=bioassembly_id)

    results = []
    _seen = set()
    for model_idx, model in enumerate(structure):
        for chain_idx, chain in enumerate(model):
            aa_sequence = structure_tools.get_chain_sequence(chain)
            if aa_sequence in _seen:
                continue
            _seen.add(aa_sequence)
            if len(aa_sequence.strip()) < 5:
                continue

            schain = structure_from_chain(structure, model, chain.copy())

            with tempfile.NamedTemporaryFile(suffix=".pdb") as pdb_file:
                PDB.save(schain, pdb_file.name)
                traj = mdtraj.load(pdb_file.name)
            assert aa_sequence == traj.top.to_fasta()[0]

            residue_df = helper.construct_residue_df(traj)
            helper.validate_residue_df(residue_df)

            residue_pairs_df = helper.construct_residue_pairs_df(traj)
            helper.validate_residue_pairs_df(residue_pairs_df)

            result = {
                "pdb_id": [pdb_id],
                "pdb_idx": [pdb_idx],
                "use_auth_id": [use_auth_id],
                "bioassembly_id": [bioassembly_id],
                "model_idx": [model_idx],
                "model_id": [model.id],
                "chain_idx": [chain_idx],
                "chain_id": [chain.id],
                **helper.residue_df_to_row(residue_df),
                **helper.residue_pairs_df_to_row(residue_pairs_df),
            }
            results.append(result)

    return results

In [None]:
def worker(pdb_id, pdb_idx):
    try:
        return process_pdb_core(pdb_id, pdb_idx), []
    except Exception as error:
        return [], [{"pdb_id": [pdb_id], "pdb_idx": [pdb_idx], "error_type": [str(type(error))], "error": [str(error)]}]

In [None]:
# result = worker(pdb_chunk[0], pdb_chunk_idxs[0])

In [None]:
# result

In [None]:
# with concurrent.futures.ProcessPoolExecutor() as pool:
#     futures = pool.map(worker, pdb_chunk[:10], pdb_chunk_idxs[:10])
#     for result in tqdm(futures, total=100):
#         break

In [None]:
# result

In [None]:
output_dir = ADJACENCY_NET_DATA_PATH.joinpath("pdb-core")
output_dir.mkdir(exist_ok=True)
output_file = output_dir.joinpath(f"pdb-core-{TASK_ID}-{TASK_COUNT}.arrow")

output_dir_failed = ADJACENCY_NET_DATA_PATH.joinpath("pdb-core-failed")
output_dir_failed.mkdir(exist_ok=True)
output_file_failed = output_dir_failed.joinpath(f"pdb-core-{TASK_ID}-{TASK_COUNT}-failed.arrow")

output_file, output_file_failed

In [None]:
writer = None
writer_failed = None
with concurrent.futures.ProcessPoolExecutor(CPU_COUNT) as pool:
    futures = pool.map(worker, pdb_chunk, pdb_chunk_idxs)
    for (results, results_failed) in tqdm(futures, total=len(pdb_chunk)):
        for result in results:
            if writer is None:
                batch = pa.RecordBatch.from_arrays(list(result.values()), list(result.keys()))
                writer = pa.RecordBatchFileWriter(output_file, batch.schema)
            batch = pa.RecordBatch.from_arrays(list(result.values()), list(result.keys()))
            writer.write_batch(batch)
        for result_failed in results_failed:
            if writer_failed is None:
                batch = pa.RecordBatch.from_arrays(list(result_failed.values()), list(result_failed.keys()))
                writer_failed = pa.RecordBatchFileWriter(output_file_failed, batch.schema)
            batch = pa.RecordBatch.from_arrays(list(result_failed.values()), list(result_failed.keys()))
            writer_failed.write_batch(batch)
if writer is not None:
    writer.close()
if writer_failed is not None:
    writer_failed.close()

In [None]:
if output_file.is_file():
    reader = pa.RecordBatchFileReader(output_file)
    print(f"Number of successful chains: {reader.num_record_batches}.")

In [None]:
if output_file_failed.is_file():
    reader_failed = pa.RecordBatchFileReader(output_file_failed)
    print(f"Number of failed PDBs: {reader_failed.num_record_batches}.")

### Write `_SUCCESS` file

In [None]:
with output_file.parent.joinpath("_SUCCESS").open("wt"):
    pass