In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import HyperparameterTuner, ContinuousParameter, IntegerParameter

role = sagemaker.get_execution_role()
session = sagemaker.Session()

# Provide your own ECR image or use SageMaker's PyTorch container:
framework_version = '1.9'
python_version = 'py38'

In [None]:
pytorch_estimator = PyTorch(
    entry_point='train.py',
    source_dir='.',
    role=role,
    instance_count=1,
    instance_type='ml.m5.xlarge',
    framework_version=framework_version,
    py_version=python_version,
    hyperparameters={
        'epochs': 30,
        'batch_size': 128
    }
)

metric_definitions = [
    {
        'Name': 'Validation-accuracy',
        'Regex': r'Validation-accuracy:\s*([0-9\.]+)'
    }
]
pytorch_estimator.metric_definitions = metric_definitions

hyperparameter_ranges = {
    'lr': ContinuousParameter(1e-5, 0.1),
    'wd': ContinuousParameter(1e-6, 0.1),
    'hidden_size': IntegerParameter(64, 256),
    'dropout': ContinuousParameter(0.0, 0.5)
}

tuner = HyperparameterTuner(
    estimator=pytorch_estimator,
    objective_metric_name='Validation-accuracy',
    objective_type='Maximize',
    hyperparameter_ranges=hyperparameter_ranges,
    max_jobs=10,
    max_parallel_jobs=2,
    strategy='Bayesian'
)


In [None]:
tuner.fit()