In [2]:
import os,sys,re
import argparse, json
import copy
import random
import pickle
import math
import torch
from torch import nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm
#from tqdm.notebook import tqdm
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Polypeptide import one_to_index
from Bio.PDB import Selection
from Bio import SeqIO
from Bio.PDB.Residue import Residue
from easydict import EasyDict
import enum
sys.path.append('/data/cb/scratch/varun/esm-multimer/github/esm-multimer/')
import esm, gzip
from Bio import SeqIO
from esm.model.esm2 import ESM2
from collections import OrderedDict
from sklearn.metrics import mean_squared_error
import scipy.stats
from torch.utils import data as torch_data
from collections import defaultdict 
import lmdb

In [15]:
def download(url, path, save_file=None, md5=None):

    from six.moves.urllib.request import urlretrieve

    if save_file is None:
        save_file = os.path.basename(url)
        if "?" in save_file:
            save_file = save_file[:save_file.find("?")]
    save_file = os.path.join(path, save_file)

    if not os.path.exists(save_file) or compute_md5(save_file) != md5:
        urlretrieve(url, save_file)
    return save_file

def compute_md5(file_name, chunk_size=65536):
    import hashlib

    md5 = hashlib.md5()
    with open(file_name, "rb") as fin:
        chunk = fin.read(chunk_size)
        while chunk:
            md5.update(chunk)
            chunk = fin.read(chunk_size)
    return md5.hexdigest()

def extract(zip_file, member=None):

    import gzip
    import shutil
    import zipfile
    import tarfile

    zip_name, extension = os.path.splitext(zip_file)
    if zip_name.endswith(".tar"):
        extension = ".tar" + extension
        zip_name = zip_name[:-4]
    save_path = os.path.dirname(zip_file)

    if extension == ".gz":
        member = os.path.basename(zip_name)
        members = [member]
        save_files = [os.path.join(save_path, member)]
        for _member, save_file in zip(members, save_files):
            with open(zip_file, "rb") as fin:
                fin.seek(-4, 2)
                file_size = struct.unpack("<I", fin.read())[0]
            with gzip.open(zip_file, "rb") as fin:
                if not os.path.exists(save_file) or file_size != os.path.getsize(save_file):
                    logger.info("Extracting %s to %s" % (zip_file, save_file))
                    with open(save_file, "wb") as fout:
                        shutil.copyfileobj(fin, fout)
    elif extension in [".tar.gz", ".tgz", ".tar"]:
        tar = tarfile.open(zip_file, "r")
        if member is not None:
            members = [member]
            save_files = [os.path.join(save_path, os.path.basename(member))]
        else:
            members = tar.getnames()
            save_files = [os.path.join(save_path, _member) for _member in members]
        for _member, save_file in zip(members, save_files):
            if tar.getmember(_member).isdir():
                os.makedirs(save_file, exist_ok=True)
                continue
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            if not os.path.exists(save_file) or tar.getmember(_member).size != os.path.getsize(save_file):
                with tar.extractfile(_member) as fin, open(save_file, "wb") as fout:
                    shutil.copyfileobj(fin, fout)
    elif extension == ".zip":
        zipped = zipfile.ZipFile(zip_file)
        if member is not None:
            members = [member]
            save_files = [os.path.join(save_path, os.path.basename(member))]
        else:
            members = zipped.namelist()
            save_files = [os.path.join(save_path, _member) for _member in members]
        for _member, save_file in zip(members, save_files):
            if zipped.getinfo(_member).is_dir():
                os.makedirs(save_file, exist_ok=True)
                continue
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            if not os.path.exists(save_file) or zipped.getinfo(_member).file_size != os.path.getsize(save_file):
                with zipped.open(_member, "r") as fin, open(save_file, "wb") as fout:
                    shutil.copyfileobj(fin, fout)
    else:
        raise ValueError("Unknown file extension `%s`" % extension)

    if len(save_files) == 1:
        return save_files[0]
    else:
        return save_path

In [None]:
# url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/human_ppi.zip"
# md5 = "89885545ebc2c11d774c342910230e20"
# path = './dataset/'

# zip_file = download(url, path, md5)
# data_path = extract(zip_file)

