### 1. 升级boto3, sagemaker python sdk

In [None]:
!pip install --upgrade boto3 sagemaker

In [1]:
#导入对应的库

import re
import os
import json
import uuid

import numpy as np
import pandas as pd
from time import gmtime, strftime


import boto3
import sagemaker

from sagemaker import get_execution_role,session
from sagemaker import Model, image_uris, serializers, deserializers

role = get_execution_role()
sage_session = session.Session()
bucket = sage_session.default_bucket()
aws_region = boto3.Session().region_name

print(f'sagemaker sdk version: {sagemaker.__version__}\nrole:  {role}  \nbucket:  {bucket}')

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker sdk version: 2.203.1
role:  arn:aws:iam::687912291502:role/service-role/AmazonSageMaker-ExecutionRole-20211013T113123  
bucket:  sagemaker-us-west-2-687912291502


### 2. 编译docker image (comfyui-inference)

In [None]:
## You should change below region code to the region you used, here sample is use us-west-2
!aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com

In [None]:
!./build_and_push.sh

### 3. 部署Comfyui推理服务

#### 3.1 创建dummy model_data 文件(真正的模型使用code/infernece.py进行加载)

In [None]:
!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'stablediffusion')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'stablediffusion')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

In [None]:
!aws s3 ls s3://sagemaker-us-west-2-687912291502/stablediffusion/assets/model.tar.gz

#### 3.2 创建 model 配置

In [48]:

boto3_session = boto3.session.Session()
current_region=boto3_session.region_name

client = boto3.client("sts")
account_id=client.get_caller_identity()["Account"]

client = boto3.client('sagemaker')

#使用步骤2编译好的docker images
#默认名字为: comfyui-inference-v2
container = f'{account_id}.dkr.ecr.{current_region}.amazonaws.com/comfyui-inference:latest'

model_data = f's3://{bucket}/stablediffusion/assets/model.tar.gz'

model_name = 'AIGC-ComfyUI' +  strftime("%Y-%m-%d-%H-%M-%S", gmtime())
role = get_execution_role()

primary_container = {
    'Image': container,  
    'ModelDataUrl': model_data,
    'Environment':{
        's3_bucket': bucket,
        'MODEL_PATH': "s3://sagemaker-us-west-2-687912291502/models/svd/"
    }
}

create_model_response = client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container,


)

In [49]:
_time_tag = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
_variant_name =  'AIGC-ComfyUI-'+ _time_tag
endpoint_config_name = 'AIGC-ComfyUI-' +  _time_tag

response = client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': _variant_name,
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            #'VolumeSizeInGB': 300,
            'InstanceType': 'ml.g5.4xlarge',
            'InitialVariantWeight': 1,
            "ModelDataDownloadTimeoutInSeconds": 800, # Specify the model download timeout in seconds.
            "ContainerStartupHealthCheckTimeoutInSeconds": 800, # Specify the health checkup timeout in seconds
            
        },
    ]
)

#### 3.3 部署SageMaker endpoint (这里只需要运行一次!!!)

In [50]:
endpoint_name = f'AIGC-ComfyUI-{str(uuid.uuid4())}'

print(f'终端节点:{endpoint_name} 正在创建中，首次启动中会加载模型，请耐心等待, 请在控制台上查看状态')

response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)



终端节点:AIGC-ComfyUI-e3526a3a-9074-4edf-8248-078486a3d26b 正在创建中，首次启动中会加载模型，请耐心等待, 请在控制台上查看状态


* 检查endpoint 状态

In [51]:
import time
def check_sendpoint_status(endpoint_name,timeout=600):
    client = boto3.client('sagemaker')
    current_time=0
    while current_time<timeout:
        client = boto3.client('sagemaker')
        try:
            response = client.describe_endpoint(
            EndpointName=endpoint_name
            )
            if response['EndpointStatus'] !='InService':
                raise Exception (f'{endpoint_name} not ready , please wait....')
        except Exception as ex:
            print(f'{endpoint_name} not ready , please wait....')
            time.sleep(10)
        else:
            status = response['EndpointStatus']
            print(f'{endpoint_name} is ready, status: {status}')
            break

check_sendpoint_status(endpoint_name)

AIGC-ComfyUI-e3526a3a-9074-4edf-8248-078486a3d26b is ready, status: InService


### alternative deploy api(ssh debug inference )

In [None]:
model = Model(image_uri=container, model_data=model_data, role=role,dependencies=[SSHModelWrapper.dependency_dir()] )

In [None]:
instance_type = "ml.g5.2xlarge"
from sagemaker_ssh_helper.wrapper import SSHModelWrapper

endpoint_name = sagemaker.utils.name_from_base("comfyui-byoc")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             container_startup_health_check_timeout=800             
            )
ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)
# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sage_session,
    serializer=serializers.JSONSerializer(),
)
print(f"To connect over SSH run: sm-local-ssh-training connect {ssh_wrapper.training_job_name()}")
instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  # <--NEW-- 
print(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")

## 4.测试

#### 4.1 创建测试辅助方法 

In [4]:
import time
import datetime
import uuid
import io
import traceback
from PIL import Image
import boto3
import json


s3_client = boto3.client('s3')

def get_bucket_and_key(s3uri):
    pos = s3uri.find('/', 5)
    bucket = s3uri[5 : pos]
    key = s3uri[pos + 1 : ]
    return bucket, key


def predict(endpoint_name,payload):
    runtime_client = boto3.client('runtime.sagemaker')
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(payload)
    )
    print(response)
    result = json.loads(response['Body'].read().decode())
    print(result)
    return result


