In [1]:
import numpy as np
import pandas as pd
import os
import random
import atom3d.datasets.datasets as da

import sys
sys.path.append('../')
import gvp
import gvp.atom3d as gvp_atom3d
import esm


#### Load DIPS dataset (from GGN meets pLM, but really I think it's from the GVP paper)

In [2]:
dataset = da.load_dataset('../data/PPI/DIPS-split/data/test', 'lmdb')

#### Run through the data and randomly swap one chain in order to make examples of non-binders.
But also include the real binders in the dataset. Ultimately, we're doubling the size of the dataset.

In [5]:
precompute_plm_embeddings = True

if precompute_plm_embeddings:
    esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    esm_model.eval()
    esm_model = esm_model.to('cuda')

cols = ['ensemble', 'subunit', 'structure', 'residue', 'resname', 'name','x', 'y', 'z']
max_len = 800
num_dropped = 0

new_dataset = []

for i, pair in enumerate(dataset):
    if i % 1000 == 0:
        print(i, '/', len(dataset), len(new_dataset))
        print('Dropped', num_dropped, 'pairs')

    p1 = pair['atoms_pairs'][cols]
    p1 = p1[p1['name'] == 'CA']
    subunits = p1['subunit'].unique()

    # Max length check. If either chain is too long, drop the pair altogether
    len_chain1 = len(p1[p1['subunit'] == subunits[0]])
    len_chain2 = len(p1[p1['subunit'] == subunits[1]])
    if (len_chain1 > max_len) | (len_chain2 > max_len):
        num_dropped += 1
        continue

    subunit_p1 = random.choice(subunits)
    chain_p1 = p1[p1['subunit'] == subunit_p1].reset_index(drop=True)

    found_suitable_swap = False
    search_step = 0
    while not found_suitable_swap:
        p2 = dataset[np.random.randint(len(dataset))]['atoms_pairs'][cols]
        subunit_p2 = random.choice(p2['subunit'].unique())
        if subunit_p2.split('.')[0] != subunit_p1.split('.')[0]:
            chain_p2 = p2[p2['subunit'] == subunit_p2]
            chain_p2 = chain_p2[chain_p2['name'] == 'CA'].reset_index(drop=True)

            # Max length check. If the chain we're trying to swap is too long, try again.
            if len(chain_p2) > max_len:
                continue

            # Construct the new pair that we assume are non-binders
            new_pair = {'atoms_pairs': pd.concat([chain_p1, chain_p2])}
            new_pair['bind'] = 0
            new_pair['chain1_len'] = len(chain_p1)
            new_pair['chain2_len'] = len(chain_p2)
            new_pair['id'] = len(new_dataset)
            if precompute_plm_embeddings:
                new_pair['esm_emb_p1'] = pd.DataFrame(gvp_atom3d.get_plm_reps(chain_p1, esm_model, batch_converter, device='cuda').cpu().numpy())
                new_pair['esm_emb_p2'] = pd.DataFrame(gvp_atom3d.get_plm_reps(chain_p2, esm_model, batch_converter, device='cuda').cpu().numpy())
            else:
                new_pair['esm_emb_p1'] = None
                new_pair['esm_emb_p2'] = None

            new_dataset.append(new_pair)
            found_suitable_swap = True
        
        search_step += 1
        if search_step == 20:
            print('Could not find suitable swap for pair', i, subunit_p1.split('.')[0])
            found_suitable_swap = True

    original_pair = {'atoms_pairs': p1,
          'bind': 1,
          'id': len(new_dataset),
          'chain1_len': len_chain1,
          'chain2_len': len_chain2}
    if precompute_plm_embeddings:
        original_pair['esm_emb_p1'] = pd.DataFrame(gvp_atom3d.get_plm_reps(p1[p1['subunit'] == subunits[0]], esm_model, batch_converter, device='cuda').cpu().numpy())
        original_pair['esm_emb_p2'] = pd.DataFrame(gvp_atom3d.get_plm_reps(p1[p1['subunit'] == subunits[1]], esm_model, batch_converter, device='cuda').cpu().numpy())
    else:
        original_pair['esm_emb_p1'] = None
        original_pair['esm_emb_p2'] = None
    new_dataset.append(original_pair)

    if i == 200:
        break


