# SEMA-1D 

SEMA-1D is a fine-tuned ESM-1v model aimed to predict epitope resiudes based on antigen protein sequence

## 1. Set up Environment

In [13]:
%pip install datasets huggingface-hub s3fs=='0.4.2'
%pip uninstall tensorflow -y

Collecting s3fs==0.4.2
  Downloading s3fs-0.4.2-py3-none-any.whl.metadata (1.3 kB)
Downloading s3fs-0.4.2-py3-none-any.whl (19 kB)
Installing collected packages: s3fs
Successfully installed s3fs-0.4.2
Note: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [6]:
import boto3
from datasets import Dataset
import json
import os
import pandas as pd
import random
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.huggingface import HuggingFace, HuggingFaceModel
from sagemaker.inputs import TrainingInput
from time import strftime
from transformers import AutoTokenizer

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
s3 = boto_session.client("s3")
sagemaker_client = boto_session.client("sagemaker")
REGION_NAME = sagemaker_session.boto_region_name

try:
    sagemaker_execution_role = sagemaker_session.get_execution_role()
except AttributeError:
    NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
    with open(NOTEBOOK_METADATA_FILE, "rb") as f:
        metadata = json.loads(f.read())
        instance_name = metadata["ResourceName"]
        domain_id = metadata.get("DomainId")
        user_profile_name = metadata.get("UserProfileName")
        space_name = metadata.get("SpaceName")
    domain_desc = sagemaker_session.sagemaker_client.describe_domain(DomainId=domain_id)
    if "DefaultSpaceSettings" in domain_desc:
        sagemaker_execution_role = domain_desc["DefaultSpaceSettings"]["ExecutionRole"]
    else:
        sagemaker_execution_role = domain_desc["DefaultUserSettings"]["ExecutionRole"]

print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "esm-sema-1d"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

EXPERIMENT_NAME = "esm-sema-1d-" + strftime("%Y-%m-%d-%H-%M-%S")
print(f"Experiment name is {EXPERIMENT_NAME}")

Assumed SageMaker role is arn:aws:iam::340752820161:role/service-role/AmazonSageMaker-ExecutionRole-20241011T160996
S3 path is s3://sagemaker-us-east-1-340752820161/esm-sema-1d
Experiment name is esm-sema-1d-2024-10-17-14-26-50


In [3]:
# Step 1: Set Persistent TORCH_HOME Directory Path
efs_model_path = "/home/sagemaker-user/user-default-efs/torch_hub"

# Step 2: Set TORCH_HOME Environment Variable to Ensure Persistent Storage
os.environ['TORCH_HOME'] = efs_model_path
if not os.path.exists(efs_model_path):
    os.makedirs(efs_model_path)

# Step 3: Set Device for Computation (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 2. Build Dataset

In [65]:
# Load the CSV data
train_data_url = 'https://raw.githubusercontent.com/AIRI-Institute/SEMAi/main/epitopes_prediction/data/sema_2.0/train_set.csv'
test_data_url = 'https://raw.githubusercontent.com/AIRI-Institute/SEMAi/main/epitopes_prediction/data/sema_2.0/test_set.csv'

train_set = pd.read_csv(train_data_url)
test_set = pd.read_csv(test_data_url)

# Group the data by pdb_id_chain
train_set = train_set.groupby('pdb_id_chain').agg({'resi_aa': list, 'contact_number': list}).reset_index()
test_set = test_set.groupby('pdb_id_chain').agg({'resi_aa': list, 'contact_number': list}).reset_index()

print(train_set.columns)
print(test_set.columns)

# Convert Pandas DataFrame to Hugging Face Dataset
train_dataset = Dataset.from_pandas(train_set)
test_dataset = Dataset.from_pandas(test_set)

# Initialize Hugging Face ESM tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")


# Define a preprocessing function
def preprocess_data(examples, max_length=1022, label_type='regression'):
    # Join the list of amino acids into a single sequence string
    sequences = [''.join(aa_list) for aa_list in examples['resi_aa']]

    # Tokenize the sequences
    encoding = tokenizer(sequences, truncation=True, max_length=max_length)

    # Assign the appropriate labels
    if label_type == 'regression':
        labels = [label[:max_length] for label in examples["contact_number"]]
    else:
        labels = [label[:max_length] for label in examples["contact_number_binary"]]

    # Validate labels
    for i, label in enumerate(labels):
        for t in label:
            if t < -1 and t != -100:
                raise ValueError(f"Invalid label value {t} at index {i}")

    encoding["labels"] = labels

    return encoding


