In [1]:
!pip install torch transformers esm huggingface-hub

Collecting esm
  Using cached esm-3.0.5-py3-none-any.whl.metadata (9.4 kB)
Collecting torchtext (from esm)
  Using cached torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting einops (from esm)
  Using cached einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting biotite==0.41.2 (from esm)
  Using cached biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting msgpack-numpy (from esm)
  Using cached msgpack_numpy-0.4.8-py2.py3-none-any.whl.metadata (5.0 kB)
Collecting biopython (from esm)
  Using cached biopython-1.84-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Using cached esm-3.0.5-py3-none-any.whl (148 kB)
Using cached biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.0 MB)
Using cached biopython-1.84-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
Using cached einops-0.8.0-py3-none-any.whl (43 kB)
Using cached msgpack_numpy-0.4.8

In [None]:
import os
import scipy
import sklearn
import esm

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torch import nn
import math

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import Trainer, TrainingArguments, EvalPrediction

from esm.pretrained import load_model_and_alphabet_hub
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient

from sklearn.metrics import r2_score, mean_squared_error

2024-10-13 05:18:45.627718: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [1]:
# Step 1: Set Persistent TORCH_HOME Directory Path
efs_model_path = "/home/sagemaker-user/user-default-efs/torch_hub"

# Step 2: Set TORCH_HOME Environment Variable to Ensure Persistent Storage
os.environ['TORCH_HOME'] = efs_model_path
if not os.path.exists(efs_model_path):
    os.makedirs(efs_model_path)

# Step 3: Set Device for Computation (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 4: Login to HuggingFace to Access ESM3 Weights
# Optional optimization: Use `HUGGINGFACE_HUB_TOKEN` env variable to bypass manual login
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
if huggingface_token:
    login(token=huggingface_token)
else:
    login()  # Will prompt user to enter their API key if no token is set

# Step 5: Load ESM3 Model from HuggingFace Hub
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to(device)

# Step 6: Save Model Weights to Persistent EFS Directory (optional, already stored via TORCH_HOME)
model.save_pretrained(efs_model_path)

print("Model loaded and saved successfully!")

NameError: name 'os' is not defined

In [None]:
# PDB Dataset Class Update
class PDB_Dataset(Dataset):
    def __init__(self, df, label_type='regression'):
        """
        Construct all the necessary attributes to the PDB_Dataset object.
        
        Parameters:
            df (pandas.DataFrame): dataframe with two columns: 
                0 -- protein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contact number values in list [0, 0.123, 0.23, -100, 1.34] format
            label_type (str): type of model: regression or binary
        """
        self.df = df
        # Use ESM3 Inference Client for token conversion
        self.batch_converter = model.get_batch_converter()
        self.label_type = label_type

    def __getitem__(self, idx):
        item = {}
        _, _, tokens = self.batch_converter([('', ''.join(self.df.iloc[idx, 0])[:1024])])
        item['token_ids'] = tokens
        item['labels'] = torch.unsqueeze(torch.FloatTensor(self.df.iloc[idx, 1][:1024]), 0).to(torch.float32)
        return item

    def __len__(self):
        return len(self.df)

In [None]:
# ESM3 Token Classification Model
class ESM3ForTokenClassification(nn.Module):
    def __init__(self, num_labels=1):
        super().__init__()
        # Load the ESM3 model
        self.esm3 = model  # Use the ESM3InferenceClient loaded previously
        self.num_labels = num_labels
        # Define a linear classification layer to classify each token
        self.classifier = nn.Linear(self.esm3.embed_dim, num_labels)

    def forward(self, token_ids, labels=None):
        # Forward pass using ESM3 to get representations
        outputs = self.esm3.forward(token_ids)
        # Get the representations for each token, ignoring CLS and padding tokens
        hidden_states = outputs['representations'][:, 1:-1, :]  # Shape: (batch_size, seq_len, embed_dim)
        
        # Pass the representations through the classification layer
        logits = self.classifier(hidden_states)

        return SequenceClassifierOutput(logits=logits)

