In [2]:
import os,sys,re
import argparse, json
import copy
import random
import pickle
import math
import torch
from torch import nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm
#from tqdm.notebook import tqdm
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Polypeptide import one_to_index
from Bio.PDB import Selection
from Bio import SeqIO
from Bio.PDB.Residue import Residue
from easydict import EasyDict
import enum
import esm, gzip
from Bio import SeqIO
from esm.model.esm2 import ESM2
from collections import OrderedDict
from sklearn.metrics import mean_squared_error
import scipy.stats

In [25]:
complexes = []
resolution = []
units = []
log_kds_M = []

dissociation_units_to_moles = {
    'fM': 1e-15,  # femtomolar to molar
    'pM': 1e-12,  # picomolar to molar
    'nM': 1e-9,   # nanomolar to molar
    'uM': 1e-6,   # micromolar to molar
    'mM': 1e-3    # millimolar to molar
}

with open('PP/index/INDEX_general_PP.2020') as file:
    for i,line in enumerate(file):
        if line.startswith('#'):
            continue
        splitted = line.split('  ')

        def remove(s, sub_strs):
            for sub_str in sub_strs:
                s = s.replace(sub_str, '')
            return s 

        if 'Ki' in splitted[3]:
            kd_str = remove(splitted[3], ['Ki=', 'Ki~', 'Ki<'])
        if 'Kd' in splitted[3]:
            kd_str = remove(splitted[3], ['Kd=', 'Kd~', 'Kd<', 'Kd>'])
        if 'IC50' in splitted[3]:
            kd_str = remove(splitted[3], ['IC50=', 'IC50~', 'IC50<'])
        val = kd_str[:-2]
        unit = kd_str[len(kd_str)-2:]

        if val == '':
            print(line)
            print(kd_str)
            print(unit)
        
        if unit == '':
            break
        else:
            kd = float(val) * dissociation_units_to_moles[unit]
            log_kd = -1*math.log10(kd)
            complexes.append(splitted[0])
            log_kds_M.append(log_kd)

In [26]:
# extracted_df = pd.DataFrame({'complex': complexes, '-logKd': log_kds_M})
# extracted_df.to_csv('extracted_affinities.csv', index=False)

In [2]:
non_standard_residue_substitutions = {
    '2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG',
    'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA',
    'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS',
    'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS',
    'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA',
    'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP',
    'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY',
    'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO',
    'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN',
    'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU',
    'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE',
    'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS',
    'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA',
    'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR'
}


max_num_heavyatoms = 15
max_num_hydrogens = 16
max_num_allatoms = max_num_heavyatoms + max_num_hydrogens

ressymb_to_resindex = {
    'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,
    'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,
    'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,
    'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,
    'X': 20,
}

resindex_to_ressymb = {v: k for k, v in ressymb_to_resindex.items()}

class BBHeavyAtom(enum.IntEnum):
    N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14;

def _get_residue_heavyatom_info(res: Residue):
    pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float)
    mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool)
    bfactor_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.float)
    restype = AA(res.get_resname())
    for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]):
        if atom_name == '': continue
        if atom_name in res:
            pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)
            mask_heavyatom[idx] = True
            bfactor_heavyatom[idx] = res[atom_name].get_bfactor()
    return pos_heavyatom, mask_heavyatom, bfactor_heavyatom

class AA(enum.IntEnum):
    ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4
    GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9
    MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14
    SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19
    UNK = 20

    @classmethod
    def _missing_(cls, value):
        if isinstance(value, str) and len(value) == 3:      # three representation
            if value in non_standard_residue_substitutions:
                value = non_standard_residue_substitutions[value]
            if value in cls._member_names_:
                return getattr(cls, value)
        elif isinstance(value, str) and len(value) == 1:    # one representation
            if value in ressymb_to_resindex:
                return cls(ressymb_to_resindex[value])

        return super()._missing_(value)

    def __str__(self):
        return self.name

    @classmethod
    def is_aa(cls, value):
        return (value in ressymb_to_resindex) or \
            (value in non_standard_residue_substitutions) or \
            (value in cls._member_names_)

