In [27]:
from src.fine_tuning.fine_tuning import LoraFinetuner
from transformers import AutoModel, AutoTokenizer

In [28]:
import os

data_root = os.path.join(os.getcwd())
os.chdir(data_root)

from src.model.species_model import SpeciesAwareESM2
species_model = SpeciesAwareESM2(species_list=["human","mouse","ecoli"])

Using device: cuda
  GPU: NVIDIA A100-SXM4-40GB
  Memory: 42.29 GB
Loading model: facebook/esm2_t6_8M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Adding species tokens: ['<sp_human>', '<sp_mouse>', '<sp_ecoli>']
Added 3 new special tokens
Resized model embeddings to 36 tokens
✓ Model and tokenizer ready!
  Hidden size: 320
  Number of layers: 6


## Load data

In [29]:
import pandas as pd

# Example data for embedding - need to replace with our actual data
data = pd.DataFrame({
    "species": ["human", "mouse", "ecoli", "human", "mouse", "ecoli","human", "mouse", "ecoli","human", "mouse", "ecoli"],
    "sequence": [
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ",
        "MKVSAIAKQRQISFVKSHFSRQLRERLGLIEVQ",
        "MKTVYIAKQRQISFVKSHFSRQLEERLGLIEVQ"
    ]
})

species_batch = data["species"]
sequence_batch = data["sequence"]

In [31]:
import numpy as np
import optuna
import torch
from src.fine_tuning.fine_tuning import LoraFinetuner

def objective(trial):
    # Suggest hyperparameters
    r = trial.suggest_int("r", 4, 32)
    alpha = int(trial.suggest_float("alpha", 8, 64))
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    batch_size = trial.suggest_categorical("batch_size", [2, 4])

    try:
        # Create model with these LoRA params
        finetuner = LoraFinetuner(
            base_model=species_model,
            batch_size=batch_size,
            r=r,
            alpha=alpha,
            lr=lr
        )
    
    
        # Train on all data (no frozen embeddings)
        final_loss = finetuner.train(
            species_batch.tolist(),
            sequence_batch.tolist(),
            frozen_embeddings=None,
            epochs=5
        )
        
        if final_loss is None or np.isnan(final_loss):
                final_loss = float("inf")
            
    except Exception as e:
        # print(f"⚠️ Trial {trial.number} failed: {e}")
        # final_loss = float("inf")
        # raise optuna.exceptions.TrialPruned()
        print(f"⚠️ Trial {trial.number} failed: {e}")
        import traceback; traceback.print_exc()
        final_loss = float("inf")
    
    return final_loss

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=5)

print("Best hyperparameters:", study.best_params)


[I 2025-11-05 12:53:36,477] A new study created in memory with name: no-name-331d78d7-d394-48de-b69f-f279e657ab97
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 24.44it/s, loss=11.8978]


Epoch 1 — avg loss: 11.9532


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 26.40it/s, loss=11.6751]


Epoch 2 — avg loss: 11.7815


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.00it/s, loss=11.5199]


Epoch 3 — avg loss: 11.6014


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.60it/s, loss=11.2962]


Epoch 4 — avg loss: 11.4137


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.65it/s, loss=11.1663]
[I 2025-11-05 12:53:37,075] Trial 0 finished with value: inf and parameters: {'r': 20, 'alpha': 50.42483365897086, 'lr': 0.00010266430184193246, 'batch_size': 4}. Best is trial 0 with value: inf.


Epoch 5 — avg loss: 11.2285


Epoch 1/5: 100%|██████████| 6/6 [00:00<00:00, 46.72it/s, loss=8.0634]


Epoch 1 — avg loss: 8.3037


Epoch 2/5: 100%|██████████| 6/6 [00:00<00:00, 47.46it/s, loss=7.4298]


Epoch 2 — avg loss: 7.7414


Epoch 3/5: 100%|██████████| 6/6 [00:00<00:00, 47.49it/s, loss=6.1025]


Epoch 3 — avg loss: 6.7268


Epoch 4/5: 100%|██████████| 6/6 [00:00<00:00, 40.71it/s, loss=5.2253]


Epoch 4 — avg loss: 5.5226


Epoch 5/5: 100%|██████████| 6/6 [00:00<00:00, 41.31it/s, loss=4.6016]
[I 2025-11-05 12:53:37,766] Trial 1 finished with value: inf and parameters: {'r': 20, 'alpha': 16.42934734425875, 'lr': 0.0005249787157289715, 'batch_size': 2}. Best is trial 0 with value: inf.


Epoch 5 — avg loss: 4.8389


Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 27.33it/s, loss=11.9671]


Epoch 1 — avg loss: 11.9876


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 27.31it/s, loss=11.8688]


Epoch 2 — avg loss: 11.9199


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.42it/s, loss=11.9060]


Epoch 3 — avg loss: 11.8483


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.25it/s, loss=11.7825]


Epoch 4 — avg loss: 11.7784


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.32it/s, loss=11.6628]
[I 2025-11-05 12:53:38,334] Trial 2 finished with value: inf and parameters: {'r': 30, 'alpha': 51.16397303286668, 'lr': 3.73540505608111e-05, 'batch_size': 4}. Best is trial 0 with value: inf.


Epoch 5 — avg loss: 11.7092


Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 27.42it/s, loss=11.8095]


Epoch 1 — avg loss: 11.9258


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 27.54it/s, loss=11.5003]


Epoch 2 — avg loss: 11.6215


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.46it/s, loss=11.1846]


Epoch 3 — avg loss: 11.2573


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.46it/s, loss=10.6757]


Epoch 4 — avg loss: 10.8283


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.45it/s, loss=10.0560]
[I 2025-11-05 12:53:38,898] Trial 3 finished with value: inf and parameters: {'r': 7, 'alpha': 15.836090748756884, 'lr': 0.0005436460315951472, 'batch_size': 4}. Best is trial 0 with value: inf.


Epoch 5 — avg loss: 10.2589


Epoch 1/5: 100%|██████████| 3/3 [00:00<00:00, 27.41it/s, loss=11.0917]


Epoch 1 — avg loss: 11.5274


Epoch 2/5: 100%|██████████| 3/3 [00:00<00:00, 27.33it/s, loss=8.7453]


Epoch 2 — avg loss: 9.6983


Epoch 3/5: 100%|██████████| 3/3 [00:00<00:00, 27.46it/s, loss=7.4265]


Epoch 3 — avg loss: 7.7561


Epoch 4/5: 100%|██████████| 3/3 [00:00<00:00, 27.55it/s, loss=6.2814]


Epoch 4 — avg loss: 6.5490


Epoch 5/5: 100%|██████████| 3/3 [00:00<00:00, 27.31it/s, loss=5.3978]
[I 2025-11-05 12:53:39,463] Trial 4 finished with value: inf and parameters: {'r': 15, 'alpha': 63.38056973603155, 'lr': 0.0008810821030840578, 'batch_size': 4}. Best is trial 0 with value: inf.


Epoch 5 — avg loss: 5.6581
Best hyperparameters: {'r': 20, 'alpha': 50.42483365897086, 'lr': 0.00010266430184193246, 'batch_size': 4}
