In [None]:
import torch
torch.load('uniprot_nuc_embeddings.pt')

In [1]:
import pandas as pd
df=pd.read_csv('uniprot_nuc_sequences.csv').drop(['Unnamed: 0'],axis=1)

In [2]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20422 entries, 0 to 20421
Data columns (total 3 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   Entry             20422 non-null  object
 1   EMBL_first_entry  20408 non-null  object
 2   nuc_sequence      20400 non-null  object
dtypes: object(3)
memory usage: 478.8+ KB


## Run model

In [3]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

In [4]:
# Import the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")

In [5]:
model.to('cuda')

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(4105, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1002, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-23): 24 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-12, elementwise_affine=True)
          )
          (int

In [6]:
import datetime
def get_embeddings_batch(tokenizer, model,list_of_sequences, output_path=""):
    beginning = start = datetime.datetime.now()
    i=0
    batch_size = 2
    cls_embeddings = []
    for b in range(0, len(list_of_sequences), batch_size):
        sequence_batch = list_of_sequences[b:b + batch_size]
        tokens_ids = tokenizer.batch_encode_plus(sequence_batch, return_tensors="pt", padding=True)["input_ids"]
    
        attention_mask = tokens_ids != tokenizer.pad_token_id
        torch_outs = model(
            tokens_ids.to('cuda'),
            attention_mask=attention_mask.to('cuda'),
            output_hidden_states=True
        )

        # Save the tensors of the last layer along with the logits
        last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :]  # Extract the first token (CLS token) from each sequence
        # logits = torch_outs.logits

        # torch.save({'last_layer_CLS': last_layer_CLS, 'logits': logits, 'sequences': sequence_batch}, "{}_{}.pt".format(output_path,i))
        
        cls_embeddings.append(last_layer_CLS)
        
        timespan =  datetime.datetime.now() - start
        start = datetime.datetime.now()
    

        # print("Finished with batch ", i, " in time ", timespan)
        i=i+1
        if i % 10 == 0 :
            print("Finished with batch ", i, " in time ", timespan)
#             # Get system memory usage
#             mem = psutil.virtual_memory()

#             # Print total, available and used memory in bytes
#             print("Total memory:", mem.total)
#             print("Available memory:", mem.available)
#             print("Used memory:", mem.used)
#             print("Total time:", datetime.datetime.now()-beginning)

    # Concatenate embeddings for all batches
    final_result = torch.cat(cls_embeddings, dim=0)
    
    # Save the joined result
    torch.save(final_result, output_path)            
            

In [7]:
# def get_cls_embedding_batch(sequences, batch_size):

#     # Tokenize the input texts
#     encoded_inputs = tokenizer(sequences, return_tensors='pt', padding=True).to('cuda')

#     cls_embeddings = []
#     num_samples = len(sequences)

#     # Process inputs in batches
#     for i in range(0, num_samples, batch_size):
#         batch_inputs = {key: val[i:i + batch_size] for key, val in encoded_inputs.items()}

#         # Forward pass through the model
#         outputs = model(**batch_inputs)

#         # Get the hidden states from the model's output
#         last_hidden_state = outputs.hidden_states[-1].detach()

#         # Get the embedding for the CLS token
#         cls_embedding = last_hidden_state[:, 0, :]  # Extract the first token (CLS token) from each sequence

#         cls_embeddings.append(cls_embedding)

#     # Concatenate embeddings for all batches
#     final_result = torch.cat(cls_embeddings, dim=0)
    
#     # Save the joined result
#     torch.save(final_result, 'joined_result.pt')

In [8]:
sequences = df.nuc_sequence.astype(str).tolist()
sequences=[seq[:1000] for seq in sequences]

In [9]:
len(sequences)

20422

In [15]:
# get_cls_embedding_batch(sequences, batch_size=2)

In [10]:
get_embeddings_batch(tokenizer, model, sequences[:10000],'results.pt')

Finished with batch  10  in time  0:00:00.019097
Finished with batch  20  in time  0:00:00.018192
Finished with batch  30  in time  0:00:00.018933
Finished with batch  40  in time  0:00:00.018788
Finished with batch  50  in time  0:00:00.018977
Finished with batch  60  in time  0:00:00.018938
Finished with batch  70  in time  0:00:00.018889
Finished with batch  80  in time  0:00:00.018919
Finished with batch  90  in time  0:00:00.018949
Finished with batch  100  in time  0:00:00.018663
Finished with batch  110  in time  0:00:00.019013
Finished with batch  120  in time  0:00:00.019103
Finished with batch  130  in time  0:00:00.018871
Finished with batch  140  in time  0:00:00.019020
Finished with batch  150  in time  0:00:00.018892
Finished with batch  160  in time  0:00:00.018988
Finished with batch  170  in time  0:00:00.018989
Finished with batch  180  in time  0:00:00.018952
Finished with batch  190  in time  0:00:00.019092
Finished with batch  200  in time  0:00:00.019106
Finished 

