In [None]:
!pip install ankh

In [3]:
import torch
import time
import numpy as np
import gc


In [None]:
import ankh

# Check if a GPU is available and if not, use a CPU
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, tokenizer = ankh.load_large_model()

# Move the model to the GPU
#model = model.to(device)
model.eval()

In [3]:
from Bio import SeqIO

# Define your FASTA file
fasta_file = "/kaggle/input/cafa-5-protein-function-prediction/Test (Targets)/testsuperset.fasta"

# Create empty lists to hold the sequences and sequence names
sequence_names = []
seqs = []

# Parse the FASTA file
for record in SeqIO.parse(fasta_file, "fasta"):
    # Append the sequence name to the list
    sequence_names.append(record.id)
    
    # Append the sequence to the list
    seqs.append(str(record.seq))

print(len(sequence_names), len(seqs))

141865 141865


In [None]:
import os
import numpy as np
import torch

# # Ensure the tensors are on the GPU
# outputs['input_ids'] = outputs['input_ids'].to(device)
# outputs['attention_mask'] = outputs['attention_mask'].to(device)

# #Manually write start index
start_index = 45000

#Manually write end index
#end_index = len(seqs)
end_index = 48500

checkpoint_file = "ankh_checkpoint.txt"
output_file = f"Ankh_test_embeddings_{start_index}_{end_index}.npy"
label_output_file = f"Ankh_test_labels_{start_index}_{end_index}.npy"
batch_size = 50

# Initialize an empty list to hold the embeddings
embeddings_list = []
label_list = []

for i in range(start_index, end_index):
    # Tokenize and generate embeddings
    protein_sequence = [list(seqs[i])]
    #We set the max length to 3000 to prevent out of memory errors for some large proteins (e.g A2ASS6)
    protein_sequence[0]=protein_sequence[0][:3000]
    
    outputs = tokenizer.batch_encode_plus(protein_sequence,
                                          add_special_tokens=True,
                                          padding=True,
                                          is_split_into_words=True,
                                          return_tensors="pt")
    with torch.no_grad():
        embeddings = model(input_ids=outputs['input_ids'], attention_mask=outputs['attention_mask'])
    
    # Extract the last_hidden_state tensor
    last_hidden_state_tensor = embeddings.last_hidden_state

    # Convert the tensor to numpy array and average it
    average_embedding_np = np.mean(last_hidden_state_tensor[0].numpy(), axis=0)
    embeddings_list.append(average_embedding_np)
    
    label_list.append(sequence_names[i])

    #these should match
    print(len(protein_sequence[0])+1, i+1, last_hidden_state_tensor.shape, sequence_names[i], 'appended')
    
    # If we have processed batch_size sequences or this is the last sequence, save the embeddings
    if (i+1) % batch_size == 0 or i == len(seqs) - 1:
        # Load the previous embeddings if exist
        if os.path.exists(output_file):
            prev_embeddings = np.load(output_file, allow_pickle=True)
            embeddings_list = np.concatenate((prev_embeddings, embeddings_list))
        
        if os.path.exists(label_output_file):
            prev_labels = np.load(label_output_file, allow_pickle=True)
            label_list = np.concatenate((prev_labels, label_list))
        
        # Save the embeddings to a file
        np.save(output_file, embeddings_list)
        print(f"Embeddings saved to: {output_file}, {i}")

        # Save the labels as well
        np.save(label_output_file, label_list)
        print(f"Protein labels saved to: {label_output_file}, {i}")
        
        # Clear the embeddings_list for the next batch
        embeddings_list = []
        label_list = []

        # Write checkpoint
        with open(checkpoint_file, 'w') as file:
            file.write(str(i+1)) # save the index for the next sequence
        gc.collect()


In [14]:
a1 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_0_800.npy')
a2 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-800-3850.npy')
a3 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_3850_7000.npy')
a4 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_7000_9650.npy')
a5 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_9650_10000.npy')
a6 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-10000-13000.npy')
a7 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_13000_15150.npy')
a8 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_15150_16500.npy')
a9 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_16500_19450.npy')
a10 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_19450_20000.npy')
a11 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-20000-23200.npy')
a12 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_23200_25000.npy')
a13 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-25000-30000.npy')
a14 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-30000-34200.npy')
a15 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_34200_35000.npy')
a16 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-35000-39950.npy')
a17 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_39950_40000.npy')
a18 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-40000-43800.npy')
a19 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_43800_45000.npy')
#ok - 45000
#a20 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings-45000-48500.npy')
#not ok - 48450
a21 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_48500_52000.npy')
a22 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_52000_55000.npy')
a23 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_55000_58000.npy')
a24 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_58000_61000.npy')
a25 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_61000_64000.npy')
a26 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_64000_67000.npy')
a27 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_67000_70000.npy')
a28 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_70000_72500.npy')
a29 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_72500_73000.npy')
a30 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_73000_76000.npy')
a31 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_76000_79000.npy')
a32 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_79000_81400.npy')
a33 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_81400_82000.npy')
a34 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_82000_84200.npy')
a35 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_84200_85000.npy')
a36 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_85000_87400.npy')
a37 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_87400_88000.npy')
a38 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_88000_91000.npy')
a39 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_91000_94000.npy')
a40 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_94000_97000.npy')
a41 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_97000_100000.npy')
a42 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_100000_103000.npy')
a43 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_103000_106000.npy')
a44 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_106000_109000.npy')
a45 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_109000_112000.npy')
a46 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_112000_114600.npy')
a47 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_114600_115000.npy')
a48 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_115000_117600.npy')
a49 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_117600_118000.npy')
a50 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_118000_121000.npy')
a51 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_121000_121800.npy')
#a51 = np.load('') #121.8-124
a53 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_124000_127000.npy')
a54 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_127000_128800.npy')
a55 = np.load('') #128.8-130
a56 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_130000_133000.npy')
a57 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_133000_136000.npy')
a58 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_136000_139000.npy')
a59 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_embeddings_139000_141865.npy')


embed_names = ["a" + str(i) for i in range(1, 52) if i != 20]
embeds = np.concatenate([globals()[var_name] for var_name in embed_names], axis=0)

print(len(embeds))

118300


In [11]:
# b1 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_0_800.npy')
# b2 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels-800-3850.npy')
# b3 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_3850_7000.npy')
# b4 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_7000_9650.npy')
# b5 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_9650_10000.npy')
# b6 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels-10000-13000.npy')
# b7 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_13000_15150.npy')
# #b8 = np.load('')
# b9 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_16500_19450.npy')
# b10 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels_19450_20000.npy')
# b11 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels-20000-23200.npy')
# #b12 = np.load('')
# b13 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels-25000-30000.npy')
# b14 = np.load('/kaggle/input/ankh-embeddings-test/Ankh_test_labels-30000-34200.npy')



# label_names = ["b" + str(i) for i in range(1, 15) if (i != 8 and i != 12)]
# labels = np.concatenate([globals()[var_name] for var_name in label_names], axis=0)

In [12]:
# print(labels.shape, embeds.shape)

(31050,) (31050, 1536)


In [34]:
# # Concatenate the arrays along the rows (axis=0)
#  # Corrected range
# embeds = np.concatenate([locals()[var_name] for var_name in embed_names], axis=0)

KeyError: 'a2'