# ESM-2 Domain Adaptation with Uniref100 dataset

In this notebook, we demonstrate how to perform full-parameter fine tuning of the ESM-2 protein language model on uniref100 dataset.

---
## 0. Install dependencies

In [5]:
%pip install -q --upgrade pip
%pip install -q --upgrade sagemaker boto3 awscli datasets boto3 ipywidgets



[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [30]:
import boto3
from datasets import load_dataset, DatasetDict, Dataset
import os
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch
from time import strftime
from transformers import AutoTokenizer

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
s3 = boto_session.client("s3")
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
REGION_NAME = sagemaker_session.boto_region_name
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "esm-2-uniref100-benchmarking"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

EXPERIMENT_NAME = f"esm-2-benchmarking-uniref100" + strftime("%Y-%m-%d-%H-%M-%S")
print(f"Experiment name is {EXPERIMENT_NAME}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Assumed SageMaker role is arn:aws:iam::111918798052:role/DevelopmentRole
S3 path is s3://sagemaker-us-east-1-111918798052/esm-2-uniref100-benchmarking
Experiment name is esm-2-benchmarking-uniref1002023-10-12-19-16-31


In [69]:
# MODEL_ID="facebook/esm2_t48_15B_UR50D"
# MODEL_ID="facebook/esm2_t36_3B_UR50D"
MODEL_ID="facebook/esm2_t33_650M_UR50D"
# MODEL_ID="facebook/esm2_t30_150M_UR50D"
# MODEL_ID="facebook/esm2_t12_35M_UR50D"
# MODEL_ID = "facebook/esm2_t6_8M_UR50D"

---
## 1. Pre-Torkenize the data 

Torkenized using glue script. 

In [7]:
train_s3_uri_uniref100 = "s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/train/"
test_s3_uri_uniref100 = "s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/test/"

## 2. Create data map needed for training

Create index map of torkenized data using glue script. 

## 2.1 (Optional) Get sample data for a sample run. 

In [50]:
!aws s3 cp s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/train/train_index_map/part-00000-930723dc-5f05-4e8c-a356-1e7c9ef1c115-c000.csv ./tmp/
!aws s3 cp s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/test/test_index_map/part-00000-b42d82f2-3f85-471d-9b41-e43e8aa2a8af-c000.csv ./tmp/



download: s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/train/train_index_map/part-00000-930723dc-5f05-4e8c-a356-1e7c9ef1c115-c000.csv to tmp/part-00000-930723dc-5f05-4e8c-a356-1e7c9ef1c115-c000.csv
download: s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/test/test_index_map/part-00000-b42d82f2-3f85-471d-9b41-e43e8aa2a8af-c000.csv to tmp/part-00000-b42d82f2-3f85-471d-9b41-e43e8aa2a8af-c000.csv


In [64]:
import pandas as pd
train_index_map = pd.read_csv("./tmp/part-00000-930723dc-5f05-4e8c-a356-1e7c9ef1c115-c000.csv")
train_index_map[0:1500].to_csv("./tmp/sample_train_10.csv")

In [65]:
test_index_map = pd.read_csv("./tmp/part-00000-b42d82f2-3f85-471d-9b41-e43e8aa2a8af-c000.csv")
train_index_map[0:250].to_csv("./tmp/sample_test_10.csv")

In [66]:
!aws s3 cp ./tmp/sample_train_10.csv s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/train/sample_train_index_map/
!aws s3 cp ./tmp/sample_test_10.csv s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/test/sample_test_index_map/

upload: tmp/sample_train_10.csv to s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/train/sample_train_index_map/sample_train_10.csv
upload: tmp/sample_test_10.csv to s3://us-east-1-protein-ref-data/uniref100/torkenized-1mb-v3/test/sample_test_index_map/sample_test_10.csv


## 3. Train on multiple g5.2xlarge

In [70]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: ([0-9.]*)"},
    {"Name": "step", "Regex": "Step: ([0-9.]*)"},
    {"Name": "train_loss", "Regex": "Training Loss: ([0-9.e-]*)"},
    {"Name": "train_perplexity", "Regex": "Training Perplexity: ([0-9.e-]*)"},
    {
        "Name": "train_samples_per_second",
        "Regex": "Training Samples/sec: ([0-9.e-]*)",
    },
    {
        "Name": "train_tokens_per_second",
        "Regex": "Training Tokens/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_loss", "Regex": "Eval Loss: ([0-9.e-]*)"},
    {"Name": "eval_perplexity", "Regex": "Eval Perplexity: ([0-9.e-]*)"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "Eval Samples/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_tokens_per_second", "Regex": "Eval Tokens/sec: ([0-9.e-]*)"},
]

In [82]:
# Additional training parameters
hyperparameters = {
    "num_train_epochs": 2,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 12,
    "per_device_eval_batch_size": 12,
    "bf16": True,
    "logging_steps": 8,
    "optim": "adamw_torch",
    "pretrain" : 1,
    "train_sample_count" : 10000,
    "train_index_file_path" : "sample_train_index_map",
    "test_index_file_path" : "sample_test_index_map"
}

# creates Hugging Face estimator
g5_estimator = PyTorch(
    base_job_name="esm-2-uniref100-2p3dn24",
    entry_point="cuda-uniref100-pretorkenized-mlm-train-ddp-fsdp.py",
    source_dir="scripts/training/cuda/uniref100",
    instance_type="ml.p3dn.24xlarge",
    instance_count=2,
    image_uri=f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training:1.13.1-gpu-py39-cu117-ubuntu20.04-sagemaker",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    g5_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri_uniref100, input_mode="FastFile"),
            "test": TrainingInput(s3_data=test_s3_uri_uniref100, input_mode="FastFile"),
        },
        wait=True,
    )

