In [None]:
import torch
import esm
import pickle
import os

# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model and alphabet
print("Loading ESM2-150M model...")
model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
model = model.to(device)
model.eval()
print("Model loaded successfully!")

# Path to embeddings file
EMB_PATH = "embeddings_150M.pkl"

# Load existing embeddings if available
if os.path.exists(EMB_PATH):
    with open(EMB_PATH, "rb") as f:
        embedding_dict = pickle.load(f)
    print(f"✅ Loaded {len(embedding_dict)} cached embeddings from {EMB_PATH}")
else:
    embedding_dict = {}
    print("⚠️ No existing embeddings found — starting fresh")

def get_embedding(sequence: str, seq_id: str):
    """Return embedding for a protein sequence, with caching."""
    if seq_id in embedding_dict:
        print(f"ℹ️ Found cached embedding for {seq_id}")
        return embedding_dict[seq_id]

    # Prepare batch
    data = [(seq_id, sequence)]
    _, _, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)

    # Generate embedding
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[30])
        token_representations = results["representations"][30]

    # Mean pooling
    embedding = token_representations.mean(1).squeeze().cpu()

    # Save to cache
    embedding_dict[seq_id] = embedding
    with open(EMB_PATH, "wb") as f:
        pickle.dump(embedding_dict, f)

    print(f"✅ Generated and cached embedding for {seq_id}")
    
    return embedding


if __name__ == "__main__":
    # Example sequence
    seq_id = "seq_001"
    seq = "MKTFFVLVLLLAAAGVAGTQATQGNVKAAW"

    # If embeddings already exist for this seq, it will skip recomputation
    emb = get_embedding(seq, seq_id)

    print(f"\nEmbedding shape: {emb.shape}")
    print(f"First 10 elements: {emb[:10]}")



Using device: cpu
Loading ESM2-150M model...
Model loaded successfully!
✅ Loaded 1 cached embeddings from embeddings_150M.pkl
ℹ️ Found cached embedding for seq_001

Embedding shape: torch.Size([640])
First 10 elements: tensor([-0.0012, -0.0170,  0.0506, -0.0911, -0.0596, -0.0686, -0.1527,  0.0025,
         0.0370,  0.1412])
