In [2]:
!ls

Dockerfile	     ZooTransform  docker-examples	     models.md
Makefile	     config.mk	   entrypoint.sh	     pyproject.toml
ProteinGym_DMS_data  configure	   jupyter_server_config.py  tutorials
README.md	     data	   license.txt		     uv.lock


In [3]:
import os

data_root = os.path.join(os.getcwd(), "ZooTransform") # TODO - Adjust this path as necessary
os.chdir(data_root)

from src.fine_tuning.fine_tuning import LoraFinetuner
from transformers import AutoModel, AutoTokenizer
from src.model.species_model import SpeciesAwareESM2
species_model = SpeciesAwareESM2(species_list=["human","mouse","ecoli"]) #TODO - Replace with actual species

✓ All libraries imported successfully!
Using device: cuda
  GPU: NVIDIA A100-SXM4-40GB
  Memory: 42.29 GB
Loading model: facebook/esm2_t6_8M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Adding species tokens: ['<sp_human>', '<sp_mouse>', '<sp_ecoli>']
Added 3 new special tokens
Resized model embeddings to 36 tokens
✓ Model and tokenizer ready!
  Hidden size: 320
  Number of layers: 6


## Load data

In [4]:
import pandas as pd

# Example data for embedding - need to replace with our actual data
data = pd.DataFrame({
    "species": ["human", "mouse", "ecoli", "human", "mouse", "ecoli","human", "mouse", "ecoli","human", "mouse", "ecoli"],
    "sequence": [
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ"
    ]
})

species_batch = data["species"]
sequence_batch = data["sequence"]

# Define objective function for LoRA hyperparameter tuning (without masked embeddings)

In [8]:
import numpy as np
import optuna
import torch
from src.fine_tuning.fine_tuning import LoraFinetuner

def objective(trial):
    # Suggest hyperparameters
    r = trial.suggest_int("r", 4, 32)
    alpha = int(trial.suggest_float("alpha", 8, 64))
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    batch_size = trial.suggest_categorical("batch_size", [2, 4]) # TODO - Reduced batch sizes for testing

    try:
        # Create model with these LoRA params
        finetuner = LoraFinetuner(
            base_model=species_model,
            batch_size=batch_size,
            r=r,
            alpha=alpha,
            lr=lr
        )
    
    
        # Train on all data (no frozen embeddings)
        final_loss = finetuner.train(
            species_batch.tolist(),
            sequence_batch.tolist(),
            frozen_embeddings=None,
            epochs=5
        )
        
        if final_loss is None or np.isnan(final_loss):
                final_loss = float("inf")
            
    except Exception as e:
        # print(f"⚠️ Trial {trial.number} failed: {e}")
        # final_loss = float("inf")
        # raise optuna.exceptions.TrialPruned()
        print(f"Trial {trial.number} failed: {e}")
        import traceback; traceback.print_exc()
        final_loss = float("inf")
    
    return final_loss

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=5)

print("Best hyperparameters:", study.best_params)


[I 2025-11-05 13:03:55,675] A new study created in memory with name: no-name-3f0bb664-894d-4e5e-9716-42492ed0e0e9
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
Epoch 1/5: 100%|██████████| 6/6 [00:00<00:00, 14.98it/s, loss=0.1089]


Epoch 1 — avg loss: 0.1109


Epoch 2/5: 100%|██████████| 6/6 [00:00<00:00, 47.77it/s, loss=0.1059]


Epoch 2 — avg loss: 0.1072


Epoch 3/5: 100%|██████████| 6/6 [00:00<00:00, 48.47it/s, loss=0.1018]


Epoch 3 — avg loss: 0.1036


Epoch 4/5: 100%|██████████| 6/6 [00:00<00:00, 48.45it/s, loss=0.0984]


Epoch 4 — avg loss: 0.1002


Epoch 5/5: 100%|██████████| 6/6 [00:00<00:00, 47.55it/s, loss=0.0958]
[I 2025-11-05 13:03:57,093] Trial 0 finished with value: 0.10374657834569614 and parameters: {'r': 25, 'alpha': 62.26954583607888, 'lr': 4.636950770235651e-05, 'batch_size': 2}. Best is trial 0 with value: 0.10374657834569614.


Epoch 5 — avg loss: 0.0968
Returning final average loss across epochs: 0.1037


Epoch 1/5: 100%|██████████| 6/6 [00:00<00:00, 48.45it/s, loss=0.1109]


Epoch 1 — avg loss: 0.1119


Epoch 2/5: 100%|██████████| 6/6 [00:00<00:00, 47.55it/s, loss=0.1096]


Epoch 2 — avg loss: 0.1106


Epoch 3/5: 100%|██████████| 6/6 [00:00<00:00, 47.57it/s, loss=0.1089]


Epoch 3 — avg loss: 0.1093


