In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
import argparse

In [None]:
prefix = ''
sessname =''
nclass = 1
epochs =2
mini_batch_size =25
lr = 0.001
lr_scheduler_factor =0.1
momentum =0.9
weight_decay =0.0005
overlap = 0.5
momentum = 0.45
weight_decay =0.0005
nms_thresh = 0.45
image_shape =256
label_width =150
n_train_samples = 16551
network ='resnet-50'
optim = 'sgd'
role = ''

In [None]:
sess = sagemaker.Session()
bucket = sess.default_bucket()

training_image = get_image_uri(sess.boto_region_name, sessname, repo_version="latest")
print (training_image)

# Upload the RecordIO files to train and validation channels
train_channel = prefix + '/train'
validation_channel = prefix + '/validation'

sess.upload_data(path='train.rec', bucket=bucket, key_prefix=train_channel)
sess.upload_data(path='val.rec', bucket=bucket, key_prefix=validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)
print('output will be placed here: ',s3_output_location)


od_model = sagemaker.estimator.Estimator(training_image,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.p3.2xlarge',
                                         train_volume_size = 50,
                                         train_max_run = 360000,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)
                                         
od_model.set_hyperparameters(base_network=network,
                             use_pretrained_model=1,
                             num_classes=nclass,
                             mini_batch_size=mini_batch_size,
                             epochs=epochs,
                             learning_rate=lr,
                             lr_scheduler_step='3,6',
                             lr_scheduler_factor=lr_scheduler_factor,
                             optimizer=optim,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             overlap_threshold=overlap,
                             nms_threshold=nms,
                             image_shape=image_shape,   
                             label_width=label_width,		
                             num_training_samples=n_train_samples)

train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='application/x-recordio', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='application/x-recordio', s3_data_type='S3Prefix')
data_channels = {'train': train_data, 'validation': validation_data}
od_model.fit(inputs=data_channels, logs=True)    

