# Multi-Node Training on SageMaker Training job

In [None]:
# ## Update sagemaker python sdk version
!pip install -U sagemaker

## Set model, Code and data

In [None]:
import sagemaker
from sagemaker import get_execution_role

sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()
region = sess.boto_session.region_name
print("sagemaker_default_bucket:", sagemaker_default_bucket)
print("sagemaker_region:", region)

## Setup for wandb

In [None]:
!pip install wandb

In [None]:
import wandb
wandb.login()

## Submit Training job

In [7]:
from sagemaker.estimator import Estimator
from sagemaker.pytorch import PyTorch
from datetime import datetime


instance_count = 1
instance_type = 'ml.p4d.24xlarge'  ## 8*40G

max_time = 200000

# Get the current time
current_time = datetime.now()

wandb.sagemaker_auth(path="src/")
# Format the current time as a string
# formatted_time = current_time.strftime("%Y%m%d%H%M%S")
# print(formatted_time)

base_job_name = 'whisper-finetune'
environment = {
    'NODE_NUMBER':str(instance_count),
    'MODEL_LOCAL_PATH': '/tmp/pretrain_model',
    'OUTPUT_MODEL_S3_PATH': f's3://{sagemaker_default_bucket}/whisper_finetuned/', # destination
}

estimator = PyTorch(entry_point='entry.py',
                            source_dir='src/',
                            role=role,
                            environment=environment,
                            framework_version='2.4.0',
                            py_version='py311',
                            script_mode=True,
                            instance_count=instance_count,
                            instance_type=instance_type,
                            max_run=max_time)


In [None]:
import time
# data in channel will be automatically copied to each node - /opt/ml/input/data/
data_path = 's3://audio-train-datasets/train_demo/'
estimator.fit(
    inputs={'train': data_path},
    job_name=base_job_name + time.strftime("%Y-%m-%d-%H-%M-%S"),
)