Epoch 4/5: 100%|██████████| 6/6 [00:00<00:00, 48.58it/s, loss=0.1069]


Epoch 4 — avg loss: 0.1080


Epoch 5/5: 100%|██████████| 6/6 [00:00<00:00, 48.76it/s, loss=0.1070]
[I 2025-11-05 13:03:57,739] Trial 1 finished with value: 0.10928999433914821 and parameters: {'r': 29, 'alpha': 62.20296079792073, 'lr': 1.681000246194575e-05, 'batch_size': 2}. Best is trial 0 with value: 0.10374657834569614.


Epoch 5 — avg loss: 0.1066
Returning final average loss across epochs: 0.1093


Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 27.45it/s, loss=0.1071]


Epoch 1 — avg loss: 0.1101


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 27.72it/s, loss=0.0999]


Epoch 2 — avg loss: 0.1025


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.70it/s, loss=0.0919]


Epoch 3 — avg loss: 0.0950


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.65it/s, loss=0.0800]


Epoch 4 — avg loss: 0.0845


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.70it/s, loss=0.0625]
[I 2025-11-05 13:03:58,302] Trial 2 finished with value: 0.09210727711518604 and parameters: {'r': 12, 'alpha': 18.3052364475463, 'lr': 0.0006116314470293658, 'batch_size': 4}. Best is trial 2 with value: 0.09210727711518604.


Epoch 5 — avg loss: 0.0685
Returning final average loss across epochs: 0.0921


Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 27.75it/s, loss=0.1060]


Epoch 1 — avg loss: 0.1093


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 27.71it/s, loss=0.0976]


Epoch 2 — avg loss: 0.1004


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.63it/s, loss=0.0876]


Epoch 3 — avg loss: 0.0915


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.79it/s, loss=0.0711]


Epoch 4 — avg loss: 0.0773


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.72it/s, loss=0.0563]
[I 2025-11-05 13:03:58,863] Trial 3 finished with value: 0.08777781104048094 and parameters: {'r': 16, 'alpha': 44.67083647483622, 'lr': 0.000375660349465526, 'batch_size': 4}. Best is trial 3 with value: 0.08777781104048094.


Epoch 5 — avg loss: 0.0604
Returning final average loss across epochs: 0.0878


Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 27.49it/s, loss=0.1127]


Epoch 1 — avg loss: 0.1124


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 27.55it/s, loss=0.1119]


Epoch 2 — avg loss: 0.1124


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.51it/s, loss=0.1125]


Epoch 3 — avg loss: 0.1123


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.45it/s, loss=0.1121]


Epoch 4 — avg loss: 0.1122


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.39it/s, loss=0.1120]
[I 2025-11-05 13:03:59,429] Trial 4 finished with value: 0.11230654517809549 and parameters: {'r': 15, 'alpha': 9.782495161117524, 'lr': 1.071356363059682e-05, 'batch_size': 4}. Best is trial 3 with value: 0.08777781104048094.


Epoch 5 — avg loss: 0.1122
Returning final average loss across epochs: 0.1123
Best hyperparameters: {'r': 16, 'alpha': 44.67083647483622, 'lr': 0.000375660349465526, 'batch_size': 4}


# Define objective function for LoRA hyperparameter tuning (with masked embeddings)

In [5]:
import numpy as np
import optuna
import torch
from src.fine_tuning.fine_tuning import LoraFinetuner, LoraFinetunerMLM

def objective(trial, species_model, species_batch, sequence_batch, mode="mlm"):
    """
    Optuna objective for tuning LoRA hyperparameters.
    
    Args:
        trial: Optuna trial
        species_model: SpeciesAwareESM2 object
        species_batch: list of species strings
        sequence_batch: list of sequences
        mode: "mlm" for Masked LM or "embedding" for L2 alignment
    """
    # --- Suggest hyperparameters ---
    r = trial.suggest_int("r", 4, 32)
    alpha = int(trial.suggest_float("alpha", 8, 64))
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    batch_size = trial.suggest_categorical("batch_size", [2, 4])

    try:
        # --- Initialize correct LoRA finetuner ---
        if mode == "mlm":
            finetuner = LoraFinetunerMLM(
                base_model=species_model,
                r=r, alpha=alpha, lr=lr, batch_size=batch_size
            )
        elif mode == "embedding":
            finetuner = LoraFinetuner(
                base_model=species_model,
                r=r, alpha=alpha, lr=lr, batch_size=batch_size
            )
        else:
            raise ValueError(f"Unknown mode: {mode}")

        # --- Train the model ---
        if mode == "embedding":
            final_loss = finetuner.train(
                species_batch=species_batch,
                sequence_batch=sequence_batch,
                frozen_embeddings=None,
                epochs=3
            )
        else:  # MLM
            finetuner.train(
                species_batch=species_batch,
                sequence_batch=sequence_batch,
                epochs=3
            )
            # For MLM we can approximate final loss by embedding mean
            final_loss = 0.0  # Optuna just needs a scalar; MLM returns per-step loss

        # Fallback in case of NaN
        if final_loss is None or np.isnan(final_loss):
            final_loss = float("inf")

    except Exception as e:
        print(f"Trial {trial.number} failed: {e}")
        import traceback; traceback.print_exc()
        final_loss = float("inf")

    return final_loss


