In [None]:
import os
import json
from glob import glob

import numpy as np
import prody as pdy
import pandas as pd
import sys

from sblu.rmsd import pwrmsd, calc_rmsd
from sblu.cli.docking.cmd_cluster import cluster
from scipy.spatial.distance import cdist, pdist, squareform
from collections import OrderedDict, defaultdict

import torch

In [None]:
JSON_DIR = '/gpfs/scratch/jakhil/residue-packing/out_redo_w_rot-cat_v4'
CLEAN_PDB_DIR = '/gpfs/scratch/jakhil/CLEAN_PDB_03062021'

In [None]:
def read_json(json_file):
    with open(json_file, 'r') as rf:
        return json.load(rf)

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

def get_pdb(json_file):
    pdb_name = os.path.basename(json_file).split('.')[0]
    return os.path.abspath(f'{CLEAN_PDB_DIR}/{pdb_name}.pdb')

def get_central_node(res_nums, ca_array, com, dist_thresh=100.0):
    ca_array = [np.array(i) for i in ca_array]
    for r, c in zip(res_nums, ca_array):
        d = np.linalg.norm(c - com)
        if d < dist_thresh:
            dist_thresh = d
            central_node = r
    return central_node, dist_thresh


def find_target_node(res_nums, protein, cluster_distance=20.0):
    """
    We want to residues of interest, where each is seperated by at least a cluter distance from one another
    (i.e., identify source nodes which are at do not overlap in any other graph)
    """
    ca_atoms = protein.select('name CA and resnum {}'.format(' '.join([str(i) for i in roi_rn])))
    N = ca_atoms.numAtoms()
    
    condensed_dists = pdist(ca_atoms._getCoordsets().reshape(N, -1), 'sqeuclidean')
    rmsd_mat = squareform(np.sqrt(condensed_dists, condensed_dists))
    # cluster
    clusters = cluster(rmsd_mat, cluster_distance, 1, 100)
    source_node = [roi_rn[i[0]] for i in clusters]
    return source_node

def find_chi_and_eg_coord(residue):
    CHI_ATOM_MAP = {
    'ARG' : 'CD', # 5 Rotamers
    'ASN' : 'OD1', # 4 Rotamers
    'ASP' : 'OD1', # 3 Rotamers
    'CYS' : 'SG', # 2 Rotamers
    'GLN' : 'CD', # 4 Rotamers
    'GLU' : 'CD', # 5 Rotamers
    'HIS' : 'ND1', # 5 Rotamers
    'ILE' : 'CD1', # 4 Rotamers
    'LEU' : 'CD1', # 3 Rotamers
    'LYS' : 'CD', # 5 Rotamers
    'MET' : 'SD', # 5 Rotamers
    'PHE' : 'CD1', # 2 Rotamers
    'SER' : 'OG', # 2 Rotamers
    'THR' : 'OG1', # 2 Rotamers
    'TRP' : 'CD1', # 6 Rotamers
    'TYR' : 'CD1', # 2 Rotamers
    'VAL' : 'CG1' # 2 Rotamers
    }

    EG_ATOM_MAP = {
    'ARG' : 'NH2', # 5 Rotamers
    'ASN' : 'ND2', # 4 Rotamers
    'ASP' : 'OD2', # 3 Rotamers
    'CYS' : 'SG', # 2 Rotamers
    'GLN' : 'NE2', # 4 Rotamers
    'GLU' : 'OE2', # 5 Rotamers
    'HIS' : 'NE2', # 5 Rotamers
    'ILE' : 'CD1', # 4 Rotamers
    'LEU' : 'CD2', # 3 Rotamers
    'LYS' : 'NZ', # 5 Rotamers
    'MET' : 'CE', # 5 Rotamers
    'PHE' : 'CZ', # 2 Rotamers
    'SER' : 'OG', # 2 Rotamers
    'THR' : 'OG1', # 2 Rotamers
    'TRP' : 'CH2', # 6 Rotamers
    'TYR' : 'CZ', # 2 Rotamers
    'VAL' : 'CG2' # 2 Rotamers
    }
    res_name = residue.getResnames()[0]
    if res_name not in ['ALA', 'PRO', 'GLY']:
        chi_atom = str(CHI_ATOM_MAP[res_name])
        eg_atom = str(EG_ATOM_MAP[res_name])
    else:
        chi_atom = 'CA'
        eg_atom = 'CA'
    return residue.select(f'name {chi_atom}').getCoords()[0], residue.select(f'name {eg_atom}').getCoords()[0]

def print_ith_item(i):
    for key in data['test'].keys():
        print(key + ": " + str(data['test'][key][i]))

In [None]:
RES_TYPES = np.array(
    ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO',
     'SER', 'THR', 'TRP', 'TYR', 'VAL'], dtype='str')

MAX_NODES = 50
MAX_EDGES = MAX_NODES - 1

In [None]:
jsons = glob(f'{JSON_DIR}/*.json')

