## Summary

----

## Imports

In [None]:
import os
import shelve
import socket
from pathlib import Path

from tqdm.notebook import tqdm
import pickle
import pyarrow as pa
import pyarrow.parquet as pq

## Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "04_add_adjacency_distances_validation"))

NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

OUTPUT_PATH

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)) // 2)

CPU_COUNT

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    TASK_ID = 216
    TASK_COUNT = 1027
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None

TASK_ID, TASK_COUNT

In [None]:
ADJACENCY_MATRIX_PARQUET_PATH = Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath(
    "adjacency-net-v2", "v0.3", "training_dataset", "adjacency_matrix.parquet"
)

## Load data

In [None]:
folders = sorted([d for d in ADJACENCY_MATRIX_PARQUET_PATH.glob("database_id=*") if d.is_dir()])
folders[:3]

In [None]:
folders.index(Path("/scratch/strokach/datapkg_output_dir/adjacency-net-v2/v0.3/training_dataset/adjacency_matrix.parquet/database_id=G3DSA%3A1.20.120.420"))

In [None]:
files = sorted(folders[TASK_ID - 1].glob("*.parquet"))

print(files[:2])
print(len(files))

In [None]:
df = (
    pq.ParquetFile(files[0])
    .read_row_group(0, columns=["__index_level_0__"], use_pandas_metadata=False)
    .to_pandas(integer_object_nulls=True)
)

In [None]:
df.head(2)

## Find successful jobs

In [None]:
def get_new_file(file, failed=False):
    file_parts = list(file.parts)
    file_parts[-4] = file_parts[-4] + "_wdistances"
    file_parts[-1] = file_parts[-1].split(".")[0] + ".arrow"
    if failed:
        file_parts.insert(-3, "failed")
    new_file = Path(*file_parts)
    return new_file

In [None]:
succeeded = {}

for task_id in tqdm(range(1, TASK_COUNT + 1)):
    task_idx = task_id - 1
    files = sorted(folders[task_idx].glob("*.parquet"))

    succeeded[task_id] = True
    for file in files:
        new_file = get_new_file(file)
        if not new_file.parent.joinpath("._SUCCESS").is_file():
            succeeded[task_id] = False
            break
        if not new_file.is_file():
            succeeded[task_id] = False
            break

In [None]:
sum(1 for v in succeeded.values() if v)

In [None]:
sum(1 for v in succeeded.values() if not v)

In [None]:
# ",".join([str(k) for k, v in status.items() if v == "failure"])

## Analyze successful jobs

In [None]:
stats_cache_file = OUTPUT_PATH.joinpath("stats.cache")

stats_all = {}
with shelve.open(stats_cache_file.as_posix()) as stats_cache:
    for key in stats_cache:
        stats_all[int(key)] = stats_cache[key]

In [None]:
for task_id in tqdm(range(1, TASK_COUNT + 1)):
    if task_id in stats_all:
        continue

    stats = {
        "succeeded": succeeded[task_id],
        "succeeded_indices": set(),
        "failed_indices": set(),
        "missing_indices": set(),
    }

    if stats["succeeded"]:
        files = sorted(folders[task_id - 1].glob("*.parquet"))
        try:
            for file in files:
                all_indices = set(
                    pq.read_table(file, columns=["__index_level_0__"], use_pandas_metadata=False)
                    .to_pandas(integer_object_nulls=True)["__index_level_0__"]
                    .values.tolist()
                )

                new_file = get_new_file(file)
                if new_file.is_file():
                    reader = pa.RecordBatchFileReader(new_file)
                    for record_batch_idx in tqdm(range(reader.num_record_batches), leave=False):
                        batch = reader.get_record_batch(record_batch_idx)
                        index = batch.column(22)[0]
                        stats["succeeded_indices"].add(index.as_py())
                assert not stats["succeeded_indices"] - all_indices

                new_file = get_new_file(file, failed=True)
                if new_file.is_file():
                    reader = pa.RecordBatchFileReader(new_file)
                    for record_batch_idx in tqdm(range(reader.num_record_batches), leave=False):
                        batch = reader.get_record_batch(record_batch_idx)
                        index = batch.to_pydict()["Index"][0]
                        stats["failed_indices"].add(index)
                assert not stats["failed_indices"] - all_indices

                assert not stats["succeeded_indices"] & stats["failed_indices"]

                stats["missing_indices"] = all_indices - stats["succeeded_indices"] - stats["failed_indices"]
        except pa.ArrowInvalid:
            stats["succeeded"] = False

    stats_all[task_id] = stats
    with shelve.open(stats_cache_file.as_posix()) as stats_cache:
        stats_cache[str(task_id)] = stats_all[task_id]