INFO:sagemaker:Creating training-job with name: esm-2-uniref100-2p3dn24-2023-10-13-04-37-07-449


Using provided s3_resource
2023-10-13 04:37:07 Starting - Starting the training job...
2023-10-13 04:37:23 Starting - Preparing the instances for training............
2023-10-13 04:39:19 Downloading - Downloading input data...
2023-10-13 04:39:56 Training - Downloading the training image............
2023-10-13 04:42:02 Training - Training image download completed. Training in progress.....[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-10-13 04:42:42,615 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-10-13 04:42:42,675 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2023-10-13 04:42:42,686 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-10-13 04:42:42,688 sagemaker_pytorch_container.training INFO     Invoking TorchDistributed...[0m
[

ClientError: An error occurred (ValidationException) when calling the UpdateTrialComponent operation: 2 validation errors detected: Value 'Error for Training job esm-2-uniref100-2p3dn24-2023-10-13-04-37-07-449: Failed. Reason: AlgorithmError: ExecuteUserScriptError:
ExitCode 1
ErrorMessage "torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 2; 31.74 GiB total capacity; 30.89 GiB already allocated; 15.38 MiB free; 30.90 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
 Traceback (most recent call last)
 File "/opt/ml/code/cuda-uniref100-pretorkenized-mlm-train-ddp-fsdp.py", line 365, in <module>
 outputs = model(**batch)  # Forward pass
 File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
 return forward_call(*input, **kwargs)
 File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2727, in forward
 output = self._fsdp_wrapped_module(*args, **kwargs)
 File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/flatten_params_wrapper.py", line 165, in ' at 'status.message' failed to satisfy constraint: Member must have length less than or equal to 1024; Value 'Error for Training job esm-2-uniref100-2p3dn24-2023-10-13-04-37-07-449: Failed. Reason: AlgorithmError: ExecuteUserScriptError:
ExitCode 1
ErrorMessage "torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 2; 31.74 GiB total capacity; 30.89 GiB already allocated; 15.38 MiB free; 30.90 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
 Traceback (most recent call last)
 File "/opt/ml/code/cuda-uniref100-pretorkenized-mlm-train-ddp-fsdp.py", line 365, in <module>
 outputs = model(**batch)  # Forward pass
 File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
 return forward_call(*input, **kwargs)
 File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 2727, in forward
 output = self._fsdp_wrapped_module(*args, **kwargs)
 File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/flatten_params_wrapper.py", line 165, in ' at 'status.message' failed to satisfy constraint: Member must satisfy regular expression pattern: .*