In [2]:
import os,sys
import argparse, json
import copy
import random
import pickle
import math
import torch
from torch import nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm
#from tqdm.notebook import tqdm
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Polypeptide import one_to_index
from Bio.PDB import Selection
from Bio import SeqIO
from Bio.PDB.Residue import Residue
from easydict import EasyDict
import enum
sys.path.append('/data/cb/scratch/varun/esm-multimer/github/esm-multimer/')
import esm, gzip
from Bio import SeqIO
from esm.model.esm2 import ESM2
from collections import OrderedDict
from sklearn.metrics import mean_squared_error
import scipy.stats
from torch.utils import data as torch_data
from collections import defaultdict 
import lmdb

In [2]:
class PPI(Dataset):
    
    def __init__(self, path, name, target_field, split='train', verbose=1):

        self.target_field = target_field
        lmdb_file = os.path.join(path, f'{name}/{name}_{split}.lmdb/')
        self.load_lmdbs([lmdb_file], sequence_field=["primary_1", "primary_2"], target_field=self.target_field,
                        verbose=verbose)

    def load_lmdbs(self, lmdb_files, sequence_field="primary", target_field=None, number_field="num_examples",
                   transform=None, lazy=False, verbose=0, **kwargs):


        targets = []    
        sequences = []
        num_samples = []
        for lmdb_file in lmdb_files:
            env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False)
            with env.begin(write=False) as txn:
                num_sample = pickle.loads(txn.get(number_field.encode()))
                for i in range(num_sample):
                    item = pickle.loads(txn.get(str(i).encode()))
                    sequences.append([item[field] for field in sequence_field])
                    target_value = item[target_field]
                    if isinstance(target_value, np.ndarray) and value.size == 1:
                        target_value = target_value.item()
                    targets.append(target_value)
                num_samples.append(num_sample)

        assert num_samples[0] == len(targets)
        self.sequences = sequences
        self.targets = targets
        self.num_samples = num_samples
        

    def __len__(self):
        return self.num_samples[0]
    
    def __getitem__(self, index):   
        return self.sequences[index][0], self.sequences[index][1], self.targets[index]

In [3]:
train = PPI(path='./dataset/', split='train', target_field ='interaction', name='yeast_ppi')
valid = PPI(path='./dataset/', split='valid', target_field ='interaction', name='yeast_ppi')
test = PPI(path='./dataset/', split='test', target_field ='interaction', name='yeast_ppi')

In [4]:
def convert_to_csv(dataset, name):
    seq1 = []
    seq2 = []
    targets = []
    for i in range(len(dataset)):
        s1, s2, t = dataset[i]
        seq1.append(s1)
        seq2.append(s2)
        targets.append(t)
    df = pd.DataFrame({'sequence_1': seq1, 'sequence_2': seq2, 'target': targets})
    df.to_csv(f'./processed_data_{name}.csv')

In [5]:
convert_to_csv(train, 'train')
convert_to_csv(valid, 'validation')
convert_to_csv(test, 'test')

In [3]:
pd.read_csv('processed_data_test.csv')['target'].value_counts()

1    209
0    185
Name: target, dtype: int64

In [5]:
pd.read_csv('processed_data_validation.csv')['target'].value_counts()

1    56
0    39
Name: target, dtype: int64

In [6]:
pd.read_csv('processed_data_train.csv')['target'].value_counts()

0    2522
1    2423
Name: target, dtype: int64