## Load packages and Set Working Directory

In [1]:
import numpy as np
import pandas as pd
import os
import sys
import torch

print(os.getcwd())
#parent_dir = os.path.dirname(os.getcwd())
#print(parent_dir)
data_root = os.path.join(os.getcwd(), 'ZooTransform')
# os.chdir(data_root)

sys.path.append(data_root)

from src.zootransform.model.species_model import SpeciesAwareESM2
from src.zootransform.dataset.load_uniprot import load_uniprot

/home/hslab/Olive/Kode/ZooTransform


  from .autonotebook import tqdm as notebook_tqdm


✓ All libraries imported successfully!


In [2]:
!ls

README.md	     load_the_model.ipynb  requirements.txt  uniprot_data
Training_data.ipynb  optuna.ipynb	   setup.cfg	     validations.ipynb
UQ.ipynb	     pyproject.toml	   src


## Load Data

In [3]:
# Example data for embedding - need to replace with our actual data
data = load_uniprot()
species_names = sorted(set(data['species'].unique().tolist()))
species_names

['Arabidopsis thaliana',
 'Bos taurus',
 'Escherichia coli',
 'Homo sapiens',
 'Mus musculus',
 'Oryza sativa',
 'Rattus norvegicus',
 'Rhodotorula toruloides',
 'Saccharolobus solfataricus',
 'Saccharomyces cerevisiae',
 'Schizosaccharomyces pombe',
 'Staphylococcus aureus']

## Load and Use SpeciesAwareESM2 Model

In [4]:
# Load the pre-trained SpeciesAwareESM2 model
species_model = SpeciesAwareESM2(model_name="facebook/esm2_t6_8M_UR50D", species_list=species_names)

Using device: cuda
  GPU: NVIDIA GeForce RTX 4090
  Memory: 25.39 GB
Loading model: facebook/esm2_t6_8M_UR50D


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Adding species tokens: ['<sp_Arabidopsis thaliana>', '<sp_Bos taurus>', '<sp_Escherichia coli>', '<sp_Homo sapiens>', '<sp_Mus musculus>', '<sp_Oryza sativa>', '<sp_Rattus norvegicus>', '<sp_Rhodotorula toruloides>', '<sp_Saccharolobus solfataricus>', '<sp_Saccharomyces cerevisiae>', '<sp_Schizosaccharomyces pombe>', '<sp_Staphylococcus aureus>']
Added 12 new special tokens
Resized model embeddings to 45 tokens
✓ Model and tokenizer ready!
  Hidden size: 320
  Number of layers: 6


## Generate Embeddings for Sequences with Species Information, with mean pooling
slower , not recommended for a large dataset

In [None]:
# Load or generate embeddings
embeddings_file = "uniprot_embeddings.npy"
is_load_embeddings = True  # Set to True to load existing embeddings if available

if os.path.exists(embeddings_file) and is_load_embeddings:
    embeddings = np.load(embeddings_file)
    print(f"Loaded embeddings from {embeddings_file}, shape: {embeddings.shape}")
else:
    # Generate embeddings for each sequence manually
    embeddings = []

    for _, row in data.iterrows():
        emb = species_model.embed(row['species'], row['sequence'])
        emb_mean = emb.mean(dim=1).squeeze().cpu().numpy() # Mean pooling over sequence length, can decide to use different pooling
        embeddings.append(emb_mean)

    embeddings = np.vstack(embeddings)
    print("Embeddings array shape:", embeddings.shape)
    
    # Save embeddings if they haven't already
    embeddings_file = "uniprot_embeddings.npy"
    if not os.path.exists(embeddings_file):
        np.save(embeddings_file, embeddings)
        print(f"Embeddings saved to {embeddings_file}")

Embeddings array shape: (120095, 320)


## Generate Embeddings for *a Batch of Sequences* with Species Information, with mean pooling

In [8]:
# Generate embeddings for a batch of sequences (faster); here batch is the entire dataset
species_batch = data["species"].tolist()
sequence_batch = data["sequence"].tolist()

is_run_batch = False
if is_run_batch:
    with torch.no_grad():
        outputs = species_model.forward(species_batch, sequence_batch)

    batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    print(batch_embeddings.shape)
else:
    batch_embeddings = embeddings

In [9]:
# for name, module in list(species_model.model.named_modules())[:60]:
#     print(name)

## Split Data into Train, Validation, and Test Sets (not used in our case)

In [10]:
from sklearn.model_selection import train_test_split #TODO - do we want to have splits? generally we just want to fine-tune on all data

# First, split train vs temp (validation + test)
species_train, species_temp, seq_train, seq_temp = train_test_split(
    species_batch, sequence_batch, test_size=0.3, random_state=42
)

# Then, split temp into validation and test (50% of temp each = 15% total)
species_val, species_test, seq_val, seq_test = train_test_split(
    species_temp, seq_temp, test_size=0.5, random_state=42
)

print(f"Train: {len(species_train)}, Val: {len(species_val)}, Test: {len(species_test)}")

Train: 84066, Val: 18014, Test: 18015


## Fine-tune the Model using LoRA (on all data)

In [None]:
from src.zootransform.fine_tuning.fine_tuning import LoraFinetuner, LoraFinetunerMLM
from transformers import AutoModel, AutoTokenizer
import torch

species_batch = data["species"].tolist()
sequence_batch = data["sequence"].tolist()

