In [1]:
%reset -f
import pandas as pd
import scanpy as sc
import numpy as np
from fuzzywuzzy import fuzz
import pickle
import gc

In [2]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' # do this before importing pytorch
import torch
import esm
from torch.utils.data import TensorDataset
from esm import Alphabet, FastaBatchedDataset

In [3]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
import os
# Retrieve the value of the environment variable
value = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
# Print the value
print("PYTORCH_CUDA_ALLOC_CONF:", value)

PYTORCH_CUDA_ALLOC_CONF: max_split_size_mb:4096


In [5]:
df = pd.read_csv('./data/high_confidence_ppi.csv', index_col=0)
columns_to_keep = ['UniProtID_1', 'UniProtID_2', 'symbol_1', 'symbol_2', 'seq_1', 'seq_2',
                   'Experimental System', 'Throughput', 'len_1', 'len_2', 'Pubmed ID']
df = df[columns_to_keep]
df.drop_duplicates(inplace=True)  # no duplicate
# drop rows where either seq1 or seq2 has string length > 2000
df = df[~((df['seq_1'].str.len() > 2000) | (df['seq_2'].str.len() > 2000))]
# generate embeddings
seqs_wt1 = df.seq_1.values.tolist()
seqs_wt2 = df.seq_2.values.tolist()
seqs_wt1 = set(seqs_wt1)
seqs_wt2 = set(seqs_wt2)
seqs_labeled_wt1 = []
count = 0
for seq in seqs_wt1:
    seqs_labeled_wt1.append(tuple((str('seq' + str(count)), seq)))
    count += 1
seqs_labeled_wt2 = []
count = 0
for seq in seqs_wt2:
    seqs_labeled_wt2.append(tuple((str('seq' + str(count)), seq)))
    count += 1
# alternative way to generate batches

In [6]:
batch_size = 1000
dataset = FastaBatchedDataset(list(zip(*seqs_labeled_wt1))[0], list(zip(*seqs_labeled_wt1))[1])
batches = dataset.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(dataset,
                                          collate_fn=Alphabet.from_architecture("roberta_large").get_batch_converter(),
                                          batch_sampler=batches, pin_memory=True)
dataset_seq2 = FastaBatchedDataset(list(zip(*seqs_labeled_wt2))[0], list(zip(*seqs_labeled_wt2))[1])
batches_seq2 = dataset_seq2.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader_seq2 = torch.utils.data.DataLoader(dataset_seq2, collate_fn=Alphabet.from_architecture(
    "roberta_large").get_batch_converter(), batch_sampler=batches_seq2, pin_memory=True)

In [7]:
len(max(list(zip(*seqs_labeled_wt1))[1], key=len))

1989

In [12]:
# Load ESM-2 model
#del model
gc.collect()
torch.cuda.empty_cache()
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
if torch.cuda.device_count() > 1:
    print('gonna use', torch.cuda.device_count(), 'GPUs!')
model.eval()  # disables dropout for deterministic results
# find longest sequence
len(max(list(zip(*seqs_labeled_wt1))[1], key=len))

gonna use 2 GPUs!


1989

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [20]:
torch.cuda.empty_cache()
model = torch.nn.DataParallel(model) # use all GPUs by default
# splitting the input batch into GPUs
model.cuda()
print('Transferred model to GPU')

Transferred model to GPU


In [21]:
representation_store_dict = {}
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
    #if torch.cuda.is_available():
     #   toks = toks.to(device)
    with torch.no_grad():
        results = model(toks, repr_layers = [33], return_contacts = True)['representations'][33]
    results_cpu = results.to(device='cpu')
    #torch.cuda.empty_cache()
    tprint('Batch ID: '+str(batch_idx)+str(labels)+str(len(strs)))
    #tprint(torch.cuda.memory_allocated())
    #tprint(torch.cuda.memory_snapshot())
    for i, str_ in enumerate(strs):
        # only select representations relate to the sequence
        # rest of the sequences are paddings, check notebook
        # create dictionary {sequence: embeddings}
        representation_store_dict[str_] = results_cpu[i, 1: (len(strs[i])+1)].numpy()
    tprint('finish generating embeddings')

2023-09-04 11:30:06.318697 | Batch ID: 0['seq1179', 'seq253', 'seq1744', 'seq2455', 'seq2789', 'seq582', 'seq3220', 'seq130', 'seq1974', 'seq1821', 'seq194', 'seq323', 'seq481', 'seq2210']14
2023-09-04 11:30:06.321901 | finish generating embeddings
2023-09-04 11:30:13.764490 | Batch ID: 1['seq1617', 'seq1145', 'seq2509', 'seq1806', 'seq652', 'seq1223', 'seq2922', 'seq852', 'seq2101', 'seq3062', 'seq3063', 'seq2054']12
2023-09-04 11:30:13.767056 | finish generating embeddings
2023-09-04 11:30:21.040868 | Batch ID: 2['seq482', 'seq2270', 'seq3371', 'seq1479', 'seq2104', 'seq2582', 'seq326', 'seq1487', 'seq1608', 'seq2747', 'seq3532']11
2023-09-04 11:30:21.043586 | finish generating embeddings
2023-09-04 11:30:33.315545 | Batch ID: 3['seq1789', 'seq1232', 'seq1300', 'seq2325', 'seq235', 'seq2173', 'seq2702', 'seq3319', 'seq254', 'seq2657']10
2023-09-04 11:30:33.317019 | finish generating embeddings
2023-09-04 11:30:40.517603 | Batch ID: 4['seq2016', 'seq3038', 'seq3108', 'seq807', 'seq163

KeyboardInterrupt: 

In [8]:
import datetime, sys
def tprint(string):
    string = str(string)
    sys.stdout.write(str(datetime.datetime.now()) + ' | ')
    sys.stdout.write(string + '\n')
    sys.stdout.flush()