def show_image(result):
    try:
        predictions = json.loads(result)
        print(predictions['result'])
        for image in predictions['result']:
            bucket, key = get_bucket_and_key(image)
            obj = s3_client.Object(bucket, key)
            bytes = obj.get()['Body'].read()
            image = Image.open(io.BytesIO(bytes))
            #resize image to 50% size
            half = 0.5
            out_image = image.resize( [int(half * s) for s in image.size] )
            out_image.show()
    except Exception as e:
        print("result is not completed, waiting...")


def show_gifs(result):
    import base64
    from IPython import display
    try:
        predictions = result['prediction']
        s3_file_path = predictions[0]
        print("s3 generated gifs path is {}".format(s3_file_path))
        bucket_name, key = get_bucket_and_key(s3_file_path)
        timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
        local_file_path="./ComfyUI_"+timestamp+".gif"
        s3_client.download_file(bucket_name, key, local_file_path)
        with open(local_file_path, 'rb') as fd:
            b64 = base64.b64encode(fd.read()).decode('ascii')
        return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')
    except Exception as e:
        print(e)
        print("result is not completed, waiting...")


def check_sendpoint_status(endpoint_name,timeout=600):
    client = boto3.client('sagemaker')
    current_time=0
    while current_time<timeout:
        client = boto3.client('sagemaker')
        try:
            response = client.describe_endpoint(
            EndpointName=endpoint_name
            )
            if response['EndpointStatus'] !='InService':
                raise Exception (f'{endpoint_name} not ready , please wait....')
        except Exception as ex:
            print(f'{endpoint_name} not ready , please wait....')
            time.sleep(10)
        else:
            status = response['EndpointStatus']
            print(f'{endpoint_name} is ready, status: {status}')
            break

#### 4.2 测试ComfyUI 文生视频

prompt从json文件加载

In [5]:
endpoint_name="AIGC-ComfyUI-e3526a3a-9074-4edf-8248-078486a3d26b"

In [10]:
prompt_json_file="./workflow_api.json"
prompt_text=""
with open(prompt_json_file) as f:
  prompt_text = json.load(f)

client_id = str(uuid.uuid4())
payload={
     "client_id":client_id,
     "prompt": prompt_text,
     "inference_type":"text2vid",
     "method":"queue_prompt"
}
prompt_id = predict(endpoint_name,payload)["prompt_id"]


{'ResponseMetadata': {'RequestId': 'c0010815-f3d7-4518-8338-f678d5eca24f', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'c0010815-f3d7-4518-8338-f678d5eca24f', 'x-amzn-invoked-production-variant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'date': 'Wed, 24 Jan 2024 12:56:29 GMT', 'content-type': 'application/json; charset=utf-8', 'content-length': '53', 'connection': 'keep-alive'}, 'RetryAttempts': 0}, 'ContentType': 'application/json; charset=utf-8', 'InvokedProductionVariant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'Body': <botocore.response.StreamingBody object at 0x7f79a4b85330>}
{'prompt_id': '8257e432-10a5-4822-a46e-26d428152b2c'}


In [11]:
payload={
     "client_id":client_id,
     "prompt_id":prompt_id,
     "inference_type":"text2vid",
     "method":"get_status"
}
while True:
    status = predict(endpoint_name,payload)
    time.sleep(10)
    if status["status"] == "success":
        break

{'ResponseMetadata': {'RequestId': '20930e62-bc51-4bc2-aaa0-0596cce3b18d', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '20930e62-bc51-4bc2-aaa0-0596cce3b18d', 'x-amzn-invoked-production-variant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'date': 'Wed, 24 Jan 2024 12:56:31 GMT', 'content-type': 'application/json; charset=utf-8', 'content-length': '23', 'connection': 'keep-alive'}, 'RetryAttempts': 0}, 'ContentType': 'application/json; charset=utf-8', 'InvokedProductionVariant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'Body': <botocore.response.StreamingBody object at 0x7f79a7206260>}
{'status': 'executing'}
{'ResponseMetadata': {'RequestId': '2c55efed-e3f0-4b34-a4e2-b01492570502', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '2c55efed-e3f0-4b34-a4e2-b01492570502', 'x-amzn-invoked-production-variant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'date': 'Wed, 24 Jan 2024 12:56:41 GMT', 'content-type': 'application/json; charset=utf-8', 'content-length': '23', 'connection': 'kee

In [12]:
payload={
     "client_id":client_id,
     "prompt_id":prompt_id,
     "inference_type":"text2vid",
     "method":"get_images"
}
result = predict(endpoint_name,payload)

{'ResponseMetadata': {'RequestId': 'b9aebf40-c101-4b35-86ec-9aa1534e69c4', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'b9aebf40-c101-4b35-86ec-9aa1534e69c4', 'x-amzn-invoked-production-variant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'date': 'Wed, 24 Jan 2024 13:00:43 GMT', 'content-type': 'application/json; charset=utf-8', 'content-length': '95', 'connection': 'keep-alive'}, 'RetryAttempts': 0}, 'ContentType': 'application/json; charset=utf-8', 'InvokedProductionVariant': 'AIGC-ComfyUI-2024-01-23-07-50-37', 'Body': <botocore.response.StreamingBody object at 0x7f79a7206080>}
{'prediction': ['s3://sagemaker-us-west-2-687912291502/comfyui_output/images/Comfyui_110.gif']}


In [None]:
show_gifs(result)

s3 generated gifs path is s3://sagemaker-us-west-2-687912291502/comfyui_output/images/Comfyui_110.gif


### 5 清除资源

In [None]:
response = client.delete_endpoint(
    EndpointName=endpoint_name
    
)

response = client.delete_endpoint_config(
    EndpointConfigName=endpoint_config_name
)


print(f'终端节点:{endpoint_name} 已经被清除，请在控制台上查看状态')