# Apply the preprocessing function to train and test datasets
train_dataset = train_dataset.map(
    preprocess_data,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=train_dataset.column_names,
)
test_dataset = test_dataset.map(
    preprocess_data,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=test_dataset.column_names,
)

# Check format before converting to torch
print(f"Max tokenized sequence length: {max([len(seq) for seq in train_dataset['input_ids']])}")
print(f"Max label length: {max([len(label) for label in train_dataset['labels']])}")

# Set format to PyTorch (for Hugging Face Trainer)
train_dataset.set_format("torch")
test_dataset.set_format("torch")

Index(['pdb_id_chain', 'resi_aa', 'contact_number'], dtype='object')
Index(['pdb_id_chain', 'resi_aa', 'contact_number'], dtype='object')




Map (num_proc=2):   0%|          | 0/1544 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/101 [00:00<?, ? examples/s]

Max tokenized sequence length: 1022
Max label length: 1022


Look at an example record

In [66]:
# Pick a random index from the train set
random_idx = random.randint(0, len(train_dataset) - 1)

# Get the example record from the train dataset
example = train_dataset[random_idx]

# Decode the tokenized sequence back into the original sequence
decoded_sequence = tokenizer.decode(example['input_ids'])

# Display the record
print(f"Viewing example record {random_idx}")
print(f"Raw sequence:\n{decoded_sequence}\n")           # Decoded sequence
print(f"Tokenized sequence:\n{example['input_ids']}\n")  # Tokenized sequence (input_ids)
print(f"Label:\n{example['labels']}")                    # Corresponding label

Viewing example record 1525
Raw sequence:
<cls> M E W S W V F L F F L S V T T G V H S R F P N I T N L C P F H E V F N A T T F A S V Y A W N R T R I S N C V A D Y S V L Y N F A P F F A F K C Y G V S P T K L N D L C F T N V Y A D S F V I R G N E V S Q I A P G Q T G N I A D Y N Y K L P D D F T G C V I A W N S N K L D S K V S G N Y N Y L Y R L F R K S K L K P F E R D I S T E I Y Q A G N K P C N G V A G F N C Y S P L Q S Y G F R P T Y G V G H Q P Y R V V V L S F E L L H A P A T V C G P K K S T H H H H H H H H G S G S G L N D I F E A Q K I E W H E <eos>

Tokenized sequence:
tensor([ 0, 20,  9, 22,  8, 22,  7, 18,  4, 18, 18,  4,  8,  7, 11, 11,  6,  7,
        21,  8, 10, 18, 14, 17, 12, 11, 17,  4, 23, 14, 18, 21,  9,  7, 18, 17,
         5, 11, 11, 18,  5,  8,  7, 19,  5, 22, 17, 10, 11, 10, 12,  8, 17, 23,
         7,  5, 13, 19,  8,  7,  4, 19, 17, 18,  5, 14, 18, 18,  5, 18, 15, 23,
        19,  6,  7,  8, 14, 11, 15,  4, 17, 13,  4, 23, 18, 11, 17,  7, 19,  5,
        13,  8, 18,  7, 1

Finally, we upload the processed training, test, and validation data to S3.

In [67]:
train_s3_uri = S3_PATH + "/data/train"
test_s3_uri = S3_PATH + "/data/test"

train_dataset.save_to_disk(train_s3_uri)
test_dataset.save_to_disk(test_s3_uri)

Saving the dataset (0/1 shards):   0%|          | 0/1544 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/101 [00:00<?, ? examples/s]

## 3. Train Model in SageMaker

In [77]:
hyperparameters = {
    # "model_id": "facebook/esm2_t33_650M_UR50D",
    "model_id": "facebook/esm2_t6_8M_UR50D",
    "epochs": 1,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "use_gradient_checkpointing": True,
    # "lora": True,
}

metric_definitions = [
    {"Name": "epoch", "Regex": "'epoch': ([0-9.]*)"},
    {
        "Name": "max_gpu_mem",
        "Regex": "Max GPU memory use during training: ([0-9.e-]*) MB",
    },
    {"Name": "train_loss", "Regex": "'loss': ([0-9.e-]*)"},
    {
        "Name": "train_samples_per_second",
        "Regex": "'train_samples_per_second': ([0-9.e-]*)",
    },
    {"Name": "eval_loss", "Regex": "'eval_loss': ([0-9.e-]*)"},
    {"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9.e-]*)"},
]

hf_estimator = HuggingFace(
    base_job_name="esm-2-sema-1d",
    entry_point="lora-train.py",
    source_dir="scripts",
    instance_type="ml.p3.2xlarge",
    instance_count=1,
    transformers_version="4.36",
    pytorch_version="2.1",
    py_version="py310",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    checkpoint_local_path="/opt/ml/checkpoints",
    sagemaker_session=sagemaker_session,
    tags=[{"Key": "project", "Value": "esm-fine-tuning"}],
)

In [78]:
with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    hf_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri, input_mode="File"),
            "test": TrainingInput(s3_data=test_s3_uri, input_mode="File"),
        },
        wait=False,
    )

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: esm-2-sema-1d-2024-10-17-21-20-57-479


