In [1]:
!pip install torch transformers esm huggingface-hub
!pip uninstall tensorflow -y

Collecting esm
  Using cached esm-3.0.5-py3-none-any.whl.metadata (9.4 kB)
Collecting torchtext (from esm)
  Using cached torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting einops (from esm)
  Using cached einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting biotite==0.41.2 (from esm)
  Using cached biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting msgpack-numpy (from esm)
  Using cached msgpack_numpy-0.4.8-py2.py3-none-any.whl.metadata (5.0 kB)
Collecting biopython (from esm)
  Using cached biopython-1.84-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Using cached esm-3.0.5-py3-none-any.whl (148 kB)
Using cached biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.0 MB)
Using cached biopython-1.84-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
Using cached einops-0.8.0-py3-none-any.whl (43 kB)
Using cached msgpack_numpy-0.4.8

In [2]:
import os
import scipy
import sklearn
import esm

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
import pandas as pd
import numpy as np

In [4]:
import torch
from torch.utils.data import Dataset
from torch import nn
import math

In [5]:
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import Trainer, TrainingArguments, EvalPrediction

In [6]:
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient

In [7]:
from sklearn.metrics import r2_score, mean_squared_error

In [8]:
# Step 1: Set Persistent TORCH_HOME Directory Path
efs_model_path = "/home/sagemaker-user/user-default-efs/torch_hub"

# Step 2: Set TORCH_HOME Environment Variable to Ensure Persistent Storage
os.environ['TORCH_HOME'] = efs_model_path
if not os.path.exists(efs_model_path):
    os.makedirs(efs_model_path)

# Step 3: Set Device for Computation (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 4: Login to HuggingFace to Access ESM3 Weights
# Optional optimization: Use `HUGGINGFACE_HUB_TOKEN` env variable to bypass manual login
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
if huggingface_token:
    login(token=huggingface_token)
else:
    login()  # Will prompt user to enter their API key if no token is set

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [9]:
# Step 5: Load ESM3 Model from HuggingFace Hub
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda")

# Step 6: Save Model Weights to Persistent EFS Directory (optional, already stored via TORCH_HOME)
# model.save_pretrained(efs_model_path)

print("Model loaded and saved successfully!")

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

Model loaded and saved successfully!


In [23]:
import os

efs_model_path = "/home/sagemaker-user/user-default-efs/torch_hub"

# Check if the directory exists
if os.path.exists(efs_model_path):
    print(f"The directory exists: {efs_model_path}")
else:
    print(f"The directory does not exist: {efs_model_path}")

# List files and directories in the specified path
files = os.listdir(efs_model_path)
print(f"Contents of {efs_model_path}:")
for f in files:
    print(f)

The directory exists: /home/sagemaker-user/user-default-efs/torch_hub
Contents of /home/sagemaker-user/user-default-efs/torch_hub:
Test file successfully written to: /home/sagemaker-user/user-default-efs/torch_hub/test_file.txt


In [12]:
class PDB_Dataset(Dataset):
    def __init__(self, df, model, label_type='regression'):
        """
        Construct all the necessary attributes for the PDB_Dataset object.
        
        Parameters:
            df (pandas.DataFrame): dataframe with two columns: 
                0 -- protein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contact number values in list [0, 0.123, 0.23, -100, 1.34] format
            model: The model object used for encoding sequences.
            label_type (str): type of model: regression or binary classification
        """
        self.df = df
        self.model = model
        self.label_type = label_type

    def __getitem__(self, idx):
        item = {}

        # Retrieve and prepare the sequence
        sequence = ''.join(self.df.iloc[idx, 0])[:1024]  # Take up to 1024 residues

        # Encode the sequence using the model's encode method (assuming it's available)
        try:
            tokens = self.model.encode(sequence)
        except AttributeError:
            raise ValueError("The provided model does not have an 'encode' method.")

        # Debug print to check encoding
        print(f"Index: {idx}, Sequence: {sequence}, Tokens: {tokens}")

        item['token_ids'] = tokens

        # Prepare labels
        label_values = self.df.iloc[idx, 1][:1024]  # Ensure label matches length limit of tokens
        item['labels'] = torch.unsqueeze(torch.FloatTensor(label_values), 0).to(torch.float32)

        return item

    def __len__(self):
        return len(self.df)