In [None]:
# get_embeddings_batch(tokenizer, model, sequences[:10000],"NT_results/torch_outs_")

In [10]:
get_embeddings_batch(tokenizer, model, sequences[10000:],'results2.pt')

Finished with batch  10  in time  0:00:00.018805
Finished with batch  20  in time  0:00:00.018797
Finished with batch  30  in time  0:00:00.018922
Finished with batch  40  in time  0:00:00.018841
Finished with batch  50  in time  0:00:00.018736
Finished with batch  60  in time  0:00:00.018902
Finished with batch  70  in time  0:00:00.018858
Finished with batch  80  in time  0:00:00.018996
Finished with batch  90  in time  0:00:00.018866
Finished with batch  100  in time  0:00:00.018915
Finished with batch  110  in time  0:00:00.018897
Finished with batch  120  in time  0:00:00.018884
Finished with batch  130  in time  0:00:00.018986
Finished with batch  140  in time  0:00:00.019059
Finished with batch  150  in time  0:00:00.019150
Finished with batch  160  in time  0:00:00.019099
Finished with batch  170  in time  0:00:00.019311
Finished with batch  180  in time  0:00:00.019126
Finished with batch  190  in time  0:00:00.019017
Finished with batch  200  in time  0:00:00.018767
Finished 

In [14]:
# List of file paths to join
file_paths = ['results.pt', 'results2.pt']

# Load and concatenate the results
results = []
for file_path in file_paths:
    result = torch.load(file_path)
    results.append(result)

# Concatenate the results along the desired dimension
final_result = torch.cat(results, dim=0)  # Adjust the dimension according to your data

# Save the joined result
torch.save(final_result, 'uniprot_nuc_embeddings.pt')

In [15]:
len(torch.load('uniprot_nuc_embeddings.pt'))

20422

In [3]:
# import os
# def load_torch_results(tokenizer, path, subset=""):
#     # Define the folder path where the torch files are located
#     folder_path = path
    
#     # Get a list of all the files in the folder
#     files = [f for f in os.listdir(folder_path) if ".pt" in f]
#     if subset:
#         files = [f for f in files if subset in f]
    
#     # Create an empty list to store the loaded tensors
#     results = []
#     sequences = []

#     # Iterate over each file
#     for file in files:
#         # Load the dict from the file
#         tensor_dict = torch.load(os.path.join(folder_path, file))
        
        
        
# #         sequence_batch = tensor_dict['sequences']
# #         tokens_ids = tokenizer.batch_encode_plus(sequence_batch, return_tensors="pt", padding=True)["input_ids"]
# #         attention_mask = tokens_ids != tokenizer.pad_token_id
        
#         embeddings = tensor_dict['last_layer_tensors'].cpu().numpy()
# #         # sum_total=torch.sum(attention_mask.unsqueeze(-1)*embeddings, axis=-2)
# #         # count_total=torch.sum(attention_mask, axis=-1).unsqueeze(1)
# #         # mean_sequence_embeddings = sum_total/count_total
# #         # Convert the dictionary into a list of dictionaries
# #         # list_of_dicts = [dict(zip(tensor_dict.keys(), values)) for values in zip(*tensor_dict.values())]
# #         # df = pd.DataFrame({'mean_emb': mean_sequence_embeddings.tolist()})
# #         # Append the dict to the list
# #         # results.append(mean_sequence_embeddings.tolist())
#         cls_embeddings = embeddings[:, 0, :]  # Select the embeddings for the CLS token
#         results.append(cls_embeddings)
# #         results.append(cls_embeddings.tolist())
# #         sequences.append(sequence_batch)


#     # Concatenate the results along the desired dimension
#     final_result = torch.cat(results, dim=0)  # Adjust the dimension according to your data

#     # Save the joined result
#     torch.save(final_result, 'joined_result.pt')
#     return results, sequences


In [None]:
# emb, seq = load_torch_results(tokenizer, "NT_results/")

In [48]:
sequences[0][:10]

'CACGAGGGGA'

In [49]:
emb[0][0][0],seq[0][0][:10]

(-0.7393954396247864, 'CACGAGGGGA')

In [50]:
embs = [item for sublist in emb for item in sublist]
seqs = [item for sublist in seq for item in sublist]


In [51]:
# Create a dictionary to map elements in seq to their embeddings in emb
emb_dict = {s: e for s, e in zip(seqs, embs)}

# Get the embeddings in the order of sequences
ordered_emb = [emb_dict[s] for s in sequences]
ordered_seq = [seqs[sequences.index(s)] for s in sequences]