device = "cuda" if torch.cuda.is_available() else "cpu"
old_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D").to(device)
old_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
tokens_old = old_tokenizer(sequence_batch, return_tensors="pt", padding=True, truncation=True, max_length=1024)

# Reshape
batch_size = 128
max_len = len(data) - np.mod(len(data), batch_size)
species_batch = data["species"].to_numpy()[:max_len].reshape(-1, batch_size)
sequence_batch = data["sequence"].to_numpy()[:max_len].reshape(-1, batch_size)
# tokens_old = tokens_old[:max_len].reshape(-1, batch_size)

tokens_old = {k: v[:max_len].reshape(-1, batch_size).to(device) for k,v in tokens_old.items()}
with torch.no_grad():
    old_embeddings = old_model(**tokens_old).last_hidden_state.mean(dim=1)

# Species-aware model
species_model = SpeciesAwareESM2(species_list=species_names)

# LoRA finetuner
finetuner = LoraFinetuner(
    base_model=species_model, 
    r=8, 
    alpha=16, 
    dropout=0.05, 
    target_modules=None, 
    lr=1e-4, 
    batch_size=4)  #TODO - optionally optimize parameters for LoRA 

# Train to align species embeddings to frozen embeddings
finetuner.train(species_train, seq_train, frozen_embeddings=None, epochs=5) #TODO - set frozen embeddings

# Extract tuned embeddings
tuned_embeddings = finetuner.embed(species_batch, sequence_batch)
print("Tuned embeddings shape:", tuned_embeddings.shape)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


AttributeError: 'dict' object has no attribute 'reshape'

In [None]:

tokens_old = {k: v[:max_len].reshape(-1, batch_size).to(device) for k,v in tokens_old.items()}
# with torch.no_grad():
#     old_embeddings = old_model(**tokens_old).last_hidden_state.mean(dim=1)

# Species-aware model
species_model = SpeciesAwareESM2(species_list=species_names)

# LoRA finetuner
finetuner = LoraFinetuner(
    base_model=species_model, 
    r=8, 
    alpha=16, 
    dropout=0.05, 
    target_modules=None, 
    lr=1e-4, 
    batch_size=4)  #TODO - optionally optimize parameters for LoRA 

# Train to align species embeddings to frozen embeddings
finetuner.train(species_train, seq_train, frozen_embeddings=None, epochs=5) #TODO - set frozen embeddings

# Extract tuned embeddings
tuned_embeddings = finetuner.embed(species_batch, sequence_batch)
print("Tuned embeddings shape:", tuned_embeddings.shape)

Using device: cuda
  GPU: NVIDIA GeForce RTX 4090
  Memory: 25.39 GB
Loading model: facebook/esm2_t6_8M_UR50D
Adding species tokens: ['<sp_Arabidopsis thaliana>', '<sp_Bos taurus>', '<sp_Escherichia coli>', '<sp_Homo sapiens>', '<sp_Mus musculus>', '<sp_Oryza sativa>', '<sp_Rattus norvegicus>', '<sp_Rhodotorula toruloides>', '<sp_Saccharolobus solfataricus>', '<sp_Saccharomyces cerevisiae>', '<sp_Schizosaccharomyces pombe>', '<sp_Staphylococcus aureus>']
Added 12 new special tokens
Resized model embeddings to 45 tokens
✓ Model and tokenizer ready!
  Hidden size: 320
  Number of layers: 6


Epoch 1/5:   0%|          | 0/21017 [00:00<?, ?it/s]


AttributeError: 'MaskedLMOutput' object has no attribute 'last_hidden_state'

## Fine-tune the Model using LoRA for Masked Language Modeling (MLM)

In [None]:
from src.zootransform.fine_tuning.fine_tuning import LoraFinetunerMLM  # new MLM version
from src.zootransform.model.species_model import SpeciesAwareESM2
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Species-aware model
species_model = SpeciesAwareESM2(species_list=species_names)
species_model.model.to(device)

species_batch = data["species"].tolist()
sequence_batch = data["sequence"].tolist()

finetuner = LoraFinetunerMLM(
    base_model=species_model,
    r=8,
    alpha=16,
    dropout=0.05,
    target_modules=["attention.self.key", "attention.self.value", "attention.self.query", "embeddings.word_embeddings"],  # LoRA targets
    lr=1e-4,
    batch_size=4,
    mlm_probability=0.15  # fraction of tokens to mask
)

finetuner.train(
    species_batch=species_batch,
    sequence_batch=sequence_batch,
    epochs=5
)

tuned_embeddings = finetuner.embed(species_batch, sequence_batch)
print("Tuned embeddings shape:", tuned_embeddings.shape)

Using device: cuda
  GPU: NVIDIA GeForce RTX 4090
  Memory: 25.39 GB
Loading model: facebook/esm2_t6_8M_UR50D
Adding species tokens: ['<sp_human>', '<sp_mouse>', '<sp_ecoli>']
Added 3 new special tokens
Resized model embeddings to 36 tokens
✓ Model and tokenizer ready!
  Hidden size: 320
  Number of layers: 6


Epoch 1/5:  60%|██████    | 141/235 [01:01<00:41,  2.28it/s, loss=2.5247]


KeyboardInterrupt: 

# Save and Load the Fine-tuned Model

In [None]:
# Directory to save LoRA adapters
save_dir = "lora_finetuned_species_model" #TODO - specify path

# Save only LoRA weights 
finetuner.model.save_pretrained(save_dir)
print(f"LoRA adapters saved to {save_dir}")


LoRA adapters saved to lora_finetuned_species_model