In [None]:
roi = 'MET' # residue of interest
roi_rota_start = 42 # the first rotamer category of the roi
roi_eg_start = 79 # the first end group category of the roi
roi_rota_atom = 'SD' # chi2 terminal atom
roi_eg_atom = 'CE' # doi: 10.1002/prot.23222

RES_ID, RES_NAME, NUM_NODE, NUM_EDGE, PHI, PSI, X, Xn, Xc, CHI_COORD, EG_COORD, CHI_CATA, EG_CATA, EDGE = ([] for i in range(14))
TARGET_CHI_CATA, TARGET_EG_CATA, TARGET_CHI_COORD, TARGET_EG_COORD = ([] for i in range(4))  # choose 1 to predict, but generating all
for _json in jsons:
    _data = read_json(os.path.abspath(_json))
    df = pd.DataFrame.from_dict(_data)
    roi_df = df[df['res_name'] == roi]
    roi_rn = roi_df['res_num'].values
    roi_ca = roi_df['ca_xyz'].values
    roi_idx = roi_df.index
    protein = pdy.parsePDB(get_pdb(_json)).select('heavy')
    com = pdy.calcCenter(protein)
    
    if roi_df.shape[0] > 0:
        target_nodes = find_target_node(roi_rn, protein)
        for target_node in target_nodes:
            d_target_to_center = round(pdy.calcDistance(protein.select(f'resnum {target_node}').getCoords(), com).min(), 1)
            data = df[(df['res_num'] == int(target_node))]
            
            # Collect "TARGET" prefix values
            target_chi_cata = data['rota_category'].values[0]
            TARGET_CHI_CATA.append(abs(roi_rota_start - target_chi_cata))
            target_eg_cata = data['end_group'].values[0]
            TARGET_EG_CATA.append(abs(roi_eg_start - target_eg_cata))                      
            res_id = [os.path.basename(_json).split('.')[0], str(target_node), str(target_chi_cata), str(target_eg_cata), str(d_target_to_center)]
            RES_ID.append('_'.join(res_id))
            
            try:
                TARGET_CHI_COORD.append(protein.select(f'resnum {target_node} and name {roi_rota_atom}').getCoords()[0]) 
            except:
                TARGET_CHI_COORD.append(data['ca_xyz'].values[0])
            try:
                TARGET_EG_COORD.append(protein.select(f'resnum {target_node} and name {roi_eg_atom}').getCoords()[0])
            except:
                TARGET_EG_COORD.append(data['ca_xyz'].values[0])

            # Collect Graph Data:
            edge = []
            d_mat = np.array(data['res_d_mat'].values[0][0])
            neighbor_1_idx = [neighbor for neighbor in list(np.where(d_mat < 10.0)[0]) if neighbor != data.index[0]]

            for neighbor in neighbor_1_idx:
                edge.append([data.index[0], neighbor, 0])

            # reformatting edges
            list_node = []
            for e in edge:
                list_node.append(e[0])
                list_node.append(e[1])
            list_node = list(set(list_node))
            neighbor_dict = dict(zip(list_node, [i for i in range(len(list_node))]))
            re_indexed_edge = []
            for e in edge:
                e[0] = neighbor_dict[e[0]]
                e[1] = neighbor_dict[e[1]]
            num_edge = len(edge)
            num_nodes = len(list_node)

            # ensuring that edge dimension is the same for all graphs
            norm_edge = np.zeros((MAX_EDGES, 3))
            for i, v in enumerate(edge):
                norm_edge[i] = v
            EDGE.append(norm_edge)
            NUM_EDGE.append(num_edge)
            NUM_NODE.append(num_nodes)

            # ensuring that features have same dimension across graphs
            norm_x = np.zeros((MAX_NODES, 3))
            norm_xc = np.zeros((MAX_NODES, 3))
            norm_xn = np.zeros((MAX_NODES, 3))
            norm_eg_coord = np.zeros((MAX_NODES, 3))
            norm_chi_coord = np.zeros((MAX_NODES, 3))
            norm_phi = np.zeros((MAX_NODES, 1))
            norm_psi = np.zeros((MAX_NODES, 1))
            norm_resname = np.zeros((MAX_NODES, 20), dtype=bool)
            norm_chis = np.zeros((MAX_NODES, 1))
            norm_egs = np.zeros((MAX_NODES, 1))

            # getting rest of features for neighbors
            _phi = df['phi'].values
            _psi = df['psi'].values
            _ca_xyz = df['ca_xyz'].values
            _c_xyz = df['c_xyz'].values
            _n_xyz = df['n_xyz'].values
            _o_xyz = df['o_xyz'].values
            _res_name = df['res_name'].values
            _res_nums = df['res_num'].values
            _chis = df['chis'].values
            _res_cat = df['rota_category'].values
            _eg_cat = df['end_group'].values
            for idx, i in enumerate(list_node):
                i = i.astype(int)
                norm_x[idx] = [_ca_xyz[i][0], _ca_xyz[i][1], _ca_xyz[i][2]]
                norm_xc[idx] = [_c_xyz[i][0], _c_xyz[i][1], _c_xyz[i][2]] # can be modified to reflect position relative to CA
                norm_xn[idx] = [_n_xyz[i][0], _n_xyz[i][1], _n_xyz[i][2]] # can be modified to reflect position relative to CA
                norm_phi[idx] = _phi[i]
                norm_psi[idx] = _psi[i]
                norm_resname[idx] = list(RES_TYPES == _res_name[i])
                norm_chis[idx] = int(_res_cat[i])
                norm_egs[idx] = int(_eg_cat[i])

                try:
                    chi_coord, eg_coord = find_chi_and_eg_coord(protein.select(f'resnum {_res_nums[i]}'))
                    norm_chi_coord[idx] = chi_coord[0], chi_coord[1], chi_coord[2]
                    norm_eg_coord[idx] = eg_coord[0], eg_coord[1], eg_coord[2]

                except:
                    norm_chi_coord[idx] = norm_x[idx]
                    norm_eg_coord[idx] =  norm_x[idx]
                
                if i == data.index[0]:
                    norm_chi_coord[idx] = 0, 0, 0
                    norm_eg_coord[idx] = 0, 0, 0
                    norm_chis[idx] = 0
                    norm_egs[idx] = 0
                    
            PHI.append(norm_phi)
            PSI.append(norm_psi)
            RES_NAME.append(norm_resname)    
            X.append(norm_x)
            Xc.append(norm_xc)
            Xn.append(norm_xn)
            CHI_CATA.append(norm_chis)
            EG_CATA.append(norm_egs)
            CHI_COORD.append(norm_chi_coord)
            EG_COORD.append(norm_eg_coord)

