## Load packages and Set Working Directory

In [23]:
import numpy as np
import pandas as pd
import os
import torch

parent_dir = os.path.dirname(os.getcwd())
data_root = os.path.join(parent_dir, "ZooTransform")
os.chdir(data_root)

from src.model.species_model import SpeciesAwareESM2

## Load and Use SpeciesAwareESM2 Model

In [24]:
# Load the pre-trained SpeciesAwareESM2 model
model = SpeciesAwareESM2(model_name="facebook/esm2_t6_8M_UR50D", species_list=["human", "mouse", "ecoli"]) #TODO - define species list

âœ“ Using device: cuda
  GPU: NVIDIA A100-SXM4-40GB
  Memory: 42.29 GB
ðŸ“¥ Loading model: facebook/esm2_t6_8M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Adding species tokens: ['<sp_human>', '<sp_mouse>', '<sp_ecoli>']
Added 3 new special tokens
Resized model embeddings to 36 tokens
âœ“ Model and tokenizer ready!
  Hidden size: 320
  Number of layers: 6


## Load Data

In [25]:
# Example data for embedding - need to replace with our actual data
data = pd.DataFrame({
    "species": ["human", "mouse", "ecoli"],
    "sequence": [
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ"
    ]
})

## Generate Embeddings for Sequences with Species Information, with mean pooling
slower , not recommended for a large dataset

In [26]:
# Generate embeddings for each sequence manually
embeddings = []

for _, row in data.iterrows():
    emb = model.embed(row['species'], row['sequence'])
    emb_mean = emb.mean(dim=1).squeeze().cpu().numpy() # Mean pooling over sequence length, can decide to use different pooling
    embeddings.append(emb_mean)

embeddings = np.vstack(embeddings)
print("Embeddings array shape:", embeddings.shape)

Embeddings array shape: (3, 320)


## Generate Embeddings for *a Batch of Sequences* with Species Information, with mean pooling

In [27]:
species_batch = data['species'].tolist()
sequence_batch = data['sequence'].tolist()

with torch.no_grad():
    outputs = model.forward(species_batch, sequence_batch)

batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
print(batch_embeddings.shape)

(3, 320)


## Fine-tune the Model using LoRA

In [28]:
from src.fine_tuning.fine_tuning import LoraESMFinetuner

# ---- Instantiate base model ----
species_list = ["human", "mouse", "ecoli"] #TODO - define species list
model = SpeciesAwareESM2(species_list=species_list)

# ---- Prepare your data ----
species_batch = data["species"].tolist()
sequence_batch = data["sequence"].tolist()

# ---- Create and train LoRA finetuner ----
finetuner = LoraESMFinetuner(
    base_model=model,
    r=8,
    alpha=16,
    dropout=0.05,
    lr=1e-4,
    batch_size=4,
    mlm_probability=0.15,
)

finetuner.train(species_batch, sequence_batch, epochs=10)


ModuleNotFoundError: No module named 'src.fine_tuning.fine_tuning'

In [ ]:
with torch.no_grad():
    tuned_embeddings = finetuner.embed(species_batch, sequence_batch)

print("Tuned embedding shape:", tuned_embeddings.shape)
