In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/rigl'

role = sagemaker.get_execution_role()

In [None]:
sagemaker.__version__

In [None]:
from sagemaker.pytorch import PyTorch
import os

instances = 1
processes = 8

one_day_in_seconds = 86400
max_run = one_day_in_seconds * 5

estimator = PyTorch(entry_point='train_imagenet_rigl.py',
                    source_dir='..',
                    role=role,
                    framework_version='1.4.0',
                    py_version='py3',
                    
                    base_job_name='rigl',
                    
                    # Instances Setup
                    train_instance_count=instances,
                    train_instance_type='ml.p3.16xlarge',
                    train_use_spot_instances=True,
                    train_max_wait=max_run,
                    train_max_run=max_run,
                    train_volume_size=300,
                    
                    hyperparameters={
                        'data':'/opt/ml/input/data/training',
                        'output-dir': '/opt/ml/model',
                        'arch': 'resnet50',
                        'workers': 40,
                        'dense-allocation': 0.1,
                        'static-topo': 0,
                        'alpha': 0.3,
                        'delta': 100,
                        'batch-size': 1024,
                        'lr': 0.1,
#                         'lr-warmup-end': 5,
                        'lr-scaling-stop': 91,
                        'epochs': 100,
                    },

                    metric_definitions=[
                        {'Name': 'top1-accuracy', 'Regex': '\*\sAcc@1\s(.*)\sAcc@5'},
                        {'Name': 'top5-accuracy', 'Regex': '\*\sAcc@1\s.*\sAcc@5\s(.*)'},
                    ]
                   )

In [None]:
# estimator.fit(file_system_input) # train with FSx Lustre as input
estimator.fit('s3://imagenet-compressed-oregon') # use imagenet s3 bucket