## SageMaker 활용

- 데이터가 준비되었으면, local 에서 샘플 데이터로 학습을 진행해 볼 수 있습니다.
- Local 에서 정상적으로 동작하면 sagemaker 에 올려서 테스트 해 볼 수 있습니다.


In [None]:
import time
import sagemaker
import boto3
from sagemaker import image_uris
from sagemaker.utils import name_from_base
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch

sess = sagemaker.Session()
region = boto3.Session().region_name
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()

print(f"Default bucket : {bucket}")
print(f"Role : {role}")

In [None]:
s3_pretrained_uri = f"s3://{bucket}/workshop/endo-vit/models"
s3_data_uri = f"s3://{bucket}/workshop/endo-vit/data/segmentation"

In [None]:
# instance_type = "ml.g5.2xlarge"
# instance_type = "ml.g4dn.2xlarge"
instance_type = "ml.p3.2xlarge"  # quota 부족시 고려
# instance_type = "ml.g4dn.12xlarge" # Multi-GPU 

if instance_type in ['local', 'local_gpu']:
    from sagemaker.local import LocalSession
    sm_session = LocalSession()
    sm_session.config = {'local': {'local_code': True}}
else:
    sm_session = sagemaker.session.Session()
    
print(f"instance type : {instance_type}")

In [None]:
image_uri = image_uris.retrieve(
    framework="pytorch",
    region=region,
    version="1.13",
    py_version="py39",
    instance_type=instance_type,
    image_scope="training"
)

print(f"Image URI for sagemaker training: {image_uri}")

In [None]:
job_name = "endo-vit-seg-pt"
s3_checkpoint_uri = f"s3://{bucket}/workshop/endo-vit/{job_name}/checkpoints"

hyper_params = {}

max_run = 60 * 60 * 12 # 12 hrs
   
use_spot_instances = False
if use_spot_instances:
    max_wait = 12 * 60 * 60 # 12 hours: spot instance waiting + max runtime
else:
    max_wait = None


In [None]:
# Create the Estimator
estimator = PyTorch(
    image_uri=image_uri,
    entry_point='pretrain_script_sm.sh',    # train script
    source_dir='endovit-code',              # directory which includes all the files needed for training
    instance_type=instance_type,            # instances type used for the training job
    instance_count=1,                       # the number of instances used for training
    base_job_name=job_name,                 # the name of the training job
    role=role,                              # Iam role used in training job to access AWS ressources, e.g. S3
    sagemaker_session=sess,                 # sagemaker session
    volume_size=200,                        # the size of the EBS volume in GB
    hyperparameters=hyper_params,
    debugger_hook_config=False,
    disable_profile=True,
    use_spot_instances=use_spot_instances,
    max_run=max_run,
    max_wait=max_wait if use_spot_instances else None,
    checkpoint_s3_uri=s3_checkpoint_uri if instance_type not in ['local', 'local_gpu'] else None,
    checkpoint_local_path='/opt/ml/checkpoints' if instance_type not in ['local', 'local_gpu'] else None,
)

In [None]:
LOCAL_MODE = False

if LOCAL_MODE:
    # local mode 사용시 적절하게 변경필요
    estimator.fit(
        {
            "pretrained": f'file://./pt-models', 
            "training": f'file://./sample-data'
        },
        wait=False
    )
else:
    fast_file = lambda x: TrainingInput(x, input_mode="FastFile")
    estimator.fit(
        {
            "pretrained": fast_file(s3_pretrained_uri),
            "training": fast_file(s3_data_uri),
        },
        wait=False
    )

    from IPython.display import display, HTML

    def make_console_link(region, train_job_name, train_task='[Training]'):
        train_job_link = f'<b> {train_task} Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{train_job_name}">Training Job</a></b>'   
        cloudwatch_link = f'<b> {train_task} Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={region}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={train_job_name};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a></b>'
        return train_job_link, cloudwatch_link  

    train_job_name = estimator.latest_training_job.job_name
    train_job_link, cloudwatch_link = make_console_link(region, train_job_name, '[Endo-ViT-Training]')

    display(HTML(train_job_link))
    display(HTML(cloudwatch_link))

In [None]:
print(f"Job name: {train_job_name}")

In [None]:
# estimator.logs()
