## Summary

----

## Imports

In [None]:
import concurrent.futures
import itertools
import os
import pickle
import shelve
import socket
from datetime import datetime
from pathlib import Path

from tqdm.notebook import tqdm

import numpy as np
import psutil
import pyarrow as pa
import pyarrow.parquet as pq

## Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "02_process_pdb_core_validation"))

NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

OUTPUT_PATH

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)) // 2)

CPU_COUNT

In [None]:
PDB_DATA_PATH = (
    Path(os.environ["DATAPKG_OUTPUT_DIR"]).joinpath("pdb-ffindex", "2020-01-16", "arrow").resolve(strict=True)
)

PDB_DATA_PATH

In [None]:
ADJACENCY_NET_DATA_PATH = (
    Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath("adjacency-net-v2", "v0.3").resolve(strict=True)
)

ADJACENCY_NET_DATA_PATH

## Workflow

In [None]:
with PDB_DATA_PATH.joinpath("pdb-list.pickle").open("rb") as fin:
    pdb_list = pickle.load(fin)
    
pdb_data_reader = pa.RecordBatchFileReader(PDB_DATA_PATH.joinpath("pdb-mmcif.arrow"))

assert len(pdb_list) == pdb_data_reader.num_record_batches

In [None]:
pdb_list[:3]

In [None]:
output_dir = ADJACENCY_NET_DATA_PATH.joinpath("pdb-core")
output_dir_failed = ADJACENCY_NET_DATA_PATH.joinpath("pdb-core-failed")
output_dir_stats = ADJACENCY_NET_DATA_PATH.joinpath("pdb-core", "stats")
output_dir_stats.mkdir(exist_ok=True)

output_dir, output_dir_failed, output_dir_stats

In [None]:
def worker(task_id, task_count, progress=False):
    chunk_size = int(np.ceil(len(pdb_list) / task_count))

    task_idx = task_id - 1
    stats = {
        "succeeded": True,
        "pdbs_succeeded": set(),
        "pdbs_failed": set(),
        "pdbs_missing": set(),
        "chains_succeeded_count": 0,
        "error_records": [],
    }

    # Create a subset of PDB structures
    pdb_chunk = pdb_list[task_idx * chunk_size : (task_idx + 1) * chunk_size]
    pdb_chunk_idxs = list(range(task_idx * chunk_size, (task_idx + 1) * chunk_size))
    pdb_chunk_set = set(pdb_chunk)

    # Read succeeded domains
    output_file = output_dir.joinpath(f"pdb-core-{task_id}-{task_count}.arrow")
    if not output_file.with_suffix(".SUCCESS").is_file():
        stats["succeeded"] = False
        return stats

    if output_file.is_file():
        reader = pa.RecordBatchFileReader(output_file)
        for i in tqdm(range(reader.num_record_batches), desc="succeeded", leave=False, disable=not progress):
            pdb_id = reader.get_record_batch(i).column(0)[0].as_py()
            stats["pdbs_succeeded"].add(pdb_id)
            stats["chains_succeeded_count"] += 1
    assert not stats["pdbs_succeeded"] - pdb_chunk_set

    # Read failed pdbs
    output_file_failed = output_dir_failed.joinpath(f"pdb-core-{task_id}-{task_count}-failed.arrow")
    if output_file_failed.is_file():
        reader_failed = pa.RecordBatchFileReader(output_file_failed)
        for i in tqdm(range(reader_failed.num_record_batches), desc="failed", leave=False, disable=not progress):
            error_record = reader_failed.get_record_batch(i).to_pydict()
            pdb_id = error_record["pdb_id"][0]
            stats["pdbs_failed"].add(pdb_id)
            stats["error_records"].append(error_record)
    assert not stats["pdbs_failed"] - pdb_chunk_set

    stats["pdbs_missing"] = pdb_chunk_set - stats["pdbs_succeeded"] - stats["pdbs_failed"]

    return stats

In [None]:
stats_cache_file = OUTPUT_PATH.joinpath("stats.cache")
stats = {}
task_count = 300

with shelve.open(stats_cache_file.as_posix()) as stats_cache:
    for key in stats_cache:
        stats[int(key)] = stats_cache[key]

In [None]:
with shelve.open(stats_cache_file.as_posix()) as stats_cache:
    with concurrent.futures.ProcessPoolExecutor(CPU_COUNT) as pool:
        task_ids = [i for i in range(1, task_count + 1) if i not in stats]
        futures = pool.map(worker, task_ids, itertools.repeat(task_count))
        for stat, task_id in tqdm(zip(futures, task_ids), total=len(task_ids)):
            stats[task_id] = stat
            stats_cache[str(task_id)] = stats[task_id]