In [74]:
# ESM3 Token Classification Model
class ESM3ForTokenClassification(nn.Module):
    def __init__(self, model, num_labels=1):
        """
        Initializes the ESM3ForTokenClassification model.

        Args:
            model (ESM3InferenceClient): The preloaded ESM3 model.
            num_labels (int): Number of classification labels.
        """
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.esm3 = model.to(self.device)  # Move model to device
        self.num_labels = num_labels

        # Access the EsmSequenceTokenizer from the TokenizerCollection
        if hasattr(self.esm3, 'tokenizers') and hasattr(self.esm3.tokenizers, 'sequence'):
            self.tokenizer = self.esm3.tokenizers.sequence
            print("Sequence tokenizer found.")
        else:
            raise AttributeError("Model does not have a 'sequence' tokenizer in 'tokenizers' attribute.")

        # Tokenize a dummy input to get the token IDs (for informational purposes)
        dummy_sequence = ("", "GLVM")  # Tuple format as expected by ESM tokenizer
        tokens = self.tokenizer.encode(dummy_sequence)
        print(f"Encoded tokens (dummy): {tokens}")

        # Retrieve embedding dimension from model's parameters
        try:
            embed_dim = next(self.esm3.parameters()).size(-1)
            print(f"Embedding dimension: {embed_dim}")
        except StopIteration:
            raise ValueError("Model has no parameters to infer embedding dimension.")

        # Define a linear classification layer to classify each token
        self.classifier = nn.Linear(embed_dim, num_labels).to(self.device)

    def forward(self, sequences, labels=None):
        """
        Forward pass for token classification.

        Args:
            sequences (List[str]): List of protein sequences.
            labels (torch.Tensor, optional): Labels for the sequences.

        Returns:
            SequenceClassifierOutput: The classification logits.
        """
        logits_list = []

        for seq in sequences:
            # Encode the sequence
            encoded = self.tokenizer.encode(("", seq))
            tokens_tensor = torch.tensor(encoded).unsqueeze(0).to(self.device)  # Shape: (1, seq_len)

            # Create average_plddt tensor
            model_dtype = next(self.esm3.parameters()).dtype
            average_plddt = torch.tensor([70.0], dtype=model_dtype).to(self.device)  # Shape: (1,)

            # Forward pass
            try:
                outputs = self.esm3(sequence_tokens=tokens_tensor, average_plddt=average_plddt)
            except Exception as e:
                raise RuntimeError(f"Failed to perform a forward pass for sequence '{seq}': {e}")

            # Access the 'embeddings' attribute
            if hasattr(outputs, 'embeddings'):
                x = outputs.embeddings  # Shape: (1, seq_len, hidden_dim)
            else:
                raise AttributeError("ESMOutput does not have an 'embeddings' attribute.")

            # Pass through the classification layer
            logits = self.classifier(x)  # Shape: (1, seq_len, num_labels)
            logits_list.append(logits)

        # Concatenate logits from all sequences
        logits = torch.cat(logits_list, dim=0)  # Shape: (batch_size, seq_len, num_labels)

        return SequenceClassifierOutput(logits=logits)

In [46]:
print(type(model.tokenizers))
print(model.tokenizers)

