In [None]:
#Prior to execute this notebook, please make sure you have already cloned the sample data
#!git clone https://github.com/xieyongliang/all-in-one-ai-sample-data.git ../../../all-in-one-ai-sample-data

In [None]:
import sagemaker
from sagemaker.estimator import Estimator

In [None]:
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

In [None]:
import boto3
account_id = boto3.client('sts').get_caller_identity().get('Account')
region_name = boto3.session.Session().region_name

In [None]:
!./build_and_push.sh $region_name

In [None]:
train_dir = 's3://{0}/{1}/data/train'.format(bucket, 'stylegan')
!aws s3 cp ../../../all-in-one-ai-sample-data/stylegan/train/ $train_dir --recursive

In [None]:
hyperparameters = {
    "data": "/opt/ml/input/data/dataset/animeface.zip",
    "outdir": "/opt/ml/model",
    "gpus": "8",
    "kimg": "1000"
}
image_uri = '{0}.dkr.ecr.{1}.amazonaws.com/all-in-one-ai-stylegan-training'.format(account_id, region_name)
instance_type = 'ml.p3.16xlarge'
instance_count = 1
inputs = {
    'dataset': train_dir
}

In [None]:
estimator = Estimator(
    role = role,
    instance_count=1,
    instance_type = instance_type,
    image_uri = image_uri,
    hyperparameters=hyperparameters
)

In [None]:
estimator.fit(inputs)

In [None]:
training_job_name = estimator.latest_training_job.name

In [None]:
model_name = None
model_data = 's3://{}/{}/output/model.tar.gz'.format(bucket, training_job_name)
image_uri = '{0}.dkr.ecr.{1}.amazonaws.com/all-in-one-ai-stylegan-inference:latest'.format(account_id, region_name)
model_environment = {
    #'network':'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl'
}

In [None]:
from sagemaker.model import Model
from sagemaker.predictor import Predictor

In [None]:
model = Model(
    name = model_name,
    model_data = model_data,
    role = role,
    image_uri = image_uri,
    env = model_environment,
    predictor_cls = Predictor
)

In [None]:
endpoint_name = None
instance_type = 'ml.m5.xlarge'
instance_count = 1

In [None]:
predictor = model.deploy(
    endpoint_name = endpoint_name,
    instance_type = instance_type, 
    initial_instance_count = instance_count
)

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()

inputs = {
    'trunc': '1',
    'seeds': '85,265,297,849',
    'output_s3uri': 's3://{0}/{1}/data/inference/output'.format(bucket, 'stylegan')
}

predictor.predict(
    {
        'inputs': inputs
    }
)

In [None]:
predictor.delete_endpoint()