## Load packages and Set Working Directory

In [15]:
import numpy as np
import pandas as pd
import os
import torch

parent_dir = os.path.dirname(os.getcwd())
data_root = os.path.join(parent_dir, "ZooTransform")
os.chdir(data_root)

from src.model.species_model import SpeciesAwareESM2

## Load and Use SpeciesAwareESM2 Model

In [ ]:
# Load the pre-trained SpeciesAwareESM2 model
model = SpeciesAwareESM2(model_name="facebook/esm2_t6_8M_UR50D", species_list=["human", "mouse", "ecoli"])

## Load Data

In [ ]:
# Example data for embedding - need to replace with our actual data
data = pd.DataFrame({
    "species": ["human", "mouse", "ecoli"],
    "sequence": [
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ"
    ]
})

## Generate Embeddings for Sequences with Species Information, with mean pooling
slower , not recommended for a large dataset

In [ ]:
# Generate embeddings for each sequence manually
embeddings = []

for _, row in data.iterrows():
    emb = model.embed(row['species'], row['sequence'])
    emb_mean = emb.mean(dim=1).squeeze().cpu().numpy() # Mean pooling over sequence length, can decide to use different pooling
    embeddings.append(emb_mean)

embeddings = np.vstack(embeddings)
print("Embeddings array shape:", embeddings.shape)

## Generate Embeddings for *a Batch of Sequences* with Species Information, with mean pooling

In [ ]:
species_batch = data['species'].tolist()
sequence_batch = data['sequence'].tolist()

with torch.no_grad():
    outputs = model.forward(species_batch, sequence_batch)

batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
print(batch_embeddings.shape)