# Summary

Create validation dataset measuring decoy discrimination accuracy.

# Imports

In [None]:
import concurrent.futures
import importlib
import io
import os
import sys
import tarfile
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
from boltons import strutils

import kmbio.PDB
from kmtools import structure_tools
from tkpod.plugins.modeller import Modeller

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('decoy_discrimination_dataset')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    TASK_ID = 1
    TASK_COUNT = 200
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None
    
TASK_ID, TASK_COUNT

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

In [None]:
filename = (
    NOTEBOOK_PATH.name + 
    ("" if TASK_ID is None else f"-{TASK_ID:03}") +
    ".parquet"
)
OUTPUT_FILE = OUTPUT_PATH.joinpath(filename)
OUTPUT_FILE

# Input data

In [None]:
INPUT_DATA = {}

In [None]:
INPUT_DATA['3drobot'] = (
    Path(os.environ["DATAPKG_INPUT_DIR"])
    .joinpath("3drobot", "2018-11-16")
)

In [None]:
sorted(INPUT_DATA['3drobot'].glob("*.tar.bz2"))[:3]

# Functions

In [None]:
def get_one(it):
    vals = list(it)
    assert len(vals) == 1
    return vals[0]

In [None]:
def read_one(it):
    items = list(it)
    assert len(items) == 1, items
    item = items[0]
    return item

In [None]:
def text_to_fh(text):
    fh = io.StringIO()
    fh.write(text)
    fh.seek(0)
    return fh

In [None]:
def fh_to_structure(fh):
    parser = kmbio.PDB.PDBParser()
    structure = parser.get_structure(fh, bioassembly_id=False)
    return structure

In [None]:
def get_modeller_scores(structure, chain_id, sequence):
    target = structure_tools.DomainTarget(0, chain_id, sequence, 1, len(sequence), sequence)
    modeller_data = Modeller.build(structure, bioassembly_id=False, use_strict_alignment=True)
    structure_bm, modeller_results = Modeller.create_model([target], modeller_data)
    modeller_results = {k.replace(" ", "_").lower(): v for k, v in modeller_results.items()}
    # Format GA341 score
    for i in range(len(modeller_results['ga341_score'])):
        modeller_results[f'ga341_score_{i}'] = modeller_results['ga341_score'][i]
    modeller_results['ga341_score'] = modeller_results['ga341_score_0']
    # Format pdfterms
    modeller_results.update({
        "modeller_" + strutils.slugify(repr(k)): v
        for k, v
        in dict(modeller_results['pdfterms']).items()
    })
    del modeller_results['pdfterms']
    return modeller_results

In [None]:
def get_structure_info(row):
    results = {}
    
    fh = text_to_fh(row.structure_text)
    structure = fh_to_structure(fh)
    structure.id = row.unique_id
    sequence = structure_tools.get_chain_sequence(structure[0][row.chain_id])

    # Modeller
    modeller_scores = get_modeller_scores(structure, row.chain_id, sequence)
    results.update(modeller_scores)

    # Adj. mat
    ic, ica = helper.get_homology_model_interactions(row)
    residue_pairs = ica.at[0, "residue_pair"]
    residue_idx_1, residue_idx_2 = list(zip(*[t for t in residue_pairs if t[0] != t[1]]))
    results.update({
        "sequence": sequence,
        "residue_idx_1": list(residue_idx_1),
        "residue_idx_2": list(residue_idx_2),
    })

    # Rosetta (most time-consuming, so do last)
    rosetta_scores = helper.get_rosetta_scores(row)
    results.update(rosetta_scores)
    
    return results

In [None]:
def worker(row_dict):
    row = helper.to_namedtuple(row_dict)
    try:
        results = {
            "index": row.Index,
            "unique_id": row.unique_id,
            "error": None,
            **get_structure_info(row),
        }
    except Exception as e:
        results = {
            "index": row.Index,
            "unique_id": row.unique_id,
            "error": f"{type(e)}: {e}",
        }
    return results

# Process

## Get a list of decoy files

In [None]:
files = sorted(INPUT_DATA['3drobot'].glob("*.tar.bz2"))

In [None]:
if TASK_ID and TASK_COUNT:
    assert len(files) == TASK_COUNT
    files = files[TASK_ID - 1 : TASK_ID]
    
files

## Construct dataset

In [None]:
dfs =[]
for file in files:
    filename = file.name.split('.')[0]
    assert len(filename) == 5
    pdb_id, chain_id = filename[:4], filename[4]

    tempdir = tempfile.TemporaryDirectory()
    tempdir_name = tempdir.name
    tempdir_name = tempfile.mkdtemp()
    with tarfile.open(file, "r:bz2") as tar:
        tar.extractall(tempdir_name)
    
    df = pd.read_csv(Path(tempdir_name).joinpath(filename, "list.txt"), sep=' +', engine="python")
    df.rename(columns={"NAME": "decoy_name", "RMSD": "rmsd"}, inplace=True)
    df.set_index("decoy_name", inplace=True)
    df["structure_text"] = None
    assert df['structure_text'].isnull().all()
    for pdb_file in Path(tempdir_name).joinpath(filename).glob("*.pdb"):
        with pdb_file.open("rt") as fin:
            df.loc[pdb_file.name, "structure_text"] = fin.read()
    assert df['structure_text'].notnull().all()
    df.reset_index(inplace=True)
    df["structure_id"] = filename
    df["pdb_id"] = pdb_id
    df["chain_id"] = [(chain_id if n == "native.pdb" else " ") for n in df["decoy_name"]]
    df["unique_id"] = filename + "-" + df["decoy_name"].str.split('.').str[0]
    dfs.append(df)
    tempdir.cleanup()
    
dataset = pd.concat(dfs, ignore_index=True)
assert len(dataset["unique_id"]) == len(dataset["unique_id"].drop_duplicates())
assert len(dataset.index) == len(set(dataset.index))

## Run one row

In [None]:
row = list(dataset.itertuples())[1]

In [None]:
# get_structure_info(row)

## Run all rows

In [None]:
if DEBUG:
    dataset = dataset.iloc[:psutil.cpu_count(logical=False)]
    
print(len(dataset))

In [None]:
with concurrent.futures.ProcessPoolExecutor(psutil.cpu_count(logical=False)) as pool:
    futures = pool.map(worker, (t._asdict() for t in dataset.itertuples()))
    results = list(futures)

In [None]:
results_df = pd.DataFrame(results).set_index("index")
results_df.rename(columns=lambda s: s.replace(" ", "_").lower(), inplace=True)
display(results_df.head(4))
print("Number of rows: ", results_df.shape[0])
print("Number of errors: ", sum(results_df['error'].notnull()))

## Merge results

In [None]:
dataset_wresults = (
    dataset
    .merge(results_df, left_index=True, right_index=True, validate="1:1", copy=False, suffixes=("", "_copy"))
)

In [None]:
for col in dataset_wresults.columns:
    if col.endswith("_copy"):
        col_ref = col[:-5]
        assert (dataset_wresults[col] == dataset_wresults[col_ref]).all()
        del dataset_wresults[col]

## Parse failed subset

In [None]:
dataset_wresults_failed = dataset_wresults[dataset_wresults['error'].notnull()]

display(dataset_wresults_failed.head(2))
print(dataset_wresults_failed.shape[0])

In [None]:
table = pa.Table.from_pandas(dataset_wresults_failed, preserve_index=True)
pq.write_table(table, OUTPUT_FILE.with_suffix(".failed"), version="2.0", flavor="spark")

## Parse successful subset

In [None]:
dataset_wresults_succeeded = dataset_wresults[dataset_wresults['error'].isnull()]

display(dataset_wresults_succeeded.head(2))
print(dataset_wresults_succeeded.shape[0])

In [None]:
table = pa.Table.from_pandas(dataset_wresults_succeeded, preserve_index=True)
pq.write_table(table, OUTPUT_FILE, version="2.0", flavor="spark")