In [None]:
data_d = {
    'res_id': RES_ID,
    'num_node': NUM_NODE,
    'num_edge': NUM_EDGE,
    'target_cat': TARGET_CHI_CATA,
#     'target_eg_cat': TARGET_EG_CATA,
    'target_coord': TARGET_CHI_COORD,
#     'target_eg_coord': TARGET_EG_COORD,
    'chis': CHI_CATA,
    'egs' : EG_CATA,
    'chis_coord' : CHI_COORD,
    'egs_coord' : EG_COORD,    
    'x': X,
    'x_c': Xc,
    'x_n': Xn,
    'one_hot': RES_NAME,
    'phi': PHI,
    'psi': PSI,
    'edge': EDGE
}

In [None]:
for i in data_d.keys():
    print(i)
    print(len(data_d[i]))

In [None]:
df = pd.DataFrame.from_dict(data_d)
df['d_to_com'] = df['res_id'].apply(lambda x: float(x.split('_')[-1]))
df = df.sort_values(by=['d_to_com'], ascending=False)

In [None]:
CUTOFF = 1320
df_1 = df[df['target_cat'] == 0][:CUTOFF].reset_index()
df_2 = df[df['target_cat'] == 1][:CUTOFF].reset_index()
df_3 = df[df['target_cat'] == 2][:CUTOFF].reset_index()
df_4 = df[df['target_cat'] == 3][:CUTOFF].reset_index()
df_5 = df[df['target_cat'] == 4][:CUTOFF].reset_index()
df_6 = df[df['target_cat'] == 5][:CUTOFF].reset_index()
df_7 = df[df['target_cat'] == 6][:CUTOFF].reset_index()
df_8 = df[df['target_cat'] == 7][:CUTOFF].reset_index()
df_9 = df[df['target_cat'] == 8][:CUTOFF].reset_index()
print(df_1.shape,df_2.shape,df_3.shape,df_4.shape,df_5.shape,df_6.shape,df_7.shape,df_8.shape,df_9.shape)

In [None]:
df_out = pd.concat([df_1, df_2, df_3, df_4, df_5]).sort_index(kind='merge')

In [None]:
dd = defaultdict(list)
d = df_out.to_dict('list', into=dd)

In [None]:
data_dunbrack = {}

data_dunbrack["train"] = {}
data_dunbrack["valid"] = {}
data_dunbrack["test"] = {}

split_train_valid = int(CUTOFF * 0.8)
split_valid_test = int(CUTOFF * 0.9)

for key in d.keys():
    data_dunbrack["train"][key] = d[key][0:split_train_valid]
    data_dunbrack["valid"][key] = d[key][split_train_valid:split_valid_test]
    data_dunbrack["test"][key] = d[key][split_valid_test:]

In [None]:
torch.save(data_dunbrack, './11880-MET-TRIVIAL_EG-9_03302021.pt')

In [None]:
data = torch.load('./11880-MET-TRIVIAL_EG-9_03302021.pt')

In [None]:
print_ith_item(0)

In [None]:
# data = torch.load('/gpfs/scratch/jakhil/03252021_TYR-ROTA_10k.pt')