In [1]:
# allow proper pathing for project
import sys
from pathlib import Path

# get absolute path to project root
ROOT = Path(__file__).resolve().parents[1] if "__file__" in locals() else Path.cwd().parents[0]
sys.path.append(str(ROOT))

print("Added to sys.path:", ROOT)

Added to sys.path: /Users/vnagpal/Desktop/fa-2025/cse-598-ai4sci/gen-cov-abm


In [None]:
from src.utils.path_utils import get_data_dir
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
import esm

✓ ESM imported successfully


In [None]:
# verify ESM installation
print("ESM pretrained module available:", hasattr(esm, 'pretrained'))
print("ESM2 model available:", hasattr(esm.pretrained, 'esm2_t33_650M_UR50D') if hasattr(esm, 'pretrained') else False)

✓ ESM pretrained module available: True
✓ ESM2 model available: True


In [4]:
# Load Data
seq_file_path = get_data_dir() / "ma_sequences.csv"
seq_df = pd.read_csv(seq_file_path)

In [5]:
# Load model
# Load pretrained ESM-2 model
print("Loading ESM-2 model (this may take a while on first run)...")
print("NOTE: If you get an error about 'esm.pretrained', restart the kernel!")

# Load ESM-2 model using fair-esm 2.0+ API
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

model.eval()  # Disable dropout

device = torch.device("mps" if torch.backends.mps.is_available() else 'cpu')
model = model.to(device)
print(f"✓ Model loaded successfully")
print(f"✓ Using device: {device}")

Loading ESM-2 model (this may take a while on first run)...
NOTE: If you get an error about 'esm.pretrained', restart the kernel!
✓ Model loaded successfully
✓ Using device: mps


In [None]:
# Helper function to remove stop codons
def remove_stop_codon(sequence):
    if sequence[-1] == "*":
        return sequence[:-1]
    return sequence


# Function to generate embeddings for a given sequence column
def generate_embeddings(seq_df, seq_column, batch_size=32):
    """
    Generate embeddings for sequences in a specified column.
    
    Args:
        seq_df: DataFrame containing sequences
        seq_column: Name of the column containing sequences
        batch_size: Number of sequences to process per batch
    
    Returns:
        embeddings_array: NumPy array of embeddings
        ids: List of sequence IDs
    """
    all_embeddings = []
    all_ids = []
    
    # Clean sequences
    seq_df_clean = seq_df.copy()
    seq_df_clean[seq_column] = seq_df_clean[seq_column].apply(remove_stop_codon)
    
    print(f"Processing {len(seq_df_clean)} sequences from '{seq_column}' column...")
    
    for i in tqdm(range(0, len(seq_df_clean), batch_size)):
        batch_df = seq_df_clean.iloc[i : i + batch_size]
        
        # Prepare batch data
        data = [(row['name'], row[seq_column]) for _, row in batch_df.iterrows()]
        
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)
        
        # Get embeddings
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
            token_representations = results["representations"][33]
        
        # Average across residues (excluding special tokens)
        sequence_repr = token_representations.mean(1)
        
        # Store results
        all_embeddings.append(sequence_repr.cpu().numpy())
        all_ids.extend(batch_labels)
    
    # Concatenate all embeddings
    embeddings_array = np.vstack(all_embeddings)
    print(f"Final embeddings shape for '{seq_column}': {embeddings_array.shape}")
    
    return embeddings_array, all_ids


# Generate embeddings for both sequence types
batch_size = 32  # Adjust based on GPU memory

print("Generating N-sequence embeddings...")
n_embeddings, n_ids = generate_embeddings(seq_df, "n_sequence", batch_size)

print("Generating S-sequence embeddings...")
s_embeddings, s_ids = generate_embeddings(seq_df, "s_sequence", batch_size)

Generating N-sequence embeddings...
Processing 4070 sequences from 'n_sequence' column...


100%|██████████| 128/128 [22:07<00:00, 10.37s/it]


Final embeddings shape for 'n_sequence': (4070, 1280)

Generating S-sequence embeddings...
Processing 4070 sequences from 's_sequence' column...


  5%|▌         | 7/128 [2:24:54<39:17:33, 1169.04s/it]

In [None]:
# Verify embeddings shapes
print(f"N-sequence embeddings shape: {n_embeddings.shape}")
print(f"S-sequence embeddings shape: {s_embeddings.shape}")
print(f"Number of sequences: {len(n_ids)}")

In [None]:
# Save embeddings to disk
output_dir = get_data_dir()

print("Saving N-sequence embeddings...")
# Save N-sequence embeddings as numpy array
np.save(output_dir / "n_sequence_embeddings.npy", n_embeddings)
print(f"  ✓ Saved to {output_dir / 'n_sequence_embeddings.npy'}")

# Save N-sequence IDs
with open(output_dir / "n_sequence_ids.txt", "w") as f:
    f.write("\n".join(n_ids))
print(f"  ✓ Saved IDs to {output_dir / 'n_sequence_ids.txt'}")

# Save N-sequence as CSV
n_embedding_df = pd.DataFrame(n_embeddings, index=n_ids)
n_embedding_df.to_csv(output_dir / "n_sequence_embeddings.csv")
print(f"  ✓ Saved to {output_dir / 'n_sequence_embeddings.csv'}")

print("\nSaving S-sequence embeddings...")
# Save S-sequence embeddings as numpy array
np.save(output_dir / "s_sequence_embeddings.npy", s_embeddings)
print(f"  ✓ Saved to {output_dir / 's_sequence_embeddings.npy'}")

# Save S-sequence IDs
with open(output_dir / "s_sequence_ids.txt", "w") as f:
    f.write("\n".join(s_ids))
print(f"  ✓ Saved IDs to {output_dir / 's_sequence_ids.txt'}")

# Save S-sequence as CSV
s_embedding_df = pd.DataFrame(s_embeddings, index=s_ids)
s_embedding_df.to_csv(output_dir / "s_sequence_embeddings.csv")
print(f"  ✓ Saved to {output_dir / 's_sequence_embeddings.csv'}")

print("\n" + "=" * 60)
print("All embeddings saved successfully!")
print("=" * 60)
print("\nTo load the embeddings later:")
print("  # N-sequence embeddings")
print(f"  n_embeddings = np.load('{output_dir / 'n_sequence_embeddings.npy'}')")
print(f"  with open('{output_dir / 'n_sequence_ids.txt'}') as f: n_ids = f.read().splitlines()")
print("\n  # S-sequence embeddings")
print(f"  s_embeddings = np.load('{output_dir / 's_sequence_embeddings.npy'}')")
print(f"  with open('{output_dir / 's_sequence_ids.txt'}') as f: s_ids = f.read().splitlines()")

In [None]:
# Example: Load both sets of embeddings
output_dir = get_data_dir()

# Load N-sequence embeddings
n_embeddings_loaded = np.load(output_dir / 'n_sequence_embeddings.npy')
with open(output_dir / 'n_sequence_ids.txt') as f:
    n_ids_loaded = f.read().splitlines()

# Load S-sequence embeddings
s_embeddings_loaded = np.load(output_dir / 's_sequence_embeddings.npy')
with open(output_dir / 's_sequence_ids.txt') as f:
    s_ids_loaded = f.read().splitlines()

print(f"N-sequence embeddings shape: {n_embeddings_loaded.shape}")
print(f"S-sequence embeddings shape: {s_embeddings_loaded.shape}")
print(f"Number of N-sequence IDs: {len(n_ids_loaded)}")
print(f"Number of S-sequence IDs: {len(s_ids_loaded)}")