# GPT-SoVITS on Sagemaker

## build image

In [None]:
!chmod +x ./*.sh && ./build_and_push.sh 

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
bucket = sess.default_bucket()
image="gpt-sovits-inference"
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

full_image_uri=f"{account_id}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
print(full_image_uri)


## remote debug test

In [None]:
## empty model data for byoc with webserver
!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'gpt_sovits')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'gpt_sovits')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

In [None]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
model = Model(image_uri=full_image_uri, model_data=model_data, role=role,dependencies=[SSHModelWrapper.dependency_dir()] )

In [None]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
instance_type = "ml.g5.xlarge"
endpoint_name = sagemaker.utils.name_from_base("gpt-sovits-inference")


ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)  # <--NEW--

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    wait=False
)


#instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  # <--NEW-- 
#print(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")

In [None]:
instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=0)

In [None]:
instance_ids[0]

## SM endpoint test

### create sagemaker model

In [None]:
import boto3
import re
import os
import json
import uuid
import boto3
import sagemaker
from time import gmtime, strftime
## for debug only
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
sm_client = boto3.client(service_name='sagemaker')



def create_model():
    image=full_image_uri
    model_name="gpt-sovits-sagemaker-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
    create_model_response = sm_client.create_model(
        ModelName=model_name,
        ExecutionRoleArn=role,
        Containers=[{"Image": image}],
    )
    print(create_model_response)
    return model_name

In [None]:
model_name=create_model()


### create endpoint configuration

In [None]:
endpointConfigName = "gpt-sovits-sagemaker-configuration-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
def create_endpoint_configuration():
    create_endpoint_config_response = sm_client.create_endpoint_config(     
        EndpointConfigName=endpointConfigName,
        ProductionVariants=[
            {
                #"ModelName":"gpt-sovits-sagemaker-012024-03-28-04-00-03",
                "ModelName":model_name,
                "VariantName": "gpt-sovits-sagemaker"+"-variant",
                "InstanceType": "ml.g5.xlarge",  # 指定 g5.xlarge 机器
                "InitialInstanceCount": 1,
                "ModelDataDownloadTimeoutInSeconds": 1200,
                "ContainerStartupHealthCheckTimeoutInSeconds": 1200
            }
        ],
    )
    print(create_endpoint_config_response)
    return endpointConfigName


In [None]:
create_endpoint_configuration()


### create endpoint

In [None]:
endpointName="gpt-sovits-sagemaker-endpoint"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
def create_endpoint():
    create_endpoint_response = sm_client.create_endpoint(
        EndpointName=endpointName,
        #EndpointConfigName="gpt-sovits-sagemaker-configuration2024-03-28-04-03-53",
        EndpointConfigName=endpointConfigName
    )
    print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])
    resp = sm_client.describe_endpoint(EndpointName=endpointName)
    print("Endpoint Status: " + resp["EndpointStatus"])
    print("Waiting for {} endpoint to be in service".format("gpt-sovits-sagemaker-endpoint"))
    waiter = sm_client.get_waiter("endpoint_in_service")
    waiter.wait(EndpointName=endpointName)

In [None]:
create_endpoint()

## Realtime inferecne with sagemaker endpoint

In [26]:
import json
import boto3
endpointName="gpt-sovits-inference-2024-04-25-10-16-53-582"
runtime_sm_client = boto3.client(service_name="sagemaker-runtime")
#endpointName="gpt-sovits-sagemaker-endpoint2024-04-03-23-49-44"
request = {"refer_wav_path":"s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/speech_20240425104005663.mp3",
    "prompt_text": "私はスポーツが好きな女の子で、私は中華料理が大好きで、私は中国へ旅行するのが好きで、特に杭州、成都が好きです",
    "prompt_language":"ja",
    "text":"私には気にしないで、あなたは四海を家とすることを約束します,私を待っていても気にしないで、あなたの白髪を許す",
    "text_language" :"ja",
    "output_s3uri":"s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/"}

def invoke_endpoint():
    content_type = "application/json"
    request_body = request
    payload = json.dumps(request_body)
    print(payload)
    response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )
    result = response['Body'].read().decode()
    print('返回：',result)

In [27]:
response=invoke_endpoint()

{"refer_wav_path": "s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/speech_20240425104005663.mp3", "prompt_text": "\u79c1\u306f\u30b9\u30dd\u30fc\u30c4\u304c\u597d\u304d\u306a\u5973\u306e\u5b50\u3067\u3001\u79c1\u306f\u4e2d\u83ef\u6599\u7406\u304c\u5927\u597d\u304d\u3067\u3001\u79c1\u306f\u4e2d\u56fd\u3078\u65c5\u884c\u3059\u308b\u306e\u304c\u597d\u304d\u3067\u3001\u7279\u306b\u676d\u5dde\u3001\u6210\u90fd\u304c\u597d\u304d\u3067\u3059", "prompt_language": "ja", "text": "\u79c1\u306b\u306f\u6c17\u306b\u3057\u306a\u3044\u3067\u3001\u3042\u306a\u305f\u306f\u56db\u6d77\u3092\u5bb6\u3068\u3059\u308b\u3053\u3068\u3092\u7d04\u675f\u3057\u307e\u3059,\u79c1\u3092\u5f85\u3063\u3066\u3044\u3066\u3082\u6c17\u306b\u3057\u306a\u3044\u3067\u3001\u3042\u306a\u305f\u306e\u767d\u9aea\u3092\u8a31\u3059", "text_language": "ja", "output_s3uri": "s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/"}
返回： {"result":"s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_17140422

In [28]:
!aws s3 cp s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1714042297.wav ./

download: s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1714042297.wav to ./gpt_sovits_1714042297.wav
