In [None]:
"""
NTxPred2 - Training Script with Model Saving and Re-training

This script fine-tunes the ESM2-t30_150M model on peptide/protein sequence data 
to predict neurotoxicity. It performs the following:

1. Loads and preprocesses labeled sequence data.
2. Trains the ESM2 model on the training data.
3. Saves the trained model and tokenizer.
4. Reloads the saved model.
5. Re-trains the model again using the same dataset.
6. Saves the final retrained model.

"""

import random
import numpy as np
import pandas as pd
import torch
import os
from transformers import AutoTokenizer, EsmForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW

# -------------------- Configuration --------------------

SEED = 42                   # Random seed for reproducibility
EPOCHS = 12                 # Number of training epochs
BATCH_SIZE = 8              # Training batch size
LEARNING_RATE = 1e-5        # Optimizer learning rate
MAX_SEQ_LENGTH = 50         # Max tokenized sequence length
TRAIN_FILE = "example_peptide_sequence.csv"           # Path to training CSV file
SAVE_DIR_INITIAL = "./saved_model_t30"  # Directory to save first model
SAVE_DIR_FINAL = "./saved_model_t30_final"  # Directory to save retrained model

# -------------------- Reproducibility Setup --------------------

def set_seed(seed=SEED):
    """
    Sets the seed for reproducibility across numpy, random, and torch.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

# -------------------- Load and Prepare Data --------------------

# Load training data from CSV
df = pd.read_csv(TRAIN_FILE)

# Ensure 'Label' is integer (0 or 1)
df['Label'] = df['Label'].astype(int)

# Select columns: 'Sequence' and 'Label'
df = df.iloc[:, 1:3].reset_index(drop=True)

# Extract sequences and labels
sequences = df['Sequence'].tolist()
labels = torch.tensor(df['Label'].tolist(), dtype=torch.long)

# -------------------- Tokenization --------------------

# Load pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")

# Tokenize input sequences (returns padded tensors)
tokenized_inputs = tokenizer(
    sequences,
    padding=True,
    truncation=True,
    max_length=MAX_SEQ_LENGTH,
    return_tensors="pt"
)

# Wrap tokenized inputs and labels into a PyTorch Dataset
dataset = TensorDataset(
    tokenized_inputs['input_ids'],
    tokenized_inputs['attention_mask'],
    labels
)

# Create DataLoader for batch training
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# -------------------- Initialize Model --------------------

# Load pretrained model and configure for binary classification
model = EsmForSequenceClassification.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", num_labels=2
)

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()

# -------------------- First Training Phase --------------------

print("\n🔁 Starting initial training...")

model.train()
for epoch in range(EPOCHS):
    total_loss = 0

    for batch in dataloader:
        input_ids, attention_mask, label_batch = batch

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        loss = loss_fn(logits, label_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")

# -------------------- Save Initial Model --------------------

os.makedirs(SAVE_DIR_INITIAL, exist_ok=True)
model.save_pretrained(SAVE_DIR_INITIAL)
tokenizer.save_pretrained(SAVE_DIR_INITIAL)
print(f"\n💾 Initial model saved to: {SAVE_DIR_INITIAL}")

# -------------------- Load and Re-initialize Model --------------------

print("\n🔁 Re-loading saved model for re-training...")

# Load the saved model and tokenizer for continued training
model = EsmForSequenceClassification.from_pretrained(SAVE_DIR_INITIAL)
tokenizer = AutoTokenizer.from_pretrained(SAVE_DIR_INITIAL)

# Re-define optimizer for the reloaded model
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# -------------------- Second Training Phase (Re-training) --------------------

print("\n🔁 Starting re-training with the same dataset...")

model.train()
for epoch in range(EPOCHS):
    total_loss = 0

    for batch in dataloader:
        input_ids, attention_mask, label_batch = batch

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        loss = loss_fn(logits, label_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Re-train Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")

# -------------------- Save Final Retrained Model --------------------

os.makedirs(SAVE_DIR_FINAL, exist_ok=True)
model.save_pretrained(SAVE_DIR_FINAL)
tokenizer.save_pretrained(SAVE_DIR_FINAL)
print(f"\n✅ Final retrained model saved to: {SAVE_DIR_FINAL}")


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting initial training...

