# Hunyuan3D-2 on SageMaker endpoint

## 0. Import SDK

In [None]:
import sagemaker
import boto3
import os
import time
from datetime import datetime
import json

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts

region = sess._region_name

sagemaker_client = boto3.client("sagemaker")

In [None]:
endpoint_name = sagemaker.utils.name_from_base("hunyuan3d-2", short=True)

### Set container image

More containers: https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers

In [None]:
inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"
)

#中国区需要替换为下面的image_uri
if region in ['cn-north-1', 'cn-northwest-1']:
    inference_image_uri = (
        f"727897471807.dkr.ecr.{region}.amazonaws.com.cn/djl-inference:0.31.0-lmi13.0.0-cu124"
    )

print(f"Using container: {inference_image_uri}")

## 1. Prepare inference code

In [None]:
local_model_path = endpoint_name
!mkdir -p {local_model_path}
!git clone https://github.com/Tencent/Hunyuan3D-2.git {local_model_path}/Hunyuan3D-2
!rm -rf {local_model_path}/Hunyuan3D-2/.git

In [None]:
%%writefile {local_model_path}/setup.sh
cd `dirname $0`

# copy to /temp since /opt/ml/model is a read only filesystem
mkdir -p /temp/ && cp -a Hunyuan3D-2 /temp/Hunyuan3D-2
pip install -r /temp/Hunyuan3D-2/requirements.txt
pip install -e /temp/Hunyuan3D-2/
pip install -e /temp/Hunyuan3D-2/hy3dgen/texgen/custom_rasterizer
pip install -e /temp/Hunyuan3D-2/hy3dgen/texgen/differentiable_renderer

# install other deps
chmod 1777 /tmp && apt update && apt install -y libgl1-mesa-glx

In [None]:
%%writefile {local_model_path}/model.py
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
os.system(f"bash {current_dir}/setup.sh")

import site
site.main()

from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline
from hy3dgen.texgen import Hunyuan3DPaintPipeline
from hy3dgen.rembg import BackgroundRemover
from hy3dgen.shapegen import FaceReducer, FloaterRemover, DegenerateFaceRemover
import base64
from io import BytesIO
from PIL import Image

class EndpointHandler():
    def __init__(self, path=''):
        self.shape_pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained('tencent/Hunyuan3D-2')
        self.rmbg_worker = BackgroundRemover()

    def __call__(self, data):
        num_inference_steps = 50
        guidance_scale = 7.5
        octree_resolution = 384
        inputs = data.pop("inputs", data)
        if "num_inference_steps" in inputs:
            num_inference_steps = inputs['num_inference_steps']
        if "guidance_scale" in inputs:
            guidance_scale = inputs['guidance_scale']
        if  "octree_resolution" in inputs:
            octree_resolution = inputs['octree_resolution']
        input_image =  Image.open(BytesIO(base64.b64decode(inputs['image'])))
        input_image = self.rmbg_worker(input_image.convert('RGB'))
        mesh = self.shape_pipeline(input_image,num_inference_steps=num_inference_steps,guidance_scale=guidance_scale,octree_resolution=octree_resolution)[0]
        mesh = FloaterRemover()(mesh)
        mesh = DegenerateFaceRemover()(mesh)
        mesh = FaceReducer()(mesh)

        r = {
            'vertices': mesh.vertices.tolist(),
            'faces': mesh.faces.tolist(),
        }
        return r


from djl_python.outputs import Output


model = None

def handle(inputs):
    
    global model
    
    if not model:
        print("Init Model ... ")
        model = EndpointHandler(inputs.get_properties())
        
    if inputs.is_empty():
        return None

    data = inputs.get_as_json()

    result = model(data)

    return Output().add_as_json(result)

In [None]:
%%writefile {local_model_path}/serving.properties
engine=Python

### Upload to S3

In [None]:
!rm model.tar.gz
!cd {local_model_path} && rm -rf ".ipynb_checkpoints"
!tar czf model.tar.gz {local_model_path}/
!ls -lh model.tar.gz

In [None]:
s3_model_path = f"endpoint-models/{endpoint_name}"
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_model_path)
print(f"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}")

## 2. Create model and endpoint on SageMaker

In [None]:
# Step 0. create model

create_model_response = sagemaker_client.create_model(
    ModelName=endpoint_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_code_artifact
    },
    
)
print(create_model_response)
print("endpoint_model_name:", endpoint_name)

In [None]:
# Step 1. create endpoint config

endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": endpoint_name,
            "InstanceType": "ml.g6e.2xlarge",
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
            # "EnableSSMAccess": True,
        },
    ],
)
print(endpoint_config_response)
print("endpoint_config_name:", endpoint_name)

In [None]:
# Step 2. create endpoint

create_endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_name
)
print(create_endpoint_response)
print("endpoint_config_name:", endpoint_name)
while 1:
    status = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
    if status != "Creating":
        break
    print(datetime.now().strftime('%Y%m%d-%H:%M:%S') + " status: " + status)
    time.sleep(60)
print("Endpoint:", endpoint_name, status)

## 3. Test

In [None]:
sagemaker_runtime = boto3.client('runtime.sagemaker')

import base64

payload = {
    "image": base64.b64encode(open("demo.png", "rb").read()).decode()
}

response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Body=json.dumps(payload)
)

result = json.loads(response['Body'].read())

print(result['vertices'][:10])
print(result['faces'][:10])