restype_to_heavyatom_names = {
    AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    '', 'OXT'],
    AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    '', 'OXT'],
    AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    '', 'OXT'],
    AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    '', 'OXT'],
    AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    '', 'OXT'],
    AA.GLY: ['N', 'CA', 'C', 'O', '',   '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    '', 'OXT'],
    AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    '', 'OXT'],
    AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    '', 'OXT'],
    AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    '', 'OXT'],
    AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    '', 'OXT'],
    AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    '', 'OXT'],
    AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'],
    AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    '', 'OXT'],
    AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],
    AA.UNK: ['',  '',   '',  '',  '',   '',    '',    '',    '',    '',    '',    '',    '',    '',    ''],
}
for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms

amino_acids = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
    'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G',
    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}



def parse_biopython_structure(entity, unknown_threshold=1.0):
    chains = Selection.unfold_entities(entity, 'C')
    chains.sort(key=lambda c: c.get_id())
    data = EasyDict({
        'chain_id': [], 'chain_nb': [], 'seq': []
    })
    
    tensor_types = {
        'chain_nb': torch.LongTensor
    }
    count_aa, count_unk = 0, 0

    for i, chain in enumerate(chains):
        chain.atom_to_internal_coordinates()
        seq_this = 0   # Renumbering residues
        residues = Selection.unfold_entities(chain, 'R')
        residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2]))   # Sort residues by resseq-icode
        for _, res in enumerate(residues):
            resname = res.get_resname()

            if not AA.is_aa(resname): continue
            if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue
            restype = AA(resname)
            
            count_aa += 1
            if restype == AA.UNK: 
                count_unk += 1
                continue

            # Chain info
            data.chain_id.append(chain.get_id())
            data.chain_nb.append(i)

            try:
                data['seq'] += amino_acids[resname]
            except:
                data['seq'] += amino_acids[non_standard_residue_substitutions[resname]]

    return data

In [3]:
def standardize(x):
    return (x - x.mean(axis=0))/(x.std(axis=0))

def get_sequences_by_chain(chain_ids, amino_acids):
    # Initialize list of lists for chains
    chains = []
    for i in range(torch.max(chain_ids)+1):
        chains.append([])
    
    # Iterate over both chain ids and amino acids
    for chain_id, amino_acid in zip(chain_ids, amino_acids):
        # Append amino acid to the corresponding chain
        chains[chain_id].append(amino_acid)
    
    # Convert each sublist into a string to get the sequence of each chain
    return [''.join(chain) for chain in chains]

