In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import DataLoader
from pytorch_metric_learning.losses import SelfSupervisedLoss, NTXentLoss

from proteinbind_new import EmbeddingDataset
from proteinbind_new import create_proteinbind
from transformers import pipeline

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    
#model_name = 'facebook/esm2_t36_3B_UR50D'

In [None]:
# read in data
seq = pd.read_csv('Uniprot-extracted_comments_and_GO_terms.csv')
# only want accession and sequence
seq = seq[[seq.columns[1], seq.columns[3]]]


In [None]:
cutoff_value = 'Q8WWT9'

# Find the index of the first occurrence of the cutoff value
cutoff_index = seq[seq['primaryAccession'] == cutoff_value].index[0]

# Filter the DataFrame to include only the rows after the cutoff index
filtered_seq = seq[cutoff_index + 1:]

In [None]:
# take apart batch and save individual tensors
def save_tensors(batch, filenames):
    folder_path = 'GO_AA_Embeddings'
    os.makedirs(folder_path, exist_ok=True)  # Create the folder if it doesn't exist
    
    # save each tensor in batch individually
    for i in range(len(filenames)):
        filename = filenames[i]
        tensor = batch[i]
        file_path = os.path.join(folder_path, f'{filename}.pt')
        torch.save(tensor, file_path)

# converts entire dataset of sequences into ESM embeddings where the 
# filename of each embedding is the corresponding Accession 
batch_size = 32
num_batches = (len(filtered_seq) + batch_size - 1) // batch_size  # Round up to the nearest whole number

# iterate over batches of sequences
for batch_index in range(num_batches):
    start_index = batch_index * batch_size
    end_index = min((batch_index + 1) * batch_size, len(filtered_seq))
    batch_df = filtered_seq[start_index:end_index]
    sequences = batch_df['sequence'].tolist()
    filenames = batch_df['primaryAccession'].tolist()

    # Use the feature extraction pipeline to extract features from each batch of sequences
    go_AA = extractor(sequences, return_tensors=True)
    
    # reformat the list into a nice tensor
    AA_emb = []
    for ii in range(batch_size):
        AA_emb.append(go_AA[ii][0,0,:])
    go_AA = torch.stack(AA_emb, dim=0)


    # Save each tensor individually using the corresponding Accession
    save_tensors(go_AA, filenames)
    del go_AA