# Summary

Generate adjancency matrices for the Protherm training set directly from PDBs.

---

# Imports

In [None]:
import concurrent.futures
import importlib
import logging
import os
import os.path as op
import shutil
import sys
from collections import Counter
from pathlib import Path

import kmbio.PDB
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import pyarrow as pa
import pyarrow.parquet as pq
from kmtools import structure_tools

In [None]:
%matplotlib inline

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_NAME = 'protherm_dataset'
NOTEBOOK_PATH = Path(NOTEBOOK_NAME)

NOTEBOOK_PATH.mkdir(parents=True, exist_ok=True)

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
# DEBUG = "CI" not in os.environ
DEBUG = False
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

DEBUG, TASK_ID, TASK_COUNT

In [None]:
DATABIN_PATH = Path(os.environ['DATABIN_DIR'])

# Load data

In [None]:
ROSETTA_RESULTS = {}

with pd.HDFStore(DATABIN_PATH.joinpath('elapsam_feature_engineering/v0.1.0/rosetta.h5').as_posix(), 'r') as store:
    for key in store:
        ROSETTA_RESULTS[key.strip('/')] = store[key][:]

In [None]:
ROSETTA_RESULTS.keys()

In [None]:
ROSETTA_RESULTS['cartesian_ddg-talaris2014_cart-1'].head()

In [None]:
rosetta_results_df = None

for key, df in ROSETTA_RESULTS.items():
    df = df.rename(columns={'ddg': key})
    if rosetta_results_df is None:
        rosetta_results_df = df
    else:
        assert (rosetta_results_df['ddg_exp'].values == df['ddg_exp'].values).all()
        rosetta_results_df = rosetta_results_df.merge(
            df.drop('ddg_exp', axis=1), on=['filename-wt', 'pdb_chain', 'mutation'], how='outer')

rosetta_results_df = rosetta_results_df.rename(columns=lambda c: c.replace('-', '_').strip('_'))
display(rosetta_results_df.head())
print(rosetta_results_df.shape)

## Copy structures

In [None]:
STRUCTURE_PATH = NOTEBOOK_PATH.joinpath('structures')
STRUCTURE_PATH.mkdir(exist_ok=True)

In [None]:
os.listdir(STRUCTURE_PATH)[:10]

In [None]:
def get_local_filename(filename):
    return STRUCTURE_PATH.joinpath(op.basename(filename)).absolute().as_posix()

get_local_filename(rosetta_results_df['filename_wt'].iloc[0])

In [None]:
file_list = rosetta_results_df['filename_wt'].drop_duplicates().tolist()

local_filename_wt = []
for i, filename in enumerate(file_list):
    if i % 200 == 0:
        print(i)
    new_filename = STRUCTURE_PATH.joinpath(op.basename(filename))
    filename = filename.replace(
        "/home/kimlab2/database_data/biological-data-warehouse",
        Path("~/datapkg").expanduser().as_posix(),
    )
    local_filename = get_local_filename(filename)
    if not op.isfile(local_filename):
        shutil.copy(filename, local_filename)
    local_filename_wt.append(local_filename)   

rosetta_results_df['local_filename_wt'] = local_filename_wt

# Process data

In [None]:
if DEBUG:
    rosetta_results_df = rosetta_results_df.iloc[:10]

## Extract adjacencies

In [None]:
def extract_seq_and_adj(row):
    domain, result_df = helper.get_interaction_dataset_wdistances(
        row.local_filename_wt, 0, row.pdb_chain, r_cutoff=12)
    domain_sequence = structure_tools.get_chain_sequence(domain)
    assert max(result_df['residue_idx_1'].values) < len(domain_sequence)
    assert max(result_df['residue_idx_2'].values) < len(domain_sequence)
    result = {
        'sequence': domain_sequence,
        'residue_idx_1': result_df['residue_idx_1'].values,
        'residue_idx_2': result_df['residue_idx_2'].values,
        'distances': result_df['distance'].values,
    }
    return result