You can view metrics and debugging information for this run in SageMaker Experiments.

In [81]:
from sagemaker.analytics import ExperimentAnalytics

training_job_details = hf_estimator.latest_training_job.describe()
print(f"Training job name: {training_job_details.get('TrainingJobName')}")
print(f"Training job status: {training_job_details.get('TrainingJobStatus')}")
print(f"Training job output: {training_job_details.get('ModelArtifacts')}")

search_expression = {
    "Filters": [
        {
            "Name": "DisplayName",
            "Operator": "Contains",
            "Value": "Training",
        }
    ],
}

trial_component_analytics = ExperimentAnalytics(
    sagemaker_session=sagemaker_session,
    experiment_name=EXPERIMENT_NAME,
    search_expression=search_expression,
)

trial_component_analytics.dataframe().T

Training job name: esm-2-sema-1d-2024-10-17-21-20-57-479
Training job status: Completed
Training job output: {'S3ModelArtifacts': 's3://sagemaker-us-east-1-340752820161/esm-sema-1d/output/esm-2-sema-1d-2024-10-17-21-20-57-479/output/model.tar.gz'}


Unnamed: 0,0,1,2,3,4,5,6,7,8
TrialComponentName,esm-2-sema-1d-2024-10-17-21-20-57-479-aws-trai...,esm-2-sema-1d-2024-10-17-21-07-37-094-aws-trai...,esm-2-sema-1d-2024-10-17-19-51-16-646-aws-trai...,esm-2-sema-1d-2024-10-17-18-56-43-944-aws-trai...,esm-2-sema-1d-2024-10-17-18-36-44-409-aws-trai...,esm-2-sema-1d-2024-10-17-18-07-50-300-aws-trai...,esm-2-sema-1d-2024-10-17-17-55-43-940-aws-trai...,esm-2-sema-1d-2024-10-17-16-56-48-350-aws-trai...,esm-2-sema-1d-2024-10-17-16-33-03-818-aws-trai...
DisplayName,esm-2-sema-1d-2024-10-17-21-20-57-479-aws-trai...,esm-2-sema-1d-2024-10-17-21-07-37-094-aws-trai...,esm-2-sema-1d-2024-10-17-19-51-16-646-aws-trai...,esm-2-sema-1d-2024-10-17-18-56-43-944-aws-trai...,esm-2-sema-1d-2024-10-17-18-36-44-409-aws-trai...,esm-2-sema-1d-2024-10-17-18-07-50-300-aws-trai...,esm-2-sema-1d-2024-10-17-17-55-43-940-aws-trai...,esm-2-sema-1d-2024-10-17-16-56-48-350-aws-trai...,esm-2-sema-1d-2024-10-17-16-33-03-818-aws-trai...
SourceArn,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...,arn:aws:sagemaker:us-east-1:340752820161:train...
SageMaker.ImageUri,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...,763104351884.dkr.ecr.us-east-1.amazonaws.com/h...
SageMaker.InstanceCount,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
SageMaker.InstanceType,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge,ml.p3.2xlarge
SageMaker.VolumeSizeInGB,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
epochs,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
gradient_accumulation_steps,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0
model_id,"""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D""","""facebook/esm2_t33_650M_UR50D"""


## 4. Deploy Model as Real-Time Inference Endpoint

In [None]:
%%time

hub = {"HF_MODEL_ID": "bloyal/esm2_650M_membrane_loc", "HF_TASK": "text-classification"}