In [None]:
def model_init_1():
    return ESM3ForTokenClassification(pretrained_no = 1)#.cuda()

In [None]:
# Custom Loss Function for Masked Regression
class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, inputs, target, mask):    
        diff2 = (torch.flatten(inputs[:, :, 0]) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
        result = torch.sum(diff2) / torch.sum(mask)
        if torch.sum(mask) == 0:
            return torch.sum(diff2)
        else:
            return result

In [None]:
# Custom Masked Regress Trainer
class MaskedRegressTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        labels = labels.squeeze().detach().cpu().numpy().tolist()
        labels = [math.log(t + 1) if t != -100 else -100 for t in labels]
        labels = torch.unsqueeze(torch.FloatTensor(labels), 0)
        masks = ~torch.eq(labels, -100)

        # Run the model on the input tokens
        outputs = model(**inputs)
        logits = outputs.logits

        # Use a custom loss function
        loss_fn = MaskedMSELoss()
        loss = loss_fn(logits, labels, masks)

        return (loss, outputs) if return_outputs else loss

In [None]:
# Metrics Calculation Function
def compute_metrics_regr(p: EvalPrediction):
    preds = p.predictions[:, :, 0]
    batch_size, seq_len = preds.shape
    out_labels, out_preds = [], []

    for i in range(batch_size):
        for j in range(seq_len):
            if p.label_ids[i, j] > -1:
                out_labels.append(p.label_ids[i][j])
                out_preds.append(preds[i][j])

    return {
        "pearson_r": scipy.stats.pearsonr(out_labels, out_preds)[0],
        "mse": mean_squared_error(out_labels, out_preds),
        "r2_score": r2_score(out_labels, out_preds)
    }

In [None]:
# Data Collator Function for Batch Processing
def collator_fn(x):
    if len(x) == 1:
        return x[0]
    return x

In [None]:
train_set = pd.read_csv('./data/sema_2.0/train_set.csv')
train_set = train_set.groupby('pdb_id_chain').agg({'resi_pos': list,
                                 'resi_aa': list,
                                 'contact_number': list}).reset_index()
## the first run will take about 5-10 minutes, because esm weights should be downloaded
train_ds = PDB_Dataset(train_set[['resi_aa', 'contact_number']], 
                      label_type ='regression')

test_set = pd.read_csv('../data/sema_2.0/test_set.csv')
test_set = test_set.groupby('pdb_id_chain').agg({'resi_pos': list,
                                 'resi_aa': list,
                                 'contact_number_binary': list}).reset_index()
test_ds = PDB_Dataset(test_set[['resi_aa', 'contact_number_binary']],
                      label_type ='regression')

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=efs_model_path + '/results_fold',           # output directory
    num_train_epochs=2,                    # total number of training epochs
    per_device_train_batch_size=1,         # batch size per device during training
    per_device_eval_batch_size=1,          # batch size for evaluation
    warmup_steps=0,                        # number of warmup steps for learning rate scheduler
    learning_rate=1e-05,                   # learning rate
    weight_decay=0.0,                      # strength of weight decay
    logging_dir=efs_model_path + '/logs',                  # directory for storing logs
    logging_steps=200,                     # log every 200 steps
    save_strategy="no",                    # do not save checkpoints
    do_train=True,                         # Perform training
    do_eval=True,                          # Perform evaluation
    evaluation_strategy="epoch",           # evaluate after each epoch
    gradient_accumulation_steps=1,         # number of steps before backpropagation
    fp16=False,                            # Use mixed precision
    run_name="PDB_binary",                 # experiment name
    seed=42,                               # Seed for reproducibility
    load_best_model_at_end=False,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,
    use_cpu=True
)

# Instantiate and train the model
trainer = MaskedRegressTrainer(
    model=ESM3ForTokenClassification(),   # Use the ESM3-based model for token classification
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=collator_fn,
    compute_metrics=compute_metrics_regr
)

# Train the model
trainer.train()

# Save the fine-tuned weights
fine_tuned_model_path = os.path.join(efs_model_path, "fine_tuned_sema_3_ESM3.pth")
torch.save(trainer.model.state_dict(), fine_tuned_model_path)