# Imports

In [None]:
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, SamplingConfig

import pandas as pd
from tqdm import tqdm
import numpy as np
import torch
import os

# Load Data

In [20]:
p2cs_filtered_groups = pd.read_pickle("p2cs_filtered_groups.pkl")

In [3]:
p2cs_filtered_groups

Unnamed: 0,Gene,Start,Stop,Strand,Original description,class,type,P2CS description,tm,file_name,db_id,organism,Gene_num,frame,proximity_group,tcs_organization,tcs_organization_int,proximity_group_size,aa_sequence,nt_sequence
0,Asuc_0781,841792,842496,+,two component transcriptional regulator,RR,OmpR,"Response regulator, OmpR family contains 1 Res...",0,ActsuDB_Actinobacillus_succinogenes_130Z,ActsuDB,Actinobacillus succinogenes 130Z,781,1,2,Pair,2.0,2,MTKILLVDDDIELTDLLGELLSLEGFDVVTAQNGLEALEKLDDGIC...,ATGACGAAAATTTTATTAGTCGATGACGATATCGAGTTAACTGATC...
1,Asuc_0782,842506,843927,+,two-component sensor protein,HK,Classic,"Histidine kinase, Classic contains 1 HAMP,1 Hi...",3,ActsuDB_Actinobacillus_succinogenes_130Z,ActsuDB,Actinobacillus succinogenes 130Z,782,1,2,Pair,2.0,2,LFPFLQRINRLPVQLLASFWLVIFTTLSITFVLLHFLDSHRPEKLE...,TTGTTTCCCTTTTTGCAACGCATTAACCGCCTTCCTGTTCAGCTGC...
2,Asuc_1363,1517270,1517950,+,two component transcriptional regulator,RR,OmpR,"Response regulator, OmpR family contains 1 Res...",0,ActsuDB_Actinobacillus_succinogenes_130Z,ActsuDB,Actinobacillus succinogenes 130Z,1363,2,5,Pair,2.0,2,MRVLLIEDDPLIGNGLNIGLTKSGFSVDWFTDGKTGLEAVKSAPYD...,ATGAGAGTTTTATTAATCGAAGACGATCCGTTAATCGGTAACGGTT...
3,Asuc_1364,1517925,1519298,+,hypothetical protein,HK,Classic,"Histidine kinase, Classic contains 1 2CSK_N,1 ...",2,ActsuDB_Actinobacillus_succinogenes_130Z,ActsuDB,Actinobacillus succinogenes 130Z,1364,3,5,Pair,2.0,2,MMKLLKRRSLRFRLIVILSLAALVIWSMATAIAWFQAKNEVNKMFD...,ATGATGAAACTCCTGAAACGACGTAGTTTGCGTTTTCGTTTGATTG...
4,Asuc_1720,1888650,1890632,+,hypothetical protein,HK,Classic,"Histidine kinase, Classic contains 1 HisKA,1 H...",3,ActsuDB_Actinobacillus_succinogenes_130Z,ActsuDB,Actinobacillus succinogenes 130Z,1720,3,6,Pair,2.0,2,MRKWINSLNISRGLQLSFWLSALLCLFVGGLGLLTWQQQRAEINIA...,ATGAGAAAATGGATTAACAGCTTAAATATCAGCCGTGGGTTACAGC...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
88146,Glov_3523,3789002,3790093,+,response regulator receiver modulated metal ...,RR,RpfG,"Response regulator, RpfG family contains 1 Res...",0,GeoloDB_Geobacter_lovleyi_SZ,GeoloDB,Geobacter lovleyi SZ,3523,2,103623,Pair,2.0,2,MLMFPKEHPAYTVLLVDDNPENIRLLDEALRDEYTIKVATRGEKAL...,ATGCTGATGTTCCCAAAGGAGCACCCCGCCTATACCGTTCTCCTGG...
88147,Glov_3526,3793421,3795175,+,integral membrane sensor signal transduction ...,HK,Classic,"Histidine kinase, Classic contains 1 SBP_bac_3...",3,GeoloDB_Geobacter_lovleyi_SZ,GeoloDB,Geobacter lovleyi SZ,3526,2,103624,Pair,2.0,2,VLRPATLCIALFFCLLLPFSALARPIIVGGDRDYPPYEFLDPNGKP...,GTGCTGCGTCCAGCCACCCTTTGCATAGCCCTGTTTTTCTGCCTGT...
88148,Glov_3527,3795129,3796535,+,Fis family two component sigma-54 specific ...,RR,NtrC,"Response regulator, NtrC family contains 1 Res...",0,GeoloDB_Geobacter_lovleyi_SZ,GeoloDB,Geobacter lovleyi SZ,3527,3,103624,Pair,2.0,2,MSGTLFPAFGILLVDDEPAWLRSVSLTLERSAGITNIRSCSDSRQV...,ATGTCTGGAACCCTGTTTCCCGCCTTTGGTATCCTGCTGGTGGATG...
88149,Glov_3543,3812663,3813874,+,signal transduction histidine kinase LytS,HK,Classic,"Histidine kinase, Classic contains 1 His_kinas...",7,GeoloDB_Geobacter_lovleyi_SZ,GeoloDB,Geobacter lovleyi SZ,3543,2,103625,Pair,2.0,2,MGLFLDLFERLGLFAILFIFLIRFKAFKRLLTGIASRRDKLVLAFM...,ATGGGACTCTTTCTTGATCTGTTCGAACGCCTCGGCCTGTTTGCCA...