class PDBBindDataset(Dataset):

    def __init__(
        self, 
        csv_path, 
        pdb_dir, 
        cache_dir,
        cvfold_index=0, 
        num_cvfolds=3, 
        split='train', 
        split_seed=2023,
        reset=False
    ):
        super().__init__()
        self.csv_path = csv_path
        self.pdb_dir = pdb_dir
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        self.cvfold_index = cvfold_index
        self.num_cvfolds = num_cvfolds
        assert split in ('train', 'val')
        self.split = split
        self.split_seed = split_seed

        self.entries_cache = os.path.join(cache_dir, 'entries.pkl')
        self.entries = None
        self.entries_full = None
        self._load_entries(reset)

        self.structures_cache = os.path.join(cache_dir, 'structures.pkl')
        self.structures = None
        self._load_structures(reset)

    def _load_entries(self, reset):
        if not os.path.exists(self.entries_cache) or reset:
            self.entries_full = self._preprocess_entries()
        else:
            with open(self.entries_cache, 'rb') as f:
                self.entries_full = pickle.load(f)

        complex_to_entries = {}
        for index, e in self.entries_full.iterrows():
            if e['complex'] not in complex_to_entries:
                complex_to_entries[e['complex']] = []
            complex_to_entries[e['complex']].append(e)

        complex_list = sorted(complex_to_entries.keys())
        random.Random(self.split_seed).shuffle(complex_list)

        split_size = math.ceil(len(complex_list) / self.num_cvfolds)
        complex_splits = [
            complex_list[i*split_size : (i+1)*split_size] 
            for i in range(self.num_cvfolds)
        ]

        val_split = complex_splits.pop(self.cvfold_index)
        train_split = sum(complex_splits, [])
        if self.split == 'val':
            complexes_this = val_split
        else:
            complexes_this = train_split

        entries = []
        for cplx in complexes_this:
            entries += complex_to_entries[cplx]
        
        entries_val = [float(e['-logKd']) for e in entries]
        
        self.entries_val = standardize(np.array(entries_val))
        self.entries_complex = [e['complex'] for e in entries]
        
    def _preprocess_entries(self):
        entries = pd.read_csv(filepath_or_buffer=self.csv_path)
        with open(self.entries_cache, 'wb') as f:
            pickle.dump(entries, f)
        return entries

    def _load_structures(self, reset):
        if not os.path.exists(self.structures_cache) or reset:
            self.structures = self._preprocess_structures()
        else:
            with open(self.structures_cache, 'rb') as f:
                self.structures = pickle.load(f)

    def _preprocess_structures(self):
        structures = {}
        pdbcodes = list(set([e['complex'] for index, e in self.entries_full.iterrows()]))

        missing_num = 0
        for pdbcode in tqdm(pdbcodes, desc='Structures'):
            parser = PDBParser(QUIET=True)
            pdb_path = os.path.join(self.pdb_dir, '{}.ent.pdb'.format(pdbcode.lower()))
            model = parser.get_structure(None, pdb_path)[0]
            data = parse_biopython_structure(model)
            structures[pdbcode] = data

        with open(self.structures_cache, 'wb') as f:
            pickle.dump(structures, f)
        return structures

    def __len__(self):
        return len(self.entries_val)

    def __getitem__(self, index):

        ddG = self.entries_val[index]
        complex = self.entries_complex[index]
        data = copy.deepcopy(self.structures[complex])

        data['seq'] = ''.join(data['seq'])
        data['chain_nb'] = torch.tensor(data['chain_nb'])
        data['ddG'] = ddG

        del data['chain_id']
        
        return data

In [13]:
entries = pd.read_csv('extracted_affinities.csv')
pdbcodes = entries['complex'].unique().tolist()

In [25]:
sequences = {}
missing_num = 0
for pdbcode in tqdm(pdbcodes, desc='Structures'):
    parser = PDBParser(QUIET=True)
    pdb_path = os.path.join('PP/', '{}.ent.pdb'.format(pdbcode.lower()))
    model = parser.get_structure(None, pdb_path)[0]
    data = parse_biopython_structure(model)
    sequences[pdbcode] = data

  numpy.arccos(((a0a1 * a0a1) + sqr_a1a2 - (a0a2 * a0a2)) / (2 * a0a1 * a1a2))
  numpy.arccos((sqr_a1a2 + (a2a3 * a2a3) - (a1a3 * a1a3)) / (2 * a1a2 * a2a3))
Structures: 100%|███████████████████████████| 2852/2852 [38:35<00:00,  1.23it/s]


In [60]:
seqs = []
all_chain_ids = []
targets = []

for i,row in entries.iterrows():
    seq_data = sequences [row['complex']]

    if np.max(seq_data['chain_nb']) > 5:
        continue
    
    chain_ids = ''.join([str(i) for i in seq_data['chain_nb']])
    seq = ''.join(seq_data['seq'])

    assert len(seq) == len(chain_ids)

    seqs.append(seq)
    all_chain_ids.append(chain_ids)
    targets.append(row['-logKd'])

In [61]:
df = pd.DataFrame({'seq': seqs, 'chain_ids': all_chain_ids, 'target': targets})

In [62]:
df.to_csv('./processed_data.csv')

In [3]:
x = pd.read_csv('./processed_data.csv')

In [5]:
x = x.drop_duplicates(subset=['seq'])

In [7]:
x.to_csv('./processed_data.csv')