In [1]:
import torch
import torch.nn as nn
import numpy as np
from transformers import EsmModel, EsmTokenizer, TrainingArguments, Trainer
from datasets import load_dataset, Dataset
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, precision_score, recall_score

import pandas as pd

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# Verify the selected GPU
print("Available GPU:", torch.cuda.device_count())
print("Using GPU:", torch.cuda.current_device())
print("GPU Name:", torch.cuda.get_device_name(torch.cuda.current_device()))


LABEL_COLS = [
    'TRANSCRIPTIONAL', 'CHROMOSOME', 'NUCLEAR_PORE_COMPLEX',
    'NUCLEAR_SPECKLE', 'P-BODY', 'PML-BDOY', 'POST_SYNAPTIC_DENSITY',
    'STRESS_GRANULE', 'NUCLEOLUS', 'CAJAL_BODY', 'RNA_GRANULE', 'CELL_JUNCTION'
]

DATASET_FILENAME = "/home/zengs/data/Code/reproduce/protgps/data/dataset_from_json.csv"



Available GPU: 1
Using GPU: 0
GPU Name: NVIDIA RTX A6000


In [69]:
def get_dataset(tokenizer, dataset_filename):
  # Load CSV dataset
  data_df = pd.read_csv(dataset_filename)

  # Convert label columns to lists of integers (multi-hot encoding)
  data_df["labels"] = data_df[LABEL_COLS].values.tolist()

  # Tokenization function
  def tokenize_function(examples):
    return tokenizer(
        examples["sequence"], 
        padding="max_length", 
        truncation=True, 
        max_length=1800
    )

  # Convert to Hugging Face dataset
  dataset = Dataset.from_pandas(data_df)
  dataset = dataset.map(tokenize_function, batched=True)

  # Ensure labels are in correct format
  def format_labels(example):
    example["labels"] = torch.tensor(example["labels"], dtype=torch.float)
    return example

  dataset = dataset.map(format_labels)

  # Split dataset
  train_dataset = dataset.filter(lambda x: x["split"] == "train")
  val_dataset = dataset.filter(lambda x: x["split"] == "dev")
  test_dataset = dataset.filter(lambda x: x["split"] == "test")

  return train_dataset, val_dataset, test_dataset


train_dataset, val_dataset, test_dataset = get_dataset(tokenizer, DATASET_FILENAME)

Map:   0%|          | 0/5480 [00:00<?, ? examples/s]

Map:   0%|          | 0/5480 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5480 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5480 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5480 [00:00<?, ? examples/s]

In [70]:
for one in train_dataset:
    print(one["labels"])
    break

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]


In [71]:
def compute_metrics(eval_pred):
  logits, labels = eval_pred
  preds = torch.sigmoid(torch.tensor(logits)).numpy()  # Convert logits to probabilities
  labels = np.array(labels)

  # Convert probs to binary using 0.5 threshold
  preds_binary = (preds > 0.5).astype(int)

  # Compute metrics
  auc = roc_auc_score(labels, preds, average="macro")
  f1 = f1_score(labels, preds_binary, average="macro")
  mcc = matthews_corrcoef(labels.flatten(), preds_binary.flatten())
  precision = precision_score(labels, preds_binary, average="macro", zero_division=0)
  recall = recall_score(labels, preds_binary, average="macro", zero_division=0)

  return {
      "AUC-ROC": auc,
      "F1": f1,
      "MCC": mcc,
      "Precision": precision,
      "Recall": recall,
  }

In [88]:
# Define model with an MLP classifier
class ESM2MLP(nn.Module):
  def __init__(self, model_name, num_classes=12):
    super().__init__()
    self.esm = EsmModel.from_pretrained(model_name)
    hidden_dim = self.esm.config.hidden_size  # ESM2-8M has 320 hidden dim
    self.classifier = nn.Sequential(
        nn.Linear(hidden_dim, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, num_classes)  # Output 12 logits
    )
    
    self.criterion = nn.BCEWithLogitsLoss()

  def forward(self, input_ids, attention_mask=None, labels=None):
    outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask)

    pooled_output = outputs.last_hidden_state.mean(axis=1)
    logits = self.classifier(pooled_output)
    
    if labels is not None:
      return self.compute_loss(logits, labels), logits
    
    return logits
  
  def compute_loss(self, logits, labels):
    # Use CrossEntropyLoss to compute the scalar loss
    loss = self.criterion(logits, labels.float())
    return loss
  
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)

model = ESM2MLP(model_name)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [89]:
training_args = TrainingArguments(
    output_dir="./test_runs/esm2_mlp_output",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=10,  # Adjust based on GPU memory
    per_device_eval_batch_size=10,
    num_train_epochs=30,
    weight_decay=0.0,
    fp16=True,  # Use mixed precision
    logging_dir="./test_runs/esm2_mlp_output/logs",
    logging_steps=10,
    save_total_limit=2,
    report_to="none",
    metric_for_best_model="AUC-ROC",
    load_best_model_at_end=True,
    local_rank=-1
)

# Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Train model
trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss


ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

In [87]:
import gc

del model 
del trainer 
gc.collect()

# Clear GPU memory cache
torch.cuda.empty_cache()

NameError: name 'model' is not defined