<class 'esm.tokenization.TokenizerCollection'>
TokenizerCollection(sequence=EsmSequenceTokenizer(name_or_path='', vocab_size=33, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<cls>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>', 'additional_special_tokens': ['|']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	31: AddedToken("|", rstrip=False, lstrip=False, single_word=False, normalized=False, special=Tru

In [59]:
def model_init_1():
    return ESM3ForTokenClassification(model=model, num_labels=1)#.cuda()

In [30]:
# Custom Loss Function for Masked Regression
class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, inputs, target, mask):    
        diff2 = (torch.flatten(inputs[:, :, 0]) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
        result = torch.sum(diff2) / torch.sum(mask)
        if torch.sum(mask) == 0:
            return torch.sum(diff2)
        else:
            return result

In [31]:
# Custom Masked Regress Trainer
class MaskedRegressTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        labels = labels.squeeze().detach().cpu().numpy().tolist()
        labels = [math.log(t + 1) if t != -100 else -100 for t in labels]
        labels = torch.unsqueeze(torch.FloatTensor(labels), 0)
        masks = ~torch.eq(labels, -100)

        # Run the model on the input tokens
        outputs = model(**inputs)
        logits = outputs.logits

        # Use a custom loss function
        loss_fn = MaskedMSELoss()
        loss = loss_fn(logits, labels, masks)

        return (loss, outputs) if return_outputs else loss

In [32]:
# Metrics Calculation Function
def compute_metrics_regr(p: EvalPrediction):
    preds = p.predictions[:, :, 0]
    batch_size, seq_len = preds.shape
    out_labels, out_preds = [], []

    for i in range(batch_size):
        for j in range(seq_len):
            if p.label_ids[i, j] > -1:
                out_labels.append(p.label_ids[i][j])
                out_preds.append(preds[i][j])

    return {
        "pearson_r": scipy.stats.pearsonr(out_labels, out_preds)[0],
        "mse": mean_squared_error(out_labels, out_preds),
        "r2_score": r2_score(out_labels, out_preds)
    }

In [33]:
# Data Collator Function for Batch Processing
def collator_fn(x):
    if len(x) == 1:
        return x[0]
    return x

In [34]:
train_set = pd.read_csv('./data/sema_2.0/train_set.csv')
train_set = train_set.groupby('pdb_id_chain').agg({
    'resi_pos': list,
    'resi_aa': list,
    'contact_number': list
}).reset_index()

# Create training dataset
train_ds = PDB_Dataset(train_set[['resi_aa', 'contact_number']], model=model, label_type='regression')

# Load and process the test set similarly
test_set = pd.read_csv('./data/sema_2.0/test_set.csv')
test_set = test_set.groupby('pdb_id_chain').agg({
    'resi_pos': list,
    'resi_aa': list,
    'contact_number_binary': list
}).reset_index()

# Create test dataset
test_ds = PDB_Dataset(test_set[['resi_aa', 'contact_number_binary']], model=model, label_type='regression')


In [35]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=efs_model_path + '/results_fold',           # output directory
    num_train_epochs=2,                    # total number of training epochs
    per_device_train_batch_size=1,         # batch size per device during training
    per_device_eval_batch_size=1,          # batch size for evaluation
    warmup_steps=0,                        # number of warmup steps for learning rate scheduler
    learning_rate=1e-05,                   # learning rate
    weight_decay=0.0,                      # strength of weight decay
    logging_dir=efs_model_path + '/logs',                  # directory for storing logs
    logging_steps=200,                     # log every 200 steps
    save_strategy="no",                    # do not save checkpoints
    do_train=True,                         # Perform training
    do_eval=True,                          # Perform evaluation
    evaluation_strategy="epoch",           # evaluate after each epoch
    gradient_accumulation_steps=1,         # number of steps before backpropagation
    fp16=False,                            # Use mixed precision
    run_name="PDB_binary",                 # experiment name
    seed=42,                               # Seed for reproducibility
    load_best_model_at_end=False,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,
    use_cpu=True
)

In [76]:
# Dummy data
sequences = ["GLVM", "MGLV"]

# Initialize the model with the ESM3InferenceClient instance
model_instance = ESM3ForTokenClassification(model=model, num_labels=1)

# Forward pass with dummy data
tokens = torch.tensor(model_instance.tokenizer.encode(("", "GLVM"))).unsqueeze(0).to(model_instance.device)
average_plddt = torch.tensor([70.0], dtype=next(model_instance.esm3.parameters()).dtype).to(model_instance.device)

with torch.no_grad():
    outputs = model_instance.esm3(sequence_tokens=tokens, average_plddt=average_plddt)
    print("Type of outputs:", type(outputs))
    print("Attributes of outputs:", dir(outputs))
    if isinstance(outputs, dict):
        print("Keys in ESMOutput:", outputs.keys())
    elif isinstance(outputs, tuple):
        print(f"Number of elements in ESMOutput tuple: {len(outputs)}")
    print("Contents of outputs:", outputs)
    if hasattr(outputs, 'last_hidden_state'):
        print("Last Hidden State:", outputs.last_hidden_state)
    if hasattr(outputs, 'hidden_states'):
        print("Hidden States:", outputs.hidden_states)
    if hasattr(outputs, 'logits'):
        print("Logits:", outputs.logits)

Sequence tokenizer found.
Encoded tokens (dummy): [6, 4, 7, 20]
Embedding dimension: 1536
Type of outputs: <class 'esm.models.esm3.ESMOutput'>
Attributes of outputs: ['__annotations__', '__attrs_attrs__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__match_args__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'embeddings', 'function_logits', 'residue_logits', 'sasa_logits', 'secondary_structure_logits', 'sequence_logits', 'structure_logits']
Contents of outputs: ESMOutput(sequence_logits=tensor([[[-14.7519, -14.8172, -14.7101, -14.7850,  -3.8522,   1.1576,   9.8740,
            0.7219,   3.7604,   1.1795,  -0.2083,  -2.4984,  -2.9204,   1.4685,
           -2.5869,  -2.4647,  -3.8069,  -1.1124,  -4.1094,  -4.7673,  -4.7

In [53]:
# Instantiate and train the model
trainer = MaskedRegressTrainer(
    model=model_init_1(),   # Use the ESM3-based model for token classification
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=collator_fn,
    compute_metrics=compute_metrics_regr
)

# Train the model
trainer.train()

# Save the fine-tuned weights
fine_tuned_model_path = os.path.join(efs_model_path, "fine_tuned_sema_3_ESM3.pth")
torch.save(trainer.model.state_dict(), fine_tuned_model_path)

AttributeError: 'EsmSequenceTokenizer' object has no attribute 'batch_encode'