In [8]:
class HumanPPI(Dataset):
    
    url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/human_ppi.zip"
    md5 = "89885545ebc2c11d774c342910230e20"
    
    splits = ["train", "valid", "test", "cross_species_test"]
    target_fields = ["interaction"]

    def __init__(self, path, split='train', verbose=1):
        
        lmdb_file = os.path.join(path, f'HumanPPI/normal/{split}/')
        self.load_lmdb(lmdb_file, sequence_field=["primary_1", "primary_2"], target_fields=self.target_fields,
                        verbose=verbose)

    def load_lmdb(self, lmdb_file, sequence_field="primary", target_fields=None, number_field="num_examples",
                   transform=None, lazy=False, verbose=0, **kwargs):
  
        
        target_fields = set(target_fields)
    
        sequences = []
        num_samples = 0
        targets = defaultdict(list)
        
        self.env = lmdb.open(lmdb_file, lock=False, map_size=10995116277760)
        self.operator = self.env.begin()

    def _get(self, key: str or int):
        value = self.operator.get(str(key).encode())
        if value is not None:
            value = value.decode()
        return value

    def __len__(self):
        return int(self._get("length"))
    
    def __getitem__(self, index):   
        entry = json.loads(self._get(index))
        seq_1, seq_2 = entry['seq_1'], entry['seq_2']
        return seq_1, seq_2, int(entry["label"])

In [9]:
train = HumanPPI(path='./dataset/', split='train')
val = HumanPPI(path='./dataset/', split='valid')
test = HumanPPI(path='./dataset/', split='test')

In [10]:
def convert_to_csv(dataset, name):
    seq1 = []
    seq2 = []
    targets = []
    for i in range(len(dataset)):
        s1, s2, t = dataset[i]
        seq1.append(s1)
        seq2.append(s2)
        targets.append(t)
    df = pd.DataFrame({'sequence_1': seq1, 'sequence_2': seq2, 'target': targets})
    df.to_csv(f'./processed_data_{name}.csv')

In [11]:
convert_to_csv(train, 'train')
convert_to_csv(val, 'validation')
convert_to_csv(test, 'test')

In [None]:
train_loader = torch.utils.data.DataLoader(
            x, batch_size=2, collate_fn=FlabCollateFn(), shuffle=True)

In [None]:
y = next(iter(train_loader))

In [None]:
class FlabCollateFn:
    
    def __init__(self, truncation_seq_length=None):
        self.alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, batches):
        batch_size = len(batches)
        heavy_chain, light_chain, labels = zip(*batches)
        
        chains = [self.convert(c) for c in [heavy_chain, light_chain]]
        chain_ids = [torch.ones(c.shape, dtype=torch.int32) * i for i, c in enumerate(chains)]
        chains = torch.cat(chains, -1)
        chain_ids = torch.cat(chain_ids, -1)
        labels = torch.from_numpy(np.stack(labels, 0))
        
        return chains, chain_ids, labels

    def convert(self, seq_str_list):
        batch_size = len(seq_str_list)
        seq_encoded_list = [self.alphabet.encode('<cls>' + seq_str.replace('J', 'L') + '<eos>') for seq_str in seq_str_list]
        if self.truncation_seq_length:
            for i in range(batch_size):
                seq = seq_encoded_list[i]
                if len(seq) > self.truncation_seq_length:
                    start = random.randint(0, len(seq) - self.truncation_seq_length + 1)
                    seq_encoded_list[i] = seq[start:start+self.truncation_seq_length]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        if self.truncation_seq_length:
            assert max_len <= self.truncation_seq_length
        tokens = torch.empty((batch_size, max_len), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)
        
        for i, seq_encoded in enumerate(seq_encoded_list):
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[i,:len(seq_encoded)] = seq
        return tokens

In [4]:
from lmdb_dataset import LMDBDataset

import torch
import json

class SaprotPPIDataset(LMDBDataset):
    def __init__(self, **kwargs):
        """
        Args:
            tokenizer: Path to tokenizer
            
            max_length: Max length of sequence
            
            plddt_threshold: If not None, mask structure tokens with pLDDT < threshold
            
            **kwargs:
        """
        super().__init__(**kwargs)

    def __getitem__(self, index):
        entry = json.loads(self._get(index))
        seq_1, seq_2 = entry['seq_1'], entry['seq_2']
        
        return seq_1, seq_2, int(entry["label"])

    def __len__(self):
        return int(self._get("length"))

In [12]:
x = SaprotPPIDataset(train_lmdb = './dataset/HumanPPI/normal/train/',
                valid_lmdb = './dataset/HumanPPI/normal/valid/',
                test_lmdb = './dataset/HumanPPI/normal/test/')

In [5]:
103/(77+103)

0.5722222222222222

In [4]:
pd.read_csv('processed_data_test.csv')['target'].value_counts()

0    103
1     77
Name: target, dtype: int64

In [7]:
pd.read_csv('processed_data_validation.csv')['target'].value_counts()

0    135
1     99
Name: target, dtype: int64

In [8]:
pd.read_csv('processed_data_train.csv')['target'].value_counts()

0    14259
1    12060
Name: target, dtype: int64