In [1]:
import numpy as np
import pandas as pd
import os
import random
import atom3d.datasets.datasets as da

import sys
sys.path.append('../')
import gvp
import gvp.atom3d as gvp_atom3d
import esm


#### Load DIPS dataset (from GGN meets pLM, but really I think it's from the GVP paper)

In [2]:
dataset = da.load_dataset('../data/PPI/DIPS-split/data/test', 'lmdb')

#### Run through the data and randomly swap one chain in order to make examples of non-binders.
But also include the real binders in the dataset. Ultimately, we're doubling the size of the dataset.

In [5]:
precompute_plm_embeddings = True

if precompute_plm_embeddings:
    esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    esm_model.eval()
    esm_model = esm_model.to('cuda')

cols = ['ensemble', 'subunit', 'structure', 'residue', 'resname', 'name','x', 'y', 'z']
max_len = 800
num_dropped = 0

new_dataset = []

for i, pair in enumerate(dataset):
    if i % 1000 == 0:
        print(i, '/', len(dataset), len(new_dataset))
        print('Dropped', num_dropped, 'pairs')

    p1 = pair['atoms_pairs'][cols]
    p1 = p1[p1['name'] == 'CA']
    subunits = p1['subunit'].unique()

    # Max length check. If either chain is too long, drop the pair altogether
    len_chain1 = len(p1[p1['subunit'] == subunits[0]])
    len_chain2 = len(p1[p1['subunit'] == subunits[1]])
    if (len_chain1 > max_len) | (len_chain2 > max_len):
        num_dropped += 1
        continue

    subunit_p1 = random.choice(subunits)
    chain_p1 = p1[p1['subunit'] == subunit_p1].reset_index(drop=True)

    found_suitable_swap = False
    search_step = 0
    while not found_suitable_swap:
        p2 = dataset[np.random.randint(len(dataset))]['atoms_pairs'][cols]
        subunit_p2 = random.choice(p2['subunit'].unique())
        if subunit_p2.split('.')[0] != subunit_p1.split('.')[0]:
            chain_p2 = p2[p2['subunit'] == subunit_p2]
            chain_p2 = chain_p2[chain_p2['name'] == 'CA'].reset_index(drop=True)

            # Max length check. If the chain we're trying to swap is too long, try again.
            if len(chain_p2) > max_len:
                continue

            # Construct the new pair that we assume are non-binders
            new_pair = {'atoms_pairs': pd.concat([chain_p1, chain_p2])}
            new_pair['bind'] = 0
            new_pair['chain1_len'] = len(chain_p1)
            new_pair['chain2_len'] = len(chain_p2)
            new_pair['id'] = len(new_dataset)
            if precompute_plm_embeddings:
                new_pair['esm_emb_p1'] = pd.DataFrame(gvp_atom3d.get_plm_reps(chain_p1, esm_model, batch_converter, device='cuda').cpu().numpy())
                new_pair['esm_emb_p2'] = pd.DataFrame(gvp_atom3d.get_plm_reps(chain_p2, esm_model, batch_converter, device='cuda').cpu().numpy())
            else:
                new_pair['esm_emb_p1'] = None
                new_pair['esm_emb_p2'] = None

            new_dataset.append(new_pair)
            found_suitable_swap = True
        
        search_step += 1
        if search_step == 20:
            print('Could not find suitable swap for pair', i, subunit_p1.split('.')[0])
            found_suitable_swap = True

    original_pair = {'atoms_pairs': p1,
          'bind': 1,
          'id': len(new_dataset),
          'chain1_len': len_chain1,
          'chain2_len': len_chain2}
    if precompute_plm_embeddings:
        original_pair['esm_emb_p1'] = pd.DataFrame(gvp_atom3d.get_plm_reps(p1[p1['subunit'] == subunits[0]], esm_model, batch_converter, device='cuda').cpu().numpy())
        original_pair['esm_emb_p2'] = pd.DataFrame(gvp_atom3d.get_plm_reps(p1[p1['subunit'] == subunits[1]], esm_model, batch_converter, device='cuda').cpu().numpy())
    else:
        original_pair['esm_emb_p1'] = None
        original_pair['esm_emb_p2'] = None
    new_dataset.append(original_pair)

    if i == 200:
        break


0 / 15268 0
Dropped 0 pairs


#### Write to a new LMDB file

In [6]:
save_fp = '../data/paul_PPbind_w_esm_emb/test'

if not os.path.exists(save_fp):
    os.makedirs(save_fp)

da.make_lmdb_dataset(new_dataset, save_fp)

100%|██████████| 402/402 [05:41<00:00,  1.18it/s]


#### Test it out

In [7]:
dataset = da.load_dataset(save_fp, 'lmdb')
t = dataset[0]

In [8]:
t

{'atoms_pairs':                               ensemble           subunit     structure  \
 0    1a05.pdb1.gz_1_A_1a05.pdb1.gz_1_B  1a05.pdb1.gz_1_A  1a05.pdb1.gz   
 1    1a05.pdb1.gz_1_A_1a05.pdb1.gz_1_B  1a05.pdb1.gz_1_A  1a05.pdb1.gz   
 2    1a05.pdb1.gz_1_A_1a05.pdb1.gz_1_B  1a05.pdb1.gz_1_A  1a05.pdb1.gz   
 3    1a05.pdb1.gz_1_A_1a05.pdb1.gz_1_B  1a05.pdb1.gz_1_A  1a05.pdb1.gz   
 4    1a05.pdb1.gz_1_A_1a05.pdb1.gz_1_B  1a05.pdb1.gz_1_A  1a05.pdb1.gz   
 ..                                 ...               ...           ...   
 235  1e6d.pdb1.gz_1_H_1e6d.pdb1.gz_1_M  1e6d.pdb1.gz_1_H  1e6d.pdb1.gz   
 236  1e6d.pdb1.gz_1_H_1e6d.pdb1.gz_1_M  1e6d.pdb1.gz_1_H  1e6d.pdb1.gz   
 237  1e6d.pdb1.gz_1_H_1e6d.pdb1.gz_1_M  1e6d.pdb1.gz_1_H  1e6d.pdb1.gz   
 238  1e6d.pdb1.gz_1_H_1e6d.pdb1.gz_1_M  1e6d.pdb1.gz_1_H  1e6d.pdb1.gz   
 239  1e6d.pdb1.gz_1_H_1e6d.pdb1.gz_1_M  1e6d.pdb1.gz_1_H  1e6d.pdb1.gz   
 
      residue resname name          x           y          z  
 0          1     ME