0 / 15268 0
Dropped 0 pairs


#### Write to a new LMDB file

In [6]:
save_fp = '../data/paul_PPbind_w_esm_emb/test'

if not os.path.exists(save_fp):
    os.makedirs(save_fp)

da.make_lmdb_dataset(new_dataset, save_fp)

100%|██████████| 402/402 [05:41<00:00,  1.18it/s]


#### Test it out

In [7]:
dataset = da.load_dataset(save_fp, 'lmdb')
t = dataset[0]

### DB5 Dataset

In [2]:
import dill
import pickle

def process_db5_pickle_df(fp):
    # cols = ['ensemble', 'subunit', 'structure', 'residue', 'resname', 'name','x', 'y', 'z']
    with open(fp, 'rb') as f:
        data = pickle.load(f, encoding='latin1')
        data = data.rename(columns={'pdb_name': 'subunit', 'atom_name': 'name'})
        data = data[data['name'] == 'CA']

    return data

In [35]:
data_root_dir = '/data/psample/data/DIPS_PPI/DB5/interim'
db5_meta = dill.load(open(f"{data_root_dir}/complexes/complexes.dill", "rb"))

complexes = list(db5_meta['data'].keys())

db5_lmdb = []

for comp in complexes:
    meta = db5_meta['data'][comp]
    bound1_fp = meta.bound_filenames[0].split('interim')[1]
    bound2_fp = meta.bound_filenames[1].split('interim')[1]
    unbound1_fp = meta.unbound_filenames[0].split('interim')[1]
    unbound2_fp = meta.unbound_filenames[1].split('interim')[1]

    bound1 = process_db5_pickle_df(f'{data_root_dir}/{bound1_fp}')
    bound2 = process_db5_pickle_df(f'{data_root_dir}/{bound2_fp}')
    unbound1 = process_db5_pickle_df(f'{data_root_dir}/{unbound1_fp}')
    unbound2 = process_db5_pickle_df(f'{data_root_dir}/{unbound2_fp}')

    bound_pair = {'atoms_pairs': pd.concat([bound1, bound2])}
    bound_pair['bind'] = 'bound'
    bound_pair['chain1_len'] = len(bound1)
    bound_pair['chain2_len'] = len(bound2)
    bound_pair['id'] = len(db5_lmdb)
    db5_lmdb.append(bound_pair)

    unbound_pair = {'atoms_pairs': pd.concat([unbound1, unbound2])}
    unbound_pair['bind'] = 'unbound'
    unbound_pair['chain1_len'] = len(unbound1)
    unbound_pair['chain2_len'] = len(unbound2)
    unbound_pair['id'] = len(db5_lmdb)
    db5_lmdb.append(unbound_pair)
    
    found_suitable_swap = False
    search_step = 0

    while not found_suitable_swap or search_step > 10:
        random_meta = db5_meta['data'][complexes[np.random.randint(len(complexes))]]
        if random_meta.name == meta.name:
            search_step += 1
            continue

        random_fp = random_meta.bound_filenames[1].split('interim')[1] # np.random.randint(2)
        random_chain = process_db5_pickle_df(f'{data_root_dir}/{random_fp}')

        random_pair = {'atoms_pairs': pd.concat([bound1, random_chain])}
        random_pair['bind'] = 'random'
        random_pair['chain1_len'] = len(bound1)
        random_pair['chain2_len'] = len(random_chain)
        random_pair['id'] = len(db5_lmdb)
        db5_lmdb.append(random_pair)
        found_suitable_swap = True
        search_step = 0





