## Load packages and Set Working Directory

In [None]:
import numpy as np
import pandas as pd
import os
import torch

parent_dir = os.path.dirname(os.getcwd())
data_root = os.path.join(parent_dir, "ZooTransform")
os.chdir(data_root)

from src.model.species_model import SpeciesAwareESM2

## Load and Use SpeciesAwareESM2 Model

In [None]:
# Load the pre-trained SpeciesAwareESM2 model
species_model = SpeciesAwareESM2(model_name="facebook/esm2_t6_8M_UR50D", species_list=["human", "mouse", "ecoli"]) #TODO - define species list

## Load Data

In [None]:
# Example data for embedding - need to replace with our actual data
data = pd.DataFrame({
    "species": ["human", "mouse", "ecoli", "human", "mouse", "ecoli","human", "mouse", "ecoli"],
    "sequence": [
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ"
    ]
})

## Generate Embeddings for Sequences with Species Information, with mean pooling
slower , not recommended for a large dataset

In [None]:
# Generate embeddings for each sequence manually
embeddings = []

for _, row in data.iterrows():
    emb = species_model.embed(row['species'], row['sequence'])
    emb_mean = emb.mean(dim=1).squeeze().cpu().numpy() # Mean pooling over sequence length, can decide to use different pooling
    embeddings.append(emb_mean)

embeddings = np.vstack(embeddings)
print("Embeddings array shape:", embeddings.shape)

## Generate Embeddings for *a Batch of Sequences* with Species Information, with mean pooling

In [None]:
# Generate embeddings for a batch of sequences (faster); here batch is the entire dataset
species_batch = data["species"].tolist()
sequence_batch = data["sequence"].tolist()

with torch.no_grad():
    outputs = species_model.forward(species_batch, sequence_batch)

batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
print(batch_embeddings.shape)

In [None]:
# for name, module in list(species_model.model.named_modules())[:60]:
#     print(name)

## Split Data into Train, Validation, and Test Sets (not used in our case)

In [ ]:
from sklearn.model_selection import train_test_split #TODO - do we want to have splits? generally we just want to fine-tune on all data

# First, split train vs temp (validation + test)
species_train, species_temp, seq_train, seq_temp = train_test_split(
    species_batch, sequence_batch, test_size=0.3, random_state=42
)

# Then, split temp into validation and test (50% of temp each = 15% total)
species_val, species_test, seq_val, seq_test = train_test_split(
    species_temp, seq_temp, test_size=0.5, random_state=42
)

print(f"Train: {len(species_train)}, Val: {len(species_val)}, Test: {len(species_test)}")

## Fine-tune the Model using LoRA (on all data)

In [0]:
from src.fine_tuning.fine_tuning import LoraFinetuner
from transformers import AutoModel, AutoTokenizer
import torch

species_batch = data["species"].tolist()
sequence_batch = data["sequence"].tolist()

device = "cuda" if torch.cuda.is_available() else "cpu"
old_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D").to(device)
old_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
tokens_old = old_tokenizer(sequence_batch, return_tensors="pt", padding=True, truncation=True, max_length=1024)
tokens_old = {k: v.to(device) for k,v in tokens_old.items()}
with torch.no_grad():
    old_embeddings = old_model(**tokens_old).last_hidden_state.mean(dim=1)

# Species-aware model
species_model = SpeciesAwareESM2(species_list=["human","mouse","ecoli"])

# LoRA finetuner
finetuner = LoraFinetuner(base_model=species_model, r=8, alpha=16, dropout=0.05, target_modules=None, lr=1e-4, batch_size=4)  #TODO - optionally optimize parameters for LoRA 

# Train to align species embeddings to frozen embeddings
finetuner.train(species_train, seq_train, frozen_embeddings=None, epochs=5) #TODO - set frozen embeddings

# Extract tuned embeddings
tuned_embeddings = finetuner.embed(species_batch, sequence_batch)
print("Tuned embeddings shape:", tuned_embeddings.shape)