In [1]:
%reset -f
import pandas as pd
import scanpy as sc
import numpy as np
from fuzzywuzzy import fuzz
import pickle
import gc

In [2]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' # do this before importing pytorch
import torch
import esm
from torch.utils.data import TensorDataset
from esm import Alphabet, FastaBatchedDataset

In [3]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
import os
# Retrieve the value of the environment variable
value = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
# Print the value
print("PYTORCH_CUDA_ALLOC_CONF:", value)

PYTORCH_CUDA_ALLOC_CONF: max_split_size_mb:4096


In [5]:
df = pd.read_csv('./data/high_confidence_ppi.csv', index_col=0)
columns_to_keep = ['UniProtID_1', 'UniProtID_2', 'symbol_1', 'symbol_2', 'seq_1', 'seq_2',
                   'Experimental System', 'Throughput', 'len_1', 'len_2', 'Pubmed ID']
df = df[columns_to_keep]
df.drop_duplicates(inplace=True)  # no duplicate
# drop rows where either seq1 or seq2 has string length > 2000
df = df[~((df['seq_1'].str.len() > 2000) | (df['seq_2'].str.len() > 2000))]
# generate embeddings
seqs_wt1 = df.seq_1.values.tolist()
seqs_wt2 = df.seq_2.values.tolist()
seqs_wt1 = set(seqs_wt1)
seqs_wt2 = set(seqs_wt2)
seqs_labeled_wt1 = []
count = 0
for seq in seqs_wt1:
    seqs_labeled_wt1.append(tuple((str('seq' + str(count)), seq)))
    count += 1
seqs_labeled_wt2 = []
count = 0
for seq in seqs_wt2:
    seqs_labeled_wt2.append(tuple((str('seq' + str(count)), seq)))
    count += 1
# alternative way to generate batches

In [9]:
batch_size = 1000
dataset = FastaBatchedDataset(list(zip(*seqs_labeled_wt1))[0], list(zip(*seqs_labeled_wt1))[1])
batches = dataset.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(dataset,
                                          collate_fn=Alphabet.from_architecture("roberta_large").get_batch_converter(),
                                          batch_sampler=batches, pin_memory=True)
dataset_seq2 = FastaBatchedDataset(list(zip(*seqs_labeled_wt2))[0], list(zip(*seqs_labeled_wt2))[1])
batches_seq2 = dataset_seq2.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader_seq2 = torch.utils.data.DataLoader(dataset_seq2, collate_fn=Alphabet.from_architecture(
    "roberta_large").get_batch_converter(), batch_sampler=batches_seq2, pin_memory=True)

In [7]:
len(max(list(zip(*seqs_labeled_wt1))[1], key=len))

1989

In [8]:
import datetime, sys
def tprint(string):
    string = str(string)
    sys.stdout.write(str(datetime.datetime.now()) + ' | ')
    sys.stdout.write(string + '\n')
    sys.stdout.flush()

In [10]:
# try batch inspection instead
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()
# find longest sequence
torch.cuda.empty_cache()
model.cuda()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bia

In [9]:
len(data_loader)

2200

In [11]:
# QC
start_batch_id = 1188
end_batch_id = 1615
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
    #if start_batch_id <= batch_idx <= end_batch_id:
    if len(strs[0]) == 1538:
        print(batch_idx, labels, strs, len(strs[0]))
        #break