In [5]:
db5_lmdb[2]['atoms_pairs']['subunit'].unique()

array(['1SBB_l_b_cleaned.pdb', '2MTA_r_b_cleaned.pdb'], dtype=object)

#### Write DB5 data to an LMDB file

In [36]:
save_fp = '../data/paul_DB5_lmdb_random_chain_is_always_2'

if not os.path.exists(save_fp):
    os.makedirs(save_fp)

da.make_lmdb_dataset(db5_lmdb, save_fp)

100%|██████████| 690/690 [00:02<00:00, 258.94it/s]


#### Looking at some PDB structures

In [32]:
sys.path.insert(0, '/data/psample/repos/genie')
from genie.utils.data_io import coordinates_to_pdb

structs = {'bound1': bound1, 'bound2': bound2, 'unbound1': unbound1, 'unbound2': unbound2, 'random_chain': random_chain}

center_coords = False

for k, v in structs.items():
    coords = v[['x', 'y', 'z']].values
    if center_coords:
        coords = coords - coords.mean(axis=0)
    coords = coordinates_to_pdb(coords)
    with open(f'./tmp_pdbs/{k}_coord_centered_{center_coords}.pdb', 'w') as f:
        f.write(coords)

In [24]:

# coords = coordinates_to_pdb()
# coords

array([[28.54000092, 10.95300007, 80.06700134],
       [26.5909996 , 12.66100025, 77.34600067],
       [24.25799942, 15.63300037, 77.73300171],
       [21.56299973, 17.78000069, 76.12400055],
       [21.70499992, 21.58099937, 76.06300354],
       [19.67200089, 24.37199974, 74.46600342],
       [22.79500008, 26.22900009, 73.39499664],
       [26.52300072, 26.07699966, 74.08599854],
       [28.22400093, 28.4810009 , 76.49099731],
       [30.24200058, 33.95600128, 76.89900208],
       [31.79400063, 33.55199814, 80.3690033 ],
       [28.84700012, 35.1590004 , 82.18000031],
       [26.65800095, 32.29100037, 80.93499756],
       [28.65200043, 29.55599976, 82.68399811],
       [26.31399918, 28.5739994 , 85.53099823],
       [28.57799911, 26.40200043, 87.70899963],
       [32.01100159, 25.93300056, 89.23799896],
       [34.49200058, 23.34199905, 90.47699738],
       [37.65100098, 23.70299911, 92.51399994],
       [40.68500137, 23.66600037, 90.23799896],
       [43.03499985, 20.67499924, 90.513

array([[-1.07599701e+01, -2.47121520e+01,  1.05520630e+01],
       [-1.27089714e+01, -2.30041518e+01,  7.83106234e+00],
       [-1.50419716e+01, -2.00321517e+01,  8.21806338e+00],
       [-1.77369713e+01, -1.78851514e+01,  6.60906222e+00],
       [-1.75949711e+01, -1.40841527e+01,  6.54806521e+00],
       [-1.96279702e+01, -1.12931523e+01,  4.95106509e+00],
       [-1.65049710e+01, -9.43615198e+00,  3.88005832e+00],
       [-1.27769703e+01, -9.58815241e+00,  4.57106021e+00],
       [-1.10759701e+01, -7.18415117e+00,  6.97605899e+00],
       [-9.05797047e+00, -1.70915079e+00,  7.38406375e+00],
       [-7.50597042e+00, -2.11315393e+00,  1.08540650e+01],
       [-1.04529709e+01, -5.06151676e-01,  1.26650620e+01],
       [-1.26419701e+01, -3.37415171e+00,  1.14200592e+01],
       [-1.06479706e+01, -6.10915232e+00,  1.31690598e+01],
       [-1.29859719e+01, -7.09115267e+00,  1.60160599e+01],
       [-1.07219719e+01, -9.26315165e+00,  1.81940613e+01],
       [-7.28896946e+00, -9.73215151e+00