# ESM-2 Domain Adaptation with Uniref100 dataset

In this notebook, we demonstrate how to perform full-parameter fine tuning of the ESM-2 protein language model on uniref100 dataset.

---
## 0. Install dependencies

In [None]:
%pip install -q --upgrade pip
%pip install -q --upgrade sagemaker boto3 awscli boto3 ipywidgets



In [None]:
import boto3
import os
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch
from time import strftime


boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
s3 = boto_session.client("s3")
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
REGION_NAME = sagemaker_session.boto_region_name
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "esm-2-uniref100-benchmarking"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

EXPERIMENT_NAME = f"esm-2-benchmarking-ref100-650M" + strftime("%Y-%m-%d-%H-%M-%S")
print(f"Experiment name is {EXPERIMENT_NAME}")

In [None]:
# MODEL_ID="facebook/esm2_t48_15B_UR50D"
# MODEL_ID="facebook/esm2_t36_3B_UR50D"
MODEL_ID="facebook/esm2_t33_650M_UR50D"
# MODEL_ID="facebook/esm2_t30_150M_UR50D"
# MODEL_ID="facebook/esm2_t12_35M_UR50D"
# MODEL_ID = "facebook/esm2_t6_8M_UR50D"

---
## 1. Pre-tokenize the data 

Tokenized using glue script. 

In [None]:
train_s3_uri_uniref100 = "s3://<bucket>/uniref100/torkenized-1mb-650m-v1/train"
test_s3_uri_uniref100 = "s3://<bucket>/uniref100/torkenized-1mb-650m-v1/test"


## 2. Create data map needed for training

Create index map of tokenized data using glue script. 

## 2.1 (Optional) Get sample data for a sample run. 

In [None]:
train_index_file = !(aws s3 ls {train_s3_uri_uniref100}/train_index_map/) 
train_index_file = train_index_file[0].split()[-1]
train_index_file_full_path = train_s3_uri_uniref100 + "/train_index_map/" + train_index_file

test_index_file = !(aws s3 ls {test_s3_uri_uniref100}/test_index_map/) 
test_index_file = test_index_file[0].split()[-1]
test_index_file_full_path = test_s3_uri_uniref100 + "/test_index_map/"+ test_index_file
test_index_file_full_path

In [None]:
!mkdir ./tmp
!aws s3 cp {train_index_file_full_path} ./tmp/
!aws s3 cp {test_index_file_full_path} ./tmp/

In [None]:
import pandas as pd
train_index_map = pd.read_csv(f"./tmp/{train_index_file}")
train_index_map

In [None]:
train_index_map.iloc[0:3].to_csv("./tmp/sample_train_100.csv")

In [None]:
test_index_map = pd.read_csv(f"./tmp/{test_index_file}")
test_index_map

In [None]:
test_index_map.iloc[0:1].to_csv("./tmp/sample_test_100.csv")

In [None]:
{train_s3_uri_uniref100}/sample_train_index_map/

In [None]:
!aws s3 cp ./tmp/sample_train_100.csv {train_s3_uri_uniref100}/sample_train_index_map/
!aws s3 cp ./tmp/sample_test_100.csv {test_s3_uri_uniref100}/sample_test_index_map/


## 3. Train on multiple g5.2xlarge

In [None]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: ([0-9.]*)"},
    {"Name": "step", "Regex": "Step: ([0-9.]*)"},
    {"Name": "train_loss", "Regex": "Training Loss: ([0-9.e-]*)"},
    {"Name": "train_perplexity", "Regex": "Training Perplexity: ([0-9.e-]*)"},
    {
        "Name": "train_samples_per_second",
        "Regex": "Training Samples/sec: ([0-9.e-]*)",
    },
    {
        "Name": "train_tokens_per_second",
        "Regex": "Training Tokens/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_loss", "Regex": "Eval Loss: ([0-9.e-]*)"},
    {"Name": "eval_perplexity", "Regex": "Eval Perplexity: ([0-9.e-]*)"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "Eval Samples/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_tokens_per_second", "Regex": "Eval Tokens/sec: ([0-9.e-]*)"},
]

In [None]:
# Additional training parameters
hyperparameters = {
    "num_epochs": 2,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 10,
    "per_device_eval_batch_size": 10, 
    "bf16": True,
    "logging_steps": 2,
    "optim": "adamw_torch",
    "pretrain" : 1,
    "train_sample_count" : 10000,
    "train_index_file_path" : "sample_train_index_map",
    "test_index_file_path" : "sample_test_index_map",
    "gradient_accumulation_steps" : 10
    
}

# creates Hugging Face estimator
g5_estimator = PyTorch(
    base_job_name="esm-2-uniref100-p3dn-gacc",
    entry_point="cuda-uniref100-pretorkenized-mlm-train-ddp-fsdp.py",
    source_dir="training/cuda/uniref100",
    instance_type="ml.p3dn.24xlarge",
    instance_count=1,
    image_uri=f"763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
    keep_alive_period_in_seconds=1800
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    g5_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri_uniref100, input_mode="FastFile"),
            "test": TrainingInput(s3_data=test_s3_uri_uniref100, input_mode="FastFile"),
        },
        wait=False,
    )

In [None]:
# # Additional training parameters
# hyperparameters = {
#     "num_train_epochs": 2,
#     "model_id": MODEL_ID,
#     "per_device_train_batch_size": 10,
#     "per_device_eval_batch_size": 10,
#     "bf16": True,
#     "logging_steps": 8,
#     "optim": "adamw_torch",
#     "pretrain" : 1,
#     "train_sample_count" : 10000,
#     "train_index_file_path" : "sample_train_index_map",
#     "test_index_file_path" : "sample_test_index_map"
# }

# from sagemaker import ProfilerConfig, Profiler
# profiler_config = ProfilerConfig(
#     profile_params = Profiler(cpu_profiling_duration=3600)
# )

# # creates Hugging Face estimator
# g5_estimator = PyTorch(
#     base_job_name="esm-2-uniref100-2p3dn24",
#     entry_point="cuda-uniref100-pretorkenized-mlm-train-ddp-fsdp-ptprof.py",
#     source_dir="training/cuda/uniref100",
#     instance_type="ml.p3dn.24xlarge",
#     instance_count=1,
#     image_uri=f"763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker",
#     output_path=f"{S3_PATH}/output",
#     role=sagemaker_execution_role,
#     hyperparameters=hyperparameters,
#     metric_definitions=metric_definitions,
#     sagemaker_session=sagemaker_session,
#     distribution={"torch_distributed": {"enabled": True}},
#     tags=[{"Key": "project", "Value": "esm-benchmarking"}],
#     profiler_config=profiler_config
# )

# with Run(
#     experiment_name=EXPERIMENT_NAME,
#     sagemaker_session=sagemaker_session,
# ) as run:
#     g5_estimator.fit(
#         {
#             "train": TrainingInput(s3_data=train_s3_uri_uniref100, input_mode="FastFile"),
#             "test": TrainingInput(s3_data=test_s3_uri_uniref100, input_mode="FastFile"),
#         },
#         wait=False,
#     )