2115 ['seq2482'] ['MEPRMESCLAQVLQKDVGKRLQVGQELIDYFSDKQKSADLEHDQTMLDKLVDGLATSWVNSSNYKVVLLGMDILSALVTRLQDRFKAQIGTVLPSLIDRLGDAKDSVREQDQTLLLKIMDQAANPQYVWDRMLGGFKHKNFRTREGICLCLIATLNASGAQTLTLSKIVPHICNLLGDPNSQVRDAAINSLVEIYRHVGERVRADLSKKGLPQSRLNVIFTKFDEVQKSGNMIQSANDKNFDDEDSVDGNRPSSASSTSSKAPPSSRRNVGMGTTRRLGSSTLGSKSSAAKEGAGAVDEEDFIKAFDDVPVVQIYSSRDLEESINKIREILSDDKHDWEQRVNALKKIRSLLLAGAAEYDNFFQHLRLLDGAFKLSAKDLRSQVVREACITLGHLSSVLGNKFDHGAEAIMPTIFNLIPNSAKIMATSGVVAVRLIIRHTHIPRLIPVITSNCTSKSVAVRRRCFEFLDLLLQEWQTHSLERHISVLAETIKKGIHDADSEARIEARKCYWGFHSHFSREAEHLYHTLESSYQKALQSHLKNSDSIVSLPQSDRSSSSSQESLNRPLSAKRSPTGSTTSRASTVSTKSVSTTGSLQRSRSDIDVNAAASAKSKVSSSSGTTPFSSAAALPPGSYASLGRIRTRRQSSGSATNVASTPDNRGRSRAKVVSQSQRSRSANPAGAGSRSSSPGKLLGSGYGGLTGGSSRGPPVTPSSEKRSKIPRSQGCSRETSPNRIGLARSSRIPRPSMSQGCSRDTSRESSRDTSPARGFPPLDRFGLGQPGRIPGSVNAMRVLSTSTDLEAAVADALKKPVRRRYEPYGMYSDDDANSDASSVCSERSYGSRNGGIPHYLRQTEDVAEVLNHCASSNWSERKEGLLGLQNLLKSQRTLSRVELKRLCEIFTRMFADPHSKRVFSMFLETLVDFIIIHKDDLQDWLFVLLTQLLKKMGADLLGSVQAKVQKALDVTRDSFPFDQQFNILMR

In [10]:
def get_gpu_memory():
    """
    Get the current GPU memory usage.

    Returns:
        allocated (float): Memory allocated by tensors in GB.
        cached (float): Cached memory in GB.
    """
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside cache
    return a/1024**3, r/1024**3


In [14]:
representation_store_dict = {}
start_batch_id = 2115
end_batch_id = 2200
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
    start_allocated, start_cached = get_gpu_memory()
    if start_batch_id <= batch_idx <= end_batch_id:
        if torch.cuda.is_available():
            toks = toks.to(device='cuda', non_blocking=True)
        end_allocated, end_cached = get_gpu_memory()
        tprint(f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        start_allocated, start_cached = get_gpu_memory()
        with torch.no_grad():
            results = model(toks, repr_layers = [33], return_contacts = True)['representations'][33]
        end_allocated, end_cached = get_gpu_memory()
        tprint(f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        start_allocated, start_cached = get_gpu_memory()
        results_cpu = results.to(device='cpu')
        end_allocated, end_cached = get_gpu_memory()
        tprint(f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        start_allocated, start_cached = get_gpu_memory()
        del results, toks
        torch.cuda.empty_cache()
        end_allocated, end_cached = get_gpu_memory()
        tprint(f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        tprint('Batch ID: '+str(batch_idx)+str(labels)+str(len(strs)))
        #tprint(torch.cuda.memory_allocated())
        #tprint(torch.cuda.memory_snapshot())
        start_allocated, start_cached = get_gpu_memory()
        for i, str_ in enumerate(strs):
        # only select representations relate to the sequence
        # rest of the sequences are paddings, check notebook
        # create dictionary {sequence: embeddings}
            representation_store_dict[str_] = results_cpu[i, 1: (len(strs[i])+1)].numpy()
        end_allocated, end_cached = get_gpu_memory()
        tprint(f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        tprint('finish generating embeddings')
    elif batch_idx > end_batch_id:
        break

2023-09-04 16:10:08.574704 | Batch 2115: Allocated memory increased by 0.00 GB, Cached memory increased by 0.00 GB


RuntimeError: CUDA out of memory. Tried to allocate 5.82 GiB (GPU 0; 39.59 GiB total capacity; 31.67 GiB already allocated; 5.78 GiB free; 31.77 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [10]:
""" debug, really hardware issue? running stuffs on CPU where there's plenty memory """
# try batch inspection instead
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results
# find longest sequence
len(max(list(zip(*seqs_labeled_wt1))[1], key=len))
model.to('cpu')

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bia

In [None]:
representation_store_dict = {}
start_batch_id = 2115
end_batch_id = 2200
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
    if start_batch_id <= batch_idx <= end_batch_id:
        #if torch.cuda.is_available():
            #toks = toks.to(device='cuda', non_blocking=True)
        with torch.no_grad():
            results_cpu = model(toks, repr_layers = [33], return_contacts = True)['representations'][33]
        #results_cpu = results.to(device='cpu')
        #del results_cpu
        tprint('Batch ID: '+str(batch_idx)+str(labels)+str(strs))
        for i, str_ in enumerate(strs):
            # only select representations relate to the sequence
            # rest of the sequences are paddings, check notebook
            # create dictionary {sequence: embeddings}
            representation_store_dict[str_] = results_cpu[i, 1: (len(strs[i])+1)].numpy()
    elif batch_idx > end_batch_id:
        break

2023-09-04 16:47:08.397970 | Batch ID: 2115['seq816']['MEPRMESCLAQVLQKDVGKRLQVGQELIDYFSDKQKSADLEHDQTMLDKLVDGLATSWVNSSNYKVVLLGMDILSALVTRLQDRFKAQIGTVLPSLIDRLGDAKDSVREQDQTLLLKIMDQAANPQYVWDRMLGGFKHKNFRTREGICLCLIATLNASGAQTLTLSKIVPHICNLLGDPNSQVRDAAINSLVEIYRHVGERVRADLSKKGLPQSRLNVIFTKFDEVQKSGNMIQSANDKNFDDEDSVDGNRPSSASSTSSKAPPSSRRNVGMGTTRRLGSSTLGSKSSAAKEGAGAVDEEDFIKAFDDVPVVQIYSSRDLEESINKIREILSDDKHDWEQRVNALKKIRSLLLAGAAEYDNFFQHLRLLDGAFKLSAKDLRSQVVREACITLGHLSSVLGNKFDHGAEAIMPTIFNLIPNSAKIMATSGVVAVRLIIRHTHIPRLIPVITSNCTSKSVAVRRRCFEFLDLLLQEWQTHSLERHISVLAETIKKGIHDADSEARIEARKCYWGFHSHFSREAEHLYHTLESSYQKALQSHLKNSDSIVSLPQSDRSSSSSQESLNRPLSAKRSPTGSTTSRASTVSTKSVSTTGSLQRSRSDIDVNAAASAKSKVSSSSGTTPFSSAAALPPGSYASLGRIRTRRQSSGSATNVASTPDNRGRSRAKVVSQSQRSRSANPAGAGSRSSSPGKLLGSGYGGLTGGSSRGPPVTPSSEKRSKIPRSQGCSRETSPNRIGLARSSRIPRPSMSQGCSRDTSRESSRDTSPARGFPPLDRFGLGQPGRIPGSVNAMRVLSTSTDLEAAVADALKKPVRRRYEPYGMYSDDDANSDASSVCSERSYGSRNGGIPHYLRQTEDVAEVLNHCASSNWSERKEGLLGLQNLLKSQRTLSRVELKRLCEIFTRMFADPHSKRVFSMFLETLVDFIIIHKDDLQDWLFVLLTQLLK

In [6]:
"""debug wt sequence 2"""
path = './outputs/variables/'
with open(path + 'PPI_seq1_embeddings_full_22364.pk', 'rb') as f:
    representation_store_dict = pickle.load(f)
sequence_embeddings = {key: np.mean(value, axis=0, keepdims=True) for key, value in representation_store_dict.items()}


def update_embeddings(row, embedding_dict):
    """
    add embeddings to the metadata column.
    cannot do the reverse, because due to mislabel, several different protein names share the same sequences
    but as long as sequences are correct, so will the embeddings
    """
    for key, value in embedding_dict.items():
        if row == key:
            return value


df['wild_seq_1_embeddings'] = df['seq_1'].apply(update_embeddings, embedding_dict=sequence_embeddings)
df = df.dropna(subset=['wild_seq_1_embeddings'])
df

Unnamed: 0,UniProtID_1,UniProtID_2,symbol_1,symbol_2,seq_1,seq_2,Experimental System,Throughput,len_1,len_2,Pubmed ID,wild_seq_1_embeddings
0,P67809,Q13242,YBOX1_HUMAN,SRSF9_HUMAN,MSSEAETQQPPAAPPAAPALSAADTKPGTTGSGAGSGGPGGLTSAA...,MSGWADERGGEGDGRIYVGNLPTDVREKDLEDLFYKYGRIREIELK...,Two-hybrid,Low Throughput,324,221,12604611.0,"[[0.020276865, -0.055094432, -0.05752476, 0.06..."
1,Q8NER1,Q86SS6,TRPV1_HUMAN,SYT9_HUMAN,MKKWSSTDLGAAADPLQKDTCPDPLDGDPNSRPPPAKPQLSTAKSR...,MPGARDALCHQALQLLAELCARGALEHDSCQDFIYHLRDRARPRLR...,Two-hybrid,Low Throughput,839,491,15066994.0,"[[0.005592045, -0.07453048, -0.054872476, 0.06..."
2,O75319,Q13242,DUS11_HUMAN,SRSF9_HUMAN,MRNSETLERGVGGCRVFSCLGSYPGIEGAGLALLADLALGGRLLGT...,MSGWADERGGEGDGRIYVGNLPTDVREKDLEDLFYKYGRIREIELK...,Two-hybrid,Low Throughput,377,221,9685386.0,"[[-0.006202125, -0.0743498, 0.023617525, 0.003..."
3,Q13526,Q7Z5J4,PIN1_HUMAN,RAI1_HUMAN,MADEEKLPPGWEKRMSRSSGRVYYFNHITNASQWERPSGNSSSGGK...,MQSFRERCGFHGKQQNYQQTSQETSRLENYRQPSQAGLSCDRQRLL...,Two-hybrid,High Throughput,163,1906,16189514.0,"[[-0.005602383, -0.0029278726, 0.014461238, 0...."
4,Q13526,Q9UJY4,PIN1_HUMAN,GGA2_HUMAN,MADEEKLPPGWEKRMSRSSGRVYYFNHITNASQWERPSGNSSSGGK...,MAATAVAAAVAGTESAQGPPGPAASLELWLNKATDPSMSEQDWSAI...,Two-hybrid,High Throughput,163,613,16189514.0,"[[-0.005602383, -0.0029278726, 0.014461238, 0...."
...,...,...,...,...,...,...,...,...,...,...,...,...
24516,P60763,Q13443,RAC3_HUMAN,ADAM9_HUMAN,MQAIKCVVVGDGAVGKTCLLISYTTNAFPGEYIPTVFDNYSANVMV...,MGSGARFPSGTLRVRWLLLLGLVGPVLGAARPGFQQTSHLSSYEII...,Affinity Capture-MS,Low Throughput,192,819,31871319.0,"[[-0.023273492, -0.0878509, -0.1062212, -0.007..."
24517,P60763,Q9UNK0,RAC3_HUMAN,STX8_HUMAN,MQAIKCVVVGDGAVGKTCLLISYTTNAFPGEYIPTVFDNYSANVMV...,MAPDPWFSTYDSTCQIAQEIAEKIQQRNQYERKGEKAPKLTVTIRA...,Affinity Capture-MS,Low Throughput,192,236,31871319.0,"[[-0.023273492, -0.0878509, -0.1062212, -0.007..."
24518,Q9BQ90,Q15843,KLDC3_HUMAN,NEDD8_HUMAN,MLRWTVHLEGGPRRVNHAAVAVGHRVYSFGGYCSGEDYETLRQIDV...,MLIKVKTLTGKEIEIDIEPTDKVERIKERVEEKEGIPPQQQRLIYS...,Affinity Capture-MS,Low Throughput,382,81,35468939.0,"[[-0.015261493, -0.039154943, 0.05141214, 0.10..."
24519,Q9BWK5,Q9H9Q4,CYREN_HUMAN,NHEJ1_HUMAN,METLQSETKTRVLPSWLTAQVATKNVAPMKAPKRMRMAAVPVAAAR...,MEELEQGLLMQPWAWLQLAENSLLAKVFITKQGYALLVSDLQQVWH...,Affinity Capture-Western,Low Throughput,157,299,30017584.0,"[[-0.012049904, -0.064932406, 0.010630969, 0.1..."


In [7]:
seqs_labeled_wt2 = []
seqs_wt2 = df.seq_2.values.tolist()
seqs_wt2 = set(seqs_wt2)
count = 0
for seq in seqs_wt2:
    seqs_labeled_wt2.append(tuple((str('seq' + str(count)), seq)))
    count += 1
# Load ESM-2 model
#del model
gc.collect()
torch.cuda.empty_cache()
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results
batch_size = 2000
dataset_seq2 = FastaBatchedDataset(list(zip(*seqs_labeled_wt2))[0], list(zip(*seqs_labeled_wt2))[1])
batches_seq2 = dataset_seq2.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader_seq2 = torch.utils.data.DataLoader(dataset_seq2, collate_fn=Alphabet.from_architecture(
    "roberta_large").get_batch_converter(), batch_sampler=batches_seq2, pin_memory=False)
len(data_loader_seq2)
torch.cuda.empty_cache()
if torch.cuda.is_available():
    model = model.cuda()
    print('Transferred model to GPU')
import datetime


def tprint(string):
    string = str(string)
    sys.stdout.write(str(datetime.datetime.now()) + ' | ')
    sys.stdout.write(string + '\n')
    sys.stdout.flush()

Transferred model to GPU


In [18]:
# QC
start_batch_id = 1502
end_batch_id = 1505
for batch_idx, (labels, strs, toks) in enumerate(data_loader_seq2):
    if start_batch_id <= batch_idx <= end_batch_id:
    #if len(strs[0]) == 1538:
        print(batch_idx, labels, strs, len(strs[0]))
        #break

1502 ['seq667'] ['MAKRSRGPGRRCLLALVLFCAWGTLAVVAQKPGAGCPSRCLCFRTTVRCMHLLLEAVPAVAPQTSILDLRFNRIREIQPGAFRRLRNLNTLLLNNNQIKRIPSGAFEDLENLKYLYLYKNEIQSIDRQAFKGLASLEQLYLHFNQIETLDPDSFQHLPKLERLFLHNNRITHLVPGTFNHLESMKRLRLDSNTLHCDCEILWLADLLKTYAESGNAQAAAICEYPRRIQGRSVATITPEELNCERPRITSEPQDADVTSGNTVYFTCRAEGNPKPEIIWLRNNNELSMKTDSRLNLLDDGTLMIQNTQETDQGIYQCMAKNVAGEVKTQEVTLRYFGSPARPTFVIQPQNTEVLVGESVTLECSATGHPPPRISWTRGDRTPLPVDPRVNITPSGGLYIQNVVQGDSGEYACSATNNIDSVHATAFIIVQALPQFTVTPQDRVVIEGQTVDFQCEAKGNPPPVIAWTKGGSQLSVDRRHLVLSSGTLRISGVALHDQGQYECQAVNIIGSQKVVAHLTVQPRVTPVFASIPSDTTVEVGANVQLPCSSQGEPEPAITWNKDGVQVTESGKFHISPEGFLTINDVGPADAGRYECVARNTIGSASVSMVLSVNVPDVSRNGDPFVATSIVEAIATVDRAINSTRTHLFDSRPRSPNDLLALFRYPRDPYTVEQARAGEIFERTLQLIQEHVQHGLMVDLNGTSYHYNDLVSPQYLNLIANLSGCTAHRRVNNCSDMCFHQKYRTHDGTCNNLQHPMWGASLTAFERLLKSVYENGFNTPRGINPHRLYNGHALPMPRLVSTTLIGTETVTPDEQFTHMLMQWGQFLDHDLDSTVVALSQARFSDGQHCSNVCSNDPPCFSVMIPPNDSRARSGARCMFFVRSSPVCGSGMTSLLMNSVYPREQINQLTSYIDASNVYGSTEHEARSIRDLASHRGLLRQGIVQRSGKPLLPFATGPPTECMRDENESPIPCFLAGDHRANEQL

In [11]:
representation_store_dict = {}
start_batch_id = 1502
end_batch_id = 1505
for batch_idx, (labels, strs, toks) in enumerate(data_loader_seq2):
    start_allocated, start_cached = get_gpu_memory()
    if start_batch_id <= batch_idx <= end_batch_id:
        if torch.cuda.is_available():
            toks = toks.to(device='cuda', non_blocking=True)
        end_allocated, end_cached = get_gpu_memory()
        tprint(
            f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        start_allocated, start_cached = get_gpu_memory()
        with torch.no_grad():
            results = model(toks, repr_layers=[33], return_contacts=True)['representations'][33]
        end_allocated, end_cached = get_gpu_memory()
        tprint(
            f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        start_allocated, start_cached = get_gpu_memory()
        results_cpu = results.to(device='cpu')
        end_allocated, end_cached = get_gpu_memory()
        tprint(
            f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        start_allocated, start_cached = get_gpu_memory()
        del results, toks
        torch.cuda.empty_cache()
        end_allocated, end_cached = get_gpu_memory()
        tprint(
            f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        tprint('Batch ID: ' + str(batch_idx) + str(labels) + str(len(strs)))
        #tprint(torch.cuda.memory_allocated())
        #tprint(torch.cuda.memory_snapshot())
        start_allocated, start_cached = get_gpu_memory()
        for i, str_ in enumerate(strs):
            # only select representations relate to the sequence
            # rest of the sequences are paddings, check notebook
            # create dictionary {sequence: embeddings}
            representation_store_dict[str_] = results_cpu[i, 1: (len(strs[i]) + 1)].numpy()
        end_allocated, end_cached = get_gpu_memory()
        tprint(
            f"Batch {batch_idx}: Allocated memory increased by {(end_allocated - start_allocated):.2f} GB, Cached memory increased by {(end_cached - start_cached):.2f} GB")
        tprint('finish generating embeddings')
    elif batch_idx > end_batch_id:
        break

2023-09-05 11:29:31.572825 | Batch 1502: Allocated memory increased by 0.00 GB, Cached memory increased by 0.00 GB
2023-09-05 11:29:31.707259 | Batch 1502: Allocated memory increased by 0.03 GB, Cached memory increased by 32.61 GB
2023-09-05 11:29:31.915605 | Batch 1502: Allocated memory increased by 0.00 GB, Cached memory increased by 0.00 GB
2023-09-05 11:29:31.978152 | Batch 1502: Allocated memory increased by -0.01 GB, Cached memory increased by -32.58 GB
2023-09-05 11:29:31.978735 | Batch ID: 1502['seq1871']1
2023-09-05 11:29:31.979614 | Batch 1502: Allocated memory increased by 0.00 GB, Cached memory increased by 0.00 GB
2023-09-05 11:29:31.979892 | finish generating embeddings
2023-09-05 11:29:31.988615 | Batch 1503: Allocated memory increased by 0.00 GB, Cached memory increased by 0.00 GB
2023-09-05 11:29:32.842327 | Batch 1503: Allocated memory increased by 0.01 GB, Cached memory increased by 32.58 GB
2023-09-05 11:29:32.922521 | Batch 1503: Allocated memory increased by 0.00 