# Amazon SageMaker Workshop - PyTroch Native DDP - MNIST

## Contents

1. [目标](#目标)
1. [数据准备](#数据准备)
1. [高性能存储FSx For Lustre](#高性能存储)
1. [分布式训练](#分布式训练)
1. [部署](#部署)
1. [推理](#推理)
---

## 目标

本实验主要帮助用户快速掌握Amazon SageMaker 使用PyTorch DDP做分布式训练，及部署可扩展环境和推理，并帮助用户快速掌握在训练过程中使用FSx for Lustre进行存储加速。

注意：仅支持 单机单卡 / 单机多卡 / 多机多卡 场景

---


## 数据准备

～2分钟

In [None]:
!pip install boto3=="1.23.10" sagemaker=="2.104.0"
!pip3 install torch==1.4.0 torchvision==0.5.0 -f https://download.pytorch.org/whl/cu101/torch_stable.html

#please restart kernel

In [None]:
#Sagemaker basic setting
import sagemaker
import time
from datetime import datetime
import torch
import boto3
sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'dji'

role = sagemaker.get_execution_role()
print("torch.__version__:{}".format(torch.__version__))
print("boto3.__version__:{}".format(boto3.__version__))
print("sagemaker.__version__:{}".format(sagemaker.__version__))
print("bucket:{}".format(bucket))
print("role:{}".format(role))

In [None]:
#download from pytorch 直接使用提供的data.tar.gz文件
from torchvision import datasets, transforms

datasets.MNIST('data', download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))


In [None]:
# Uploading the data to S3
inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))


# 高性能存储

～18分钟

In [None]:
#get default vpc_id and one subnet_id
import boto3

client = boto3.client('ec2')

response = client.describe_vpcs(
    Filters=[
        {
            'Name': 'is-default',
            'Values': [
                'true',
            ]
        },
    ],
    MaxResults=5
)

vpc_id = response['Vpcs'][0]['VpcId']

response = client.describe_subnets(
    Filters=[
        {
            'Name': 'vpc-id',
            'Values': [
                vpc_id,
            ]
        },
        {
            'Name': 'availability-zone',
            'Values': [
                'us-east-1b',
            ]
        },
    ],
    MaxResults=10
)

subnet_id = response['Subnets'][0]['SubnetId']

response = client.describe_subnets(
    Filters=[
        {
            'Name': 'vpc-id',
            'Values': [
                vpc_id,
            ]
        },
    ],
    MaxResults=10
)
subnet_ids = []

for subnet in response['Subnets']:
    subnet_ids.append(subnet['SubnetId'])
    
print("vpc_id:{} ; subnet_id: {}; subnet_ids: {} ".format(vpc_id,subnet_id,subnet_ids))

In [None]:
#create or check security group
response = {}
security_group_id = ''
try:
    response = client.describe_security_groups(
        GroupNames=[
            'fsx-dji',
        ],
    )
    security_group_id = response['SecurityGroups'][0]['GroupId']
    print("security_group_id:{} ".format(security_group_id))
except:
    response = client.create_security_group(
        GroupName='fsx-dji',
        VpcId=vpc_id,
        Description='fsx security_group'
    )
    security_group_id = response['GroupId']
    print("new security_group_id:{} ".format(security_group_id))

    data = client.authorize_security_group_ingress(
        GroupId=security_group_id,
        IpPermissions=[
            {'IpProtocol': 'tcp',
             'FromPort': 1,
             'ToPort': 65535,
             'IpRanges': [{'CidrIp': '0.0.0.0/0'}]}
        ])
    print('Ingress Successfully Set %s' % data)  



In [None]:
# Create vpc endpoint s3 gateway 
response = client.describe_route_tables(
    Filters=[
        {
            'Name': 'vpc-id',
            'Values': [
                vpc_id,
            ]
        },
    ],
    MaxResults=5
)
RouteTableId = response['RouteTables'][0]['RouteTableId']
print('RouteTableId : {}'.format(RouteTableId))

response = client.create_vpc_endpoint(
    VpcEndpointType='Gateway',
    VpcId=vpc_id,
    ServiceName='com.amazonaws.us-east-1.s3',
    RouteTableIds=[
        RouteTableId,
    ]
)
print('Create vpc endpoint s3 gateway successfully! ')

In [None]:
#create FSx create_file_system
import boto3

client = boto3.client('fsx')
response = client.create_file_system(
    FileSystemType='LUSTRE',
    StorageCapacity=1200,
    StorageType='SSD',
    SubnetIds=[
        subnet_id,
    ],
    SecurityGroupIds=[
        security_group_id,
    ],
    LustreConfiguration={
        'DeploymentType': 'PERSISTENT_2',
        #'ImportPath': 's3://{}/mnist'.format(bucket),
        'PerUnitStorageThroughput': 250,
        #'AutomaticBackupRetentionDays': 0,
    }
    
)

FileSystemId=response['FileSystem']['FileSystemId']

#MountName=response['FileSystem']['MountName']
print('FileSystemId: {}  '.format(FileSystemId))

In [None]:
#create_data_repository_association 
response = client.create_data_repository_association(
    FileSystemId=FileSystemId,
    FileSystemPath='/dji',
    DataRepositoryPath='s3://{}/dji'.format(bucket),
    BatchImportMetaDataOnCreate=True,
    S3={
        'AutoImportPolicy': {
            'Events': [
                'NEW','CHANGED','DELETED',
            ]
        },
    },
)

AssociationId = response['Association']['AssociationId']

In [None]:
%%time
#check fsx/association is ready
MountName = ''
#AssociationId = 'dra-08e0b7db42944abc8'
while True:
    response = client.describe_file_systems(
        FileSystemIds=[
            FileSystemId,
        ],
        MaxResults=5
    ),
    print('FileSystem status is {}'.format(response[0]['FileSystems'][0]['Lifecycle']))
    if response[0]['FileSystems'][0]['Lifecycle']=='AVAILABLE':
        MountName=response[0]['FileSystems'][0]['LustreConfiguration']['MountName']
        break
    time.sleep(60)
    

print('FileSystemId: {} , MountName: {} '.format(FileSystemId,MountName))
while True:
    response = client.describe_data_repository_associations(
        AssociationIds=[
            AssociationId,
        ],
        MaxResults=5
    )
    print('Data_repository_associations status is {}'.format(response['Associations'][0]['Lifecycle']))
    if response['Associations'][0]['Lifecycle']=='AVAILABLE':
        break
    time.sleep(60)



In [None]:
%%time
# check describe_data_repository_tasks
while True:
    response = client.describe_data_repository_tasks(
        MaxResults=5
    )
    print('data_repository_tasks status is {}'.format(response['DataRepositoryTasks'][0]['Lifecycle']))
    if response['DataRepositoryTasks'][0]['Lifecycle']=='SUCCEEDED':
        break
    time.sleep(60)
    

In [None]:
print('Data is ready, the training job can be started!')

## 分布式训练

～10分钟

In [None]:
vpc_id='vpc-06a89deaa85410c41'
subnet_id='subnet-0c510fde0aaf2d2b4'
subnet_ids= ['subnet-040931094ccbb99ab', 'subnet-0c510fde0aaf2d2b4', 'subnet-0778ab9cab7ae8679', 'subnet-09044a012cf45536e', 'subnet-06a5e6659440bcc9a', 'subnet-0129c8b3f02da025a'] 
security_group_id='sg-0f04c56bf9daf8d50'
FileSystemId='fs-07af543dc6e864140'
MountName = 'qkibbbev' 

In [None]:
!pygmentize mnist-ddp.py

In [None]:
# vpc_id = 'vpc-74bd990d'
# subnet_ids = ['subnet-03f6a221252c4e388', 'subnet-02071bfe4f24324cb', 'subnet-0a58f474d1bc358b0', 'subnet-0abd567068550fd73', 'subnet-00c5e11db0c2f4a09', 'subnet-0b47b661ec76273ca', 'subnet-03fc2c059d01d5487', 'subnet-0cc9c3f03f6c07b47', 'subnet-384dcd5c']  

# security_group_id = 'sg-0b89fbfb4f4e483b0'

# FileSystemId = 'fs-0ada02c1767ce8aa3' 
# MountName = 'ozgrpbev' 

In [None]:
#configure fsx fileinput
from sagemaker.inputs import FileSystemInput,TrainingInput

#standard dataset
dataset_directory_path =  "/{}/dji".format(MountName)  
file_system_access_mode = "rw"
file_system_type = "FSxLustre"
dataset_fsx = FileSystemInput(
    file_system_id=FileSystemId,
    file_system_type=file_system_type,
    directory_path=dataset_directory_path,
    file_system_access_mode=file_system_access_mode,
)
#print(dataset_fsx)

In [None]:
#update training script file
with open('mnist-ddp.py', 'r') as file:
    # read a list of lines into data
    data = file.readlines()

# now change the 2nd line, note that you have to add a newline
data[70] = '    is_distributed = (args.num_gpus > 0 and args.backend is not None) or len(args.hosts) > 1\n'

# and write everything back
with open('mnist-ddp.py', 'w') as file:
    file.writelines( data )

In [None]:
# Run training in SageMaker
from sagemaker.pytorch import PyTorch

estimator = PyTorch(entry_point='mnist-ddp.py',
                    role=role,
                    framework_version='1.5.0',
                    py_version='py3',
                    instance_count=1,
                    instance_type='ml.p3.2xlarge',
                    subnets=subnet_ids,
                    security_group_ids=[security_group_id],
                    file_system_id = FileSystemId,
                    hyperparameters={
                        'epochs': 6,
                        'backend': 'nccl'
                    },
                    disable_profiler=True, # Reduce number of logs since we don't need profiler or debugger for this training
                    debugger_hook_config=False,)

In [None]:
%%time
# start training
job_name = "dji-{}".format(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
estimator.fit(inputs={'training': dataset_fsx},job_name =job_name)

## 部署
～3分钟

In [None]:
%%time
# start deployment without autoscaling
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m5.2xlarge')

In [None]:
print(predictor.endpoint_name)

## 推理


In [None]:
from IPython.display import HTML
HTML(open("input.html").read())

In [None]:
import pandas as pd
import numpy as np

image = np.array([data], dtype=np.float32)
response = predictor.predict(image)
prediction = response.argmax(axis=1)[0]
print(prediction)

## ONLY for existed endpoint, create new Predictor

In [None]:
from sagemaker.predictor import Predictor
#predictor = Predictor()
from sagemaker.serializers import NumpySerializer
from sagemaker.deserializers import NumpyDeserializer
#import sagemaker.serializers.
###
#'please change to your endpoint_name'
###
endpoint_name = 'pytorch-training-2022-08-19-17-20-36-703'
predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=NumpySerializer(),
    deserializer=NumpyDeserializer(),
)


In [None]:
import numpy as np
image = np.array([data], dtype=np.float32)
response = predictor.predict(image)
prediction = response.argmax(axis=1)[0]
print(prediction)

### Cleanup

After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it

In [None]:
#if you want to delete endpoint, please umcomment the line as below
#estimator.delete_endpoint()  