# Load Model

In [None]:
# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
# Replace YOUR_HUGGINGFACE_TOKEN with your actual token
login(token="YOUR_HUGGINGFACE_TOKEN")

In [None]:
# This will download the model weights and instantiate the model on your machine.
model: ESM3InferenceClient = ESM3.from_pretrained("esm3-open").to("cpu") # or "cuda" for GPU

# Embed

In [None]:
# Compute embeddings for each aa_sequence and save as a numpy array
def get_embedding(seq):
    with torch.no_grad():
        # Clear cache before processing
        torch.cuda.empty_cache()
        
        # Truncate very long sequences to avoid memory issues
        max_length = 1000  # Adjust based on your GPU memory
        if len(seq) > max_length:
            seq = seq[:max_length]
        
        protein = ESMProtein(sequence=seq)
        protein_tensor = model.encode(protein)
        
        # Use mixed precision to reduce memory usage
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            result = model.forward_and_sample(
                protein_tensor, SamplingConfig(return_per_residue_embeddings=True)
            )
        
        embedding = result.per_residue_embedding
        # Take mean over residues for a fixed-size embedding and immediately move to CPU
        mean_embedding = embedding.mean(dim=0).detach().cpu().numpy()
        
        # Explicitly delete tensors and clear cache
        del protein, protein_tensor, result, embedding
        torch.cuda.empty_cache()
        
        return mean_embedding

embeddings_file = "p2cs_filtered_groups_mean_embeddings.npy"
seqs = p2cs_filtered_groups['aa_sequence'].tolist()

# Initialize the embeddings file if it doesn't exist
if not os.path.exists(embeddings_file):
    np.save(embeddings_file, np.empty((0, 1536)))  # 1536 is the embedding size for ESM-3

embeddings_list = []
SAVE_INTERVAL = 100  # Reduced from 500 to save more frequently
PBAR_UPDATE_INTERVAL = 25  # Reduced update interval

# Set PyTorch memory allocation settings for better memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

with tqdm(total=len(seqs), desc="Embedding sequences") as pbar:
    for i, seq in enumerate(seqs):
        try:
            embedding = get_embedding(seq)
            embeddings_list.append(embedding)
            
            # More aggressive cleanup
            del embedding
            if (i + 1) % 10 == 0:  # Clear cache every 10 iterations
                torch.cuda.empty_cache()
                
            if (i + 1) % SAVE_INTERVAL == 0 or (i + 1) == len(seqs):
                # Load current embeddings from file
                current_embeddings = np.load(embeddings_file)
                # Append new embeddings
                new_embeddings = np.vstack(embeddings_list)
                current_embeddings = np.vstack([current_embeddings, new_embeddings])
                # Save back to file
                np.save(embeddings_file, current_embeddings)
                embeddings_list = []
                
                # Clear variables to free memory
                del current_embeddings, new_embeddings
                torch.cuda.empty_cache()
                
            if (i + 1) % PBAR_UPDATE_INTERVAL == 0 or (i + 1) == len(seqs):
                pbar.update(PBAR_UPDATE_INTERVAL if (i + 1) % PBAR_UPDATE_INTERVAL == 0 else (len(seqs) % PBAR_UPDATE_INTERVAL))
                
        except torch.cuda.OutOfMemoryError:
            print(f"CUDA OOM at sequence {i}. Clearing cache and retrying...")
            torch.cuda.empty_cache()
            # Try with a shorter sequence if it's very long
            if len(seq) > 500:
                seq_truncated = seq[:500]
                embedding = get_embedding(seq_truncated)
                embeddings_list.append(embedding)
                del embedding
            else:
                print(f"Skipping sequence {i} due to memory constraints")
                # Add a zero embedding as placeholder
                embeddings_list.append(np.zeros(1536))
                continue