In [52]:
ordered_emb[0][0],ordered_seq[0][:10]

(-0.7393954396247864, 'CACGAGGGGA')

In [53]:
df['nuc_embeddings']=ordered_emb

In [54]:
df

Unnamed: 0.1,Unnamed: 0,Entry,EMBL_first_entry,nuc_sequence,nuc_embeddings
0,0,A0A087X1C5,AY220845,CACGAGGGGAAGGGTCACGCGCTCGGTGTGCTGAGAGTGTCCTGCC...,"[-0.7393954396247864, 0.12663324177265167, -0...."
1,1,A0A0B4J2F0,AB593170,ATCCATTAAGTTTGGCCTTTGAGAGCAGTCGTCGCTCGCAAGCCCG...,"[-0.704558253288269, 0.08113979548215866, -0.3..."
2,2,A0A0B4J2F2,CU639417,GATCTGTTGGTGGTTCCCTCGGCTTTGGACCTAGTCGCTCTGATTC...,"[-0.6787762641906738, 0.11075711250305176, -0...."
3,3,A0A0C5B5G6,KP715230,ATGAGGTGGCAAGAAATGGGCTACATTTTCTACCCCAGAAAACTAC...,"[-0.8611523509025574, -0.16802263259887695, -0..."
4,4,A0A0K2S4Q6,LC013475,ATGACCCAGAGGGCTGGGGCTGCCATGCTGCCTTCAGCTCTGCTCC...,"[-0.7674009203910828, 0.13942335546016693, -0...."
...,...,...,...,...,...
6523,6841,Q16322,U96110,TCCCCTAGAATGGATGTGTGTGGCTGGAAAGAAATGGAGGTTGCGC...,"[-0.7365144491195679, 0.1813833862543106, -0.3..."
6524,6842,Q16342,S78085,GCTGCGCCCCACGCCAGCCCGCGCCCCGCATGGCTGCCGCCGGGGC...,"[-0.7676622271537781, 0.0987539067864418, -0.3..."
6525,6843,Q16348,S78203,CGAGGAGAGAGAGAGAGTAAGGAGCCAGCCATGAATCCTTTCCAGA...,"[-0.7234958410263062, 0.12948516011238098, -0...."
6526,6844,Q16352,S78296,TGTAGCTCGCGTTGAAGCCGCACGTCCGGCCCCGATCCCGGCACCA...,"[-0.7770000100135803, 0.13071982562541962, -0...."


In [55]:
df.to_csv('uniprot_nuc_embeddings.csv')

In [56]:
df

Unnamed: 0.1,Unnamed: 0,Entry,EMBL_first_entry,nuc_sequence,nuc_embeddings
0,0,A0A087X1C5,AY220845,CACGAGGGGAAGGGTCACGCGCTCGGTGTGCTGAGAGTGTCCTGCC...,"[-0.7393954396247864, 0.12663324177265167, -0...."
1,1,A0A0B4J2F0,AB593170,ATCCATTAAGTTTGGCCTTTGAGAGCAGTCGTCGCTCGCAAGCCCG...,"[-0.704558253288269, 0.08113979548215866, -0.3..."
2,2,A0A0B4J2F2,CU639417,GATCTGTTGGTGGTTCCCTCGGCTTTGGACCTAGTCGCTCTGATTC...,"[-0.6787762641906738, 0.11075711250305176, -0...."
3,3,A0A0C5B5G6,KP715230,ATGAGGTGGCAAGAAATGGGCTACATTTTCTACCCCAGAAAACTAC...,"[-0.8611523509025574, -0.16802263259887695, -0..."
4,4,A0A0K2S4Q6,LC013475,ATGACCCAGAGGGCTGGGGCTGCCATGCTGCCTTCAGCTCTGCTCC...,"[-0.7674009203910828, 0.13942335546016693, -0...."
...,...,...,...,...,...
6523,6841,Q16322,U96110,TCCCCTAGAATGGATGTGTGTGGCTGGAAAGAAATGGAGGTTGCGC...,"[-0.7365144491195679, 0.1813833862543106, -0.3..."
6524,6842,Q16342,S78085,GCTGCGCCCCACGCCAGCCCGCGCCCCGCATGGCTGCCGCCGGGGC...,"[-0.7676622271537781, 0.0987539067864418, -0.3..."
6525,6843,Q16348,S78203,CGAGGAGAGAGAGAGAGTAAGGAGCCAGCCATGAATCCTTTCCAGA...,"[-0.7234958410263062, 0.12948516011238098, -0...."
6526,6844,Q16352,S78296,TGTAGCTCGCGTTGAAGCCGCACGTCCGGCCCCGATCCCGGCACCA...,"[-0.7770000100135803, 0.13071982562541962, -0...."
