In [None]:
!pip install nvidia-pyindex tritonclient[http]
!pip install numpy
!pip install transformers accelerate diffusers

In [None]:
# imports
import boto3
import sagemaker
import time
from sagemaker import get_execution_role

# variables
s3_client = boto3.client("s3")
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

# sagemaker variables
role = get_execution_role()
role = "arn:aws:iam::187838347205:role/service-role/AmazonSageMaker-ExecutionRole-20190212T141132"
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
bucket = sagemaker_session.default_bucket()

In [None]:
import diffusers
import torch 

pipeline = diffusers.StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",
                                                             torch_dtype=torch.float16,
                                                             revision="fp16")
pipeline.save_pretrained('model_repo_simple/pipeline/stable_diff')

In [None]:
!bash conda_dependencies.sh
!mv sd_env.tar.gz model_repo_simple/pipeliness

### Package model

In [None]:
model_name_prefix = 'stable_diff'
model_file_name = f'{model_name_prefix}.tar.gz'

In [None]:

prefix = 'stable-diffusion'
#!tar -C model_repo_simple/ -czf $model_file_name pipeline
#model_data_url = sagemaker_session.upload_data(path=model_file_name, key_prefix=prefix)

We will package `i` more models and send them to S3. However, Triton does not accept that models within different model repositories have similar names; so we have to replicate our model_repo_simple, change the `pipeline` folder to `pipeline_{i}`, and the model name in the config for `pipeline_{i}` as well for all subsequent models.  

In [None]:
for i in range(1,2):
    !cp -r model_repo_simple/ model_repo_simple_"$i"/
    !mv model_repo_simple_"$i"/pipeline/ model_repo_simple_"$i"/pipeline_"$i"/

# !!!! At this point I manually changed the config.pbtxt for all the new pipelines


In [None]:
i= 26
!tar -C model_repo_simple_"$i"/ -czf "$model_name_prefix"_"$i".tar.gz pipeline_"$i"
model_file_name = f"{model_name_prefix}_{i}.tar.gz"
sagemaker_session.upload_data(path=model_file_name, key_prefix=prefix)

In [None]:
!aws s3 ls  s3://sagemaker-eu-west-1-187838347205/stable-diffusion/

In [None]:
# account mapping for SageMaker MME Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}

region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
mme_triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.10-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)
mme_triton_image_uri = "187838347205.dkr.ecr.eu-west-1.amazonaws.com/mme-triton-custom:3"


In [None]:
prefix = "stable-diffusion"
model_data_url = f"s3://sagemaker-eu-west-1-187838347205/{prefix}/"
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {"Image": mme_triton_image_uri, "ModelDataUrl": model_data_url, "Mode": "MultiModel"}

In [None]:
sm_model_name = f"{prefix}-mdl-{ts}"

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

### vpc model

In [None]:
vpc=""
#vpc="-vpc"

In [None]:

create_model_response = sm_client.create_model(
    ModelName=f"{sm_model_name}{vpc}", ExecutionRoleArn=role,\
    Containers=[container],
    VpcConfig={
        'SecurityGroupIds': [
            'sg-51498f2e',
        ],
        'Subnets': [
            'subnet-d9513191',
            'subnet-5da37107',
            'subnet-653aab03'

        ]
    }
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = f"{prefix}-epc-{ts}{vpc}"
instance_type = 'ml.g5.2xlarge'

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": instance_type,
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": f"{sm_model_name}{vpc}",
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = f'{prefix}-ep-{ts}{vpc}'

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

## Query the endpoint

In [None]:
! pip install tritonclient

In [None]:
import numpy as np
import tritonclient.http as httpclient
from tritonclient.utils import *
from PIL import Image

prompt = ""
inputs = []
outputs = []

text_obj = np.array([prompt], dtype="object").reshape((-1, 1))

inputs.append(httpclient.InferInput("prompt",
                                    text_obj.shape,
                                    np_to_triton_dtype(text_obj.dtype))
             )
inputs[0].set_data_from_numpy(text_obj)

outputs.append(httpclient.InferRequestedOutput("generated_image"))


In [None]:
request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
    inputs, outputs=outputs
)

print(request_body)

In [None]:
# Change the model_file_name to try different models

In [None]:
model_file_name ="stable_diff_26.tar.gz"

In [None]:
#endpoint_name="stable-diffusion-ep-2023-02-09-10-07-00"

endpoint_name="stable-diffusion-ep-2023-03-02-16-23-13-vpc"

In [None]:
response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(header_length),
    Body=request_body,
    TargetModel=model_file_name
)

In [None]:
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)

image_array = result.as_numpy('generated_image')
image = Image.fromarray(np.squeeze(image_array))

In [None]:
image