# --- Create and run study ---
study = optuna.create_study(direction="minimize")
study.optimize(
    lambda trial: objective(trial, species_model, species_batch.tolist(), sequence_batch.tolist(), mode="mlm"),
    n_trials=5
)

print("Best hyperparameters:", study.best_params)


[I 2025-11-05 15:47:14,534] A new study created in memory with name: no-name-60dcf7f5-a119-4587-b259-d454c84e3ebf
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
Epoch 1/3: 100%|██████████| 3/3 [00:00<00:00,  6.60it/s, loss=5.7720]


Epoch 1 — avg loss: 5.7906


Epoch 2/3: 100%|██████████| 3/3 [00:00<00:00, 15.33it/s, loss=5.6335]


Epoch 2 — avg loss: 5.6694


Epoch 3/3: 100%|██████████| 3/3 [00:00<00:00, 15.83it/s, loss=5.5365]
[I 2025-11-05 15:47:15,783] Trial 0 finished with value: 0.0 and parameters: {'r': 20, 'alpha': 52.26044371272665, 'lr': 0.0006513759522130387, 'batch_size': 4}. Best is trial 0 with value: 0.0.


Epoch 3 — avg loss: 5.5609


Epoch 1/3: 100%|██████████| 3/3 [00:00<00:00, 15.76it/s, loss=5.8136]


Epoch 1 — avg loss: 5.8052


Epoch 2/3: 100%|██████████| 3/3 [00:00<00:00, 15.84it/s, loss=5.6582]


Epoch 2 — avg loss: 5.7196


Epoch 3/3: 100%|██████████| 3/3 [00:00<00:00, 15.79it/s, loss=5.6656]
[I 2025-11-05 15:47:16,369] Trial 1 finished with value: 0.0 and parameters: {'r': 24, 'alpha': 53.419154995654615, 'lr': 0.000157752270494732, 'batch_size': 4}. Best is trial 0 with value: 0.0.


Epoch 3 — avg loss: 5.7106


Epoch 1/3: 100%|██████████| 6/6 [00:00<00:00, 28.02it/s, loss=5.7900]


Epoch 1 — avg loss: 5.8230


Epoch 2/3: 100%|██████████| 6/6 [00:00<00:00, 28.58it/s, loss=5.8192]


Epoch 2 — avg loss: 5.8128


Epoch 3/3: 100%|██████████| 6/6 [00:00<00:00, 28.56it/s, loss=5.7928]
[I 2025-11-05 15:47:17,021] Trial 2 finished with value: 0.0 and parameters: {'r': 31, 'alpha': 24.592364444011142, 'lr': 4.938932297963774e-05, 'batch_size': 2}. Best is trial 0 with value: 0.0.


Epoch 3 — avg loss: 5.7701


Epoch 1/3: 100%|██████████| 6/6 [00:00<00:00, 28.16it/s, loss=5.8272]


Epoch 1 — avg loss: 5.8119


Epoch 2/3: 100%|██████████| 6/6 [00:00<00:00, 28.49it/s, loss=5.8085]


Epoch 2 — avg loss: 5.7917


Epoch 3/3: 100%|██████████| 6/6 [00:00<00:00, 28.28it/s, loss=5.7659]
[I 2025-11-05 15:47:17,672] Trial 3 finished with value: 0.0 and parameters: {'r': 8, 'alpha': 63.256457314511, 'lr': 3.132726238619787e-05, 'batch_size': 2}. Best is trial 0 with value: 0.0.


Epoch 3 — avg loss: 5.7924


Epoch 1/3: 100%|██████████| 6/6 [00:00<00:00, 28.41it/s, loss=5.7040]


Epoch 1 — avg loss: 5.7777


Epoch 2/3: 100%|██████████| 6/6 [00:00<00:00, 28.87it/s, loss=5.8553]


Epoch 2 — avg loss: 5.7935


Epoch 3/3: 100%|██████████| 6/6 [00:00<00:00, 28.94it/s, loss=5.7904]
[I 2025-11-05 15:47:18,315] Trial 4 finished with value: 0.0 and parameters: {'r': 26, 'alpha': 42.53729589137802, 'lr': 2.930462078712776e-05, 'batch_size': 2}. Best is trial 0 with value: 0.0.


Epoch 3 — avg loss: 5.7956
Best hyperparameters: {'r': 20, 'alpha': 52.26044371272665, 'lr': 0.0006513759522130387, 'batch_size': 4}