hf_model = HuggingFaceModel(
    env=hub,
    role=sagemaker_execution_role,
    transformers_version="4.28",
    pytorch_version="2.0",
    py_version="py310",
)

predictor = hf_model.deploy(
    initial_instance_count=1,
    instance_type="ml.r5.2xlarge",
    role=sagemaker_execution_role,
)

To deploy our endpoint, we call deploy() on our HuggingFace estimator object, passing in our desired number of instances and instance type.

In [82]:
predictor = hf_estimator.deploy(initial_instance_count=1, instance_type="ml.g4dn.xlarge")

ValueError: Unsupported huggingface version: 4.36. You may need to upgrade your SDK version (pip install -U sagemaker) for newer huggingface versions. Supported huggingface version(s): 4.6.1, 4.10.2, 4.11.0, 4.12.3, 4.17.0, 4.26.0, 4.28.1, 4.37.0, 4.6, 4.10, 4.11, 4.12, 4.17, 4.26, 4.28, 4.37.

Try running some known proteins

In [None]:
# Example cell membrane proteins
glp_1_receptor = "MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS"
pd1 = "MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRGSADGPRSAQPLRPEDGHCSWPL"
rit1 = "MDSGTRPVGSCCSSPAGLSREYKLVMLGAGGVGKSAMTMQFISHRFPEDHDPTIEDAYKIRIRIDDEPANLDILDTAGQAEFTAMRDQYMRAGEGFIICYSITDRRSFHEVREFKQLIYRVRRTDDTPVVLVGNKSDLKQLRQVTKEEGLALAREFSCPFFETSAAYRYYIDDVFHALVREIRRKEKEAVLAMEKKSKPKNSVWKRLKSPFRKKKDSVT"

# Example non-cell membrane proteins
tubulin_beta_1 = "MREIVHIQIGQCGNQIGAKFWEMIGEEHGIDLAGSDRGASALQLERISVYYNEAYGRKYVPRAVLVDLEPGTMDSIRSSKLGALFQPDSFVHGNSGAGNNWAKGHYTEGAELIENVLEVVRHESESCDCLQGFQIVHSLGGGTGSGMGTLLMNKIREEYPDRIMNSFSVMPSPKVSDTVVEPYNAVLSIHQLIENADACFCIDNEALYDICFRTLKLTTPTYGDLNHLVSLTMSGITTSLRFPGQLNADLRKLAVNMVPFPRLHFFMPGFAPLTAQGSQQYRALSVAELTQQMFDARNTMAACDLRRGRYLTVACIFRGKMSTKEVDQQLLSVQTRNSSCFVEWIPNNVKVAVCDIPPRGLSMAATFIGNNTAIQEIFNRVSEHFSAMFKRKAFVHWYTSEGMDINEFGEAENNIHDLVSEYQQFQDAKAVLEEDEEVTEEAEMEPEDKGH"
p53 = "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD"
adh5 = "MANEVIKCKAAVAWEAGKPLSIEEIEVAPPKAHEVRIKIIATAVCHTDAYTLSGADPEGCFPVILGHEGAGIVESVGEGVTKLKAGDTVIPLYIPQCGECKFCLNPKTNLCQKIRVTQGKGLMPDGTSRFTCKGKTILHYMGTSTFSEYTVVADISVAKIDPLAPLDKVCLLGCGISTGYGAAVNTAKLEPGSVCAVFGLGGVGLAVIMGCKVAGASRIIGVDINKDKFARAKEFGATECINPQDFSKPIQEVLIEMTDGGVDYSFECIGNVKVMRAALEACHKGWGVSVVVGVAASGEEIATRPFQLVTGRTWKGTAFGGWKSVESVPKLVSEYMSKKIKVDEFVTHNLSFDEINKAFELMHSGKSIRTVVKI"

sample = {"inputs": glp_1_receptor}
predictor.predict(sample)

In [None]:
#Epoch	Training Loss	Validation Loss	Pearson R	Mse	R2 Score
#1	0.212700	0.150756	0.251578	0.173891	-0.567424
#2	0.157400	0.165494	0.253576	0.183997	-0.658516


## 5. Clean up

Delete endpoint

In [None]:
try:
    predictor.delete_endpoint()
except:
    pass

Delete S3 data

In [None]:
bucket = boto_session.resource("s3").Bucket(S3_BUCKET)
bucket.objects.filter(Prefix=S3_PREFIX).delete()