In [None]:
def worker(row_dict):
    row = helper.to_namedtuple(row_dict)
    result = extract_seq_and_adj(row)
    return result

In [None]:
logging.getLogger("kmbio.PDB.core.atom").setLevel(logging.WARNING)

In [None]:
columns = ["local_filename_wt", "pdb_chain"]

with concurrent.futures.ProcessPoolExecutor(psutil.cpu_count(logical=False)) as pool:
    futures = pool.map(worker, (t._asdict() for t in rosetta_results_df[columns].itertuples()))
    results = list(futures)

In [None]:
protherm_validaton_dataset = rosetta_results_df.copy()
protherm_validaton_dataset = protherm_validaton_dataset.rename(columns={'pdb_chain': 'chain_id'})
                                                               
protherm_validaton_dataset['structure_id'] = [
    Path(filename).name[3:7] for filename in protherm_validaton_dataset["filename_wt"]
]
protherm_validaton_dataset['model_id'] = 0

In [None]:
protherm_validaton_dataset['qseq'] = [result["sequence"] for result in results]
protherm_validaton_dataset['residue_idx_1_corrected'] = [result["residue_idx_1"] for result in results]
protherm_validaton_dataset['residue_idx_2_corrected'] = [result["residue_idx_2"] for result in results]
protherm_validaton_dataset['distances'] = [result["distances"] for result in results]

In [None]:
def mutation_matches_sequence(mutation, sequence):
    return sequence[int(mutation[1:-1]) - 1] == mutation[0]


protherm_validaton_dataset['mutation_matches_sequence'] = [
    mutation_matches_sequence(mutation, sequence)
    for mutation, sequence
    in protherm_validaton_dataset[['mutation', 'qseq']].values
]
assert protherm_validaton_dataset['mutation_matches_sequence'].all()

In [None]:
def apply_mutation(sequence, mutation):
    wt, pos, mut = mutation[0], int(mutation[1:-1]), mutation[-1]
    assert sequence[pos - 1] == wt
    sequence_mut = sequence[:pos - 1] + mut + sequence[pos:]
    assert sequence_mut[pos - 1] == mut
    assert len(sequence) == len(sequence_mut)
    return sequence_mut

protherm_validaton_dataset['qseq_mutation'] = [
    apply_mutation(sequence, mutation)
    for mutation, sequence
    in protherm_validaton_dataset[['mutation', 'qseq']].values
]

In [None]:
assert not protherm_validaton_dataset.isnull().any().any()

In [None]:
columns = [
    'structure_id', 'model_id', 'chain_id', 'qseq', 'qseq_mutation', 'ddg_exp', 
    'residue_idx_1_corrected', 'residue_idx_2_corrected', 'distances',
]

for column in columns:
    assert column in protherm_validaton_dataset.columns, column

In [None]:
pq.write_table(
    pa.Table.from_pandas(protherm_validaton_dataset, preserve_index=False),
    OUTPUT_PATH.joinpath('protherm_validaton_dataset.parquet').as_posix(),
    version='2.0', flavor='spark'
)

# Explore

In [None]:
aa_wt_counter = Counter(protherm_validaton_dataset['mutation'].str[0])
aa_mut_counter = Counter(protherm_validaton_dataset['mutation'].str[-1])

labels = list(aa_wt_counter)
aa_wt = [aa_wt_counter[l] for l in labels]
aa_mut = [aa_mut_counter[l] for l in labels]

indexes = np.arange(len(labels))
width = 0.3

with plt.rc_context(rc={'figure.figsize': (8, 5), 'font.size': 14}):
    plt.bar(indexes - 0.15 , aa_wt, width, label="wt")
    plt.bar(indexes + 0.15, aa_mut, width, label="mut")
    plt.xticks(indexes, labels)
    plt.ylabel("Number of occurrences")
    plt.legend()