# Amazon SageMaker を活用するための準備

In [None]:
import boto3
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import os

import sagemaker
from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'data'

role = sagemaker.get_execution_role()

session = boto3.session.Session()
region = session.region_name
print(f"AWS region:{region}")

image = "yolov5-sagemaker-cu110"  # Example: mask-rcnn-smdataparallel-sagemaker
tag = "pt1.8"  # Example: pt1.8

# 学習データの準備

In [None]:
import yaml
from pathlib import Path
import os

def check_dataset(dict):
    """ check what dataset is used for training from data_dict and download it. This function is from util/general.py
    """
    # Download dataset if not found locally
    val, s = dict.get('val'), dict.get('download')
    if val and len(val):
        val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
        if not all(x.exists() for x in val):
            print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
            if s and len(s):  # download script
                if s.startswith('http') and s.endswith('.zip'):  # URL
                    f = Path(s).name  # filename
                    print(f'Downloading {s} ...')
                    torch.hub.download_url_to_file(s, f)
                    r = os.system(f'unzip -q {f} -d ../ && rm {f}')  # unzip
                elif s.startswith('bash '):  # bash script
                    print(f'Running {s} ...')
                    r = os.system(s)
                else:  # python script
                    r = exec(s)  # return None
                print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure'))  # print result
            else:
                raise Exception('Dataset not found.')

In [None]:
with open("data/coco.yaml") as f:
    data_dict = yaml.safe_load(f)  # data dict

In [None]:
check_dataset(data_dict)

In [None]:
import sagemaker
from sagemaker import get_execution_role

# SageMaker を活用するための権限が付与された Role を準備します。
role = get_execution_role()

sagemaker_session = sagemaker.Session()
input_coco = sagemaker_session.upload_data(path='../coco', key_prefix='coco')

# 学習用 Docker イメージの準備

In [None]:
!pygmentize ./Dockerfile

### Docker イメージをビルドして ECR へ push

In [None]:
%%time
! chmod +x build_and_push.sh; bash build_and_push.sh {region} {image} {tag}

# 学習の実行

In [None]:
import os
from sagemaker.pytorch import PyTorch
from sagemaker.local import LocalSession

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]

instance_type = "ml.p3.8xlarge"  # Other supported instance type: ml.p3.16xlarge, ml.p4d.24xlarge
instance_count = 1  # You can use 2, 4, 8 etc.
docker_image = f"{account}.dkr.ecr.{region}.amazonaws.com/{image}:{tag}" 
job_name = "pytorch-sm-yolo"

In [None]:
estimator = PyTorch(
    entry_point="train.py",
    role=role,
    volume_size=700,
    source_dir=".",
    image_uri=docker_image,
    instance_count=instance_count,
    instance_type=instance_type,
    framework_version="1.8.1",
    sagemaker_session=sagemaker_session
)

In [None]:
data_channels = {"data": input_coco}

estimator.fit(inputs=data_channels, job_name=job_name, wait=False)