# GPT-SoVITS on Sagemaker

## build image

In [None]:
!chmod +x ./*.sh && ./build_and_push.sh 

In [None]:
!pip install boto3 sagemaker sagemaker_ssh_helper -U

In [1]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

# role = sagemaker.get_execution_role()  # execution role for the endpoint
role = "arn:aws:iam::596899493901:role/sagemaker_full_access"
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
bucket = sess.default_bucket()
image="gpt-sovits-inference-v2"
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

full_image_uri=f"{account_id}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
print(full_image_uri)


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/dongxq/.config/sagemaker/config.yaml
596899493901.dkr.ecr.us-east-1.amazonaws.com/gpt-sovits-inference-v2:latest


## remote debug test

In [2]:
## empty model data for byoc with webserver
!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'gpt_sovits')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'gpt_sovits')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

dummy
upload: ./model.tar.gz to s3://sagemaker-us-east-1-596899493901/gpt_sovits/assets/model.tar.gz


In [3]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
model = Model(image_uri=full_image_uri, model_data=model_data, role=role,dependencies=[SSHModelWrapper.dependency_dir()] )

In [10]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
from sagemaker import Predictor
instance_type = "ml.g5.xlarge"
endpoint_name = sagemaker.utils.name_from_base("gpt-sovits-inference")


ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)  # <--NEW--

predictor: Predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    wait=False
)


instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  # <--NEW-- 
print(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")

AssertionError: You should call wrapper.create() before model.deploy().

In [5]:
instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=0)

In [8]:
ssh_wrapper.print_ssh_info()

Remote endpoint logs are at https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FEndpoints$252Fgpt-sovits-inference-2024-08-21-11-32-04-151
Endpoint metadata is at https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/endpoints/gpt-sovits-inference-2024-08-21-11-32-04-151
Endpoint config metadata is at https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/endpointConfig/gpt-sovits-inference-2024-08-21-11-32-04-151
Model metadata is at https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/models/gpt-sovits-inference-v2-2024-08-21-11-32-04-514


KeyboardInterrupt: 

In [9]:
instance_ids[0]

IndexError: list index out of range

## SM endpoint test

### create sagemaker model

In [None]:
import boto3
import re
import os
import json
import uuid
import boto3
import sagemaker
from time import gmtime, strftime
## for debug only
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
sm_client = boto3.client(service_name='sagemaker')



def create_model():
    image=full_image_uri
    model_name="gpt-sovits-sagemaker-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
    create_model_response = sm_client.create_model(
        ModelName=model_name,
        ExecutionRoleArn=role,
        Containers=[{"Image": image}],
    )
    print(create_model_response)
    return model_name

In [None]:
model_name=create_model()


### create endpoint configuration

In [None]:
endpointConfigName = "gpt-sovits-sagemaker-configuration-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
def create_endpoint_configuration():
    create_endpoint_config_response = sm_client.create_endpoint_config(     
        EndpointConfigName=endpointConfigName,
        ProductionVariants=[
            {
                #"ModelName":"gpt-sovits-sagemaker-012024-03-28-04-00-03",
                "ModelName":model_name,
                "VariantName": "gpt-sovits-sagemaker"+"-variant",
                "InstanceType": "ml.g5.xlarge",  # 指定 g5.xlarge 机器
                "InitialInstanceCount": 1,
                "ModelDataDownloadTimeoutInSeconds": 1200,
                "ContainerStartupHealthCheckTimeoutInSeconds": 1200,
                "EnableSSMAccess": true,
            }
        ],
    )
    print(create_endpoint_config_response)
    return endpointConfigName


In [None]:
create_endpoint_configuration()


### create endpoint

In [None]:
endpointName="gpt-sovits-sagemaker-endpoint-v2-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
def create_endpoint():
    create_endpoint_response = sm_client.create_endpoint(
        EndpointName=endpointName,
        #EndpointConfigName="gpt-sovits-sagemaker-configuration2024-03-28-04-03-53",
        EndpointConfigName=endpointConfigName
    )
    print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])
    resp = sm_client.describe_endpoint(EndpointName=endpointName)
    print("Endpoint Status: " + resp["EndpointStatus"])
    print("Waiting for {} endpoint to be in service".format("gpt-sovits-sagemaker-endpoint"))
    waiter = sm_client.get_waiter("endpoint_in_service")
    waiter.wait(EndpointName=endpointName)

In [None]:
create_endpoint()

## Realtime inferecne with sagemaker endpoint

In [None]:
import json
import boto3
# endpointName="gpt-sovits-inference-2024-05-07-08-46-43-537"
runtime_sm_client = boto3.client(service_name="sagemaker-runtime")
#endpointName="gpt-sovits-sagemaker-endpoint2024-04-03-23-49-44"


request = {"refer_wav_path":"s3://tts-xq/test-data/音质好.wav",
    "prompt_text": "脚下当心！这位客官，想照顾我们往生堂的生意，也不必这么心急嘛？你没什么事吧？嗯？麻烦的家伙",
    "prompt_language":"zh",
    "text":"作为 SAP 基础架构专家,我来解释一下 SAP Basis 的含义: SAP Basis 是指 SAP 系统的基础设施层,负责管理和维护整个 SAP 系统环境的运行。它包括以下几个主要方面: SAP 系统管理包括 SAP 系统实例的安装、启动、监控、备份、升级等日常管理任务。Basis 团队负责保证系统的正常运行。",
    "text_language" :"zh",
    "output_s3uri":"s3://tts-xq/gpt_sovits_output/wav/"}


def invoke_endpoint():
    content_type = "application/json"
    request_body = request
    payload = json.dumps(request_body)
    print(payload)
    response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )
    result = response['Body'].read().decode()
    print('返回：',result)

In [None]:
response=invoke_endpoint()

In [None]:
!aws s3 cp s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/speech_20240425104005663.mp3 ./
!aws s3 cp s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1715140344.wav ./

In [None]:
!aws s3 cp s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1715150796.wav ./

## Streams test (only for stream branch deployment)

In [None]:
import requests

chunk_bytes=None

def upsert(lst, new_dict):
    for i, item in enumerate(lst):
        if new_dict['index'] == i:
            lst[i] = new_dict
            return lst
    lst.append(new_dict)
    return lst

def invoke_streams_endpoint(smr_client,endpointName, request):
    global chunk_bytes
    content_type = "application/json"
    payload = json.dumps(request,ensure_ascii=False)

    response_model = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )

    result = []
    print(response_model['ResponseMetadata'])
    event_stream = iter(response_model['Body'])
    index = 0
    try: 
        while True:
            event = next(event_stream)
            eventChunk = event['PayloadPart']['Bytes']
            chunk_dict = {}
            if index == 0:
                print("Received first chunk")
                chunk_dict['first_chunk'] = True
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            else:
                chunk_dict['first_chunk'] = False
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            print("chunk len:",len(chunk_dict['bytes']))
            result.append(chunk_dict)    
            index += 1
            #print('返回chunk：', chunk_dict['bytes'])
    except StopIteration:
        print("All chunks processed")
        chunk_dict = {}
        chunk_dict['first_chunk'] = False
        chunk_dict['bytes'] = chunk_bytes
        chunk_dict['last_chunk'] = True
        chunk_dict['index'] = index-1
        result = upsert(result,chunk_dict)
    print("result",result)
    return result





In [None]:
import json
import boto3
# endpointName="gpt-sovits-inference-2024-05-17-13-49-58-483"
runtime_sm_client = boto3.client(service_name="sagemaker-runtime")
#endpointName="gpt-sovits-sagemaker-endpoint2024-04-03-23-49-44"



request = {"refer_wav_path":"s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/speech_20240425104005663.mp3",
    "prompt_text": "私はスポーツが好きな女の子で、私は中華料理が大好きで、私は中国へ旅行するのが好きで、特に杭州、成都が好きです",
    "prompt_language":"ja",
    "text":"『白夜行』はとても美しい小説で、私はとても夢中になって読んで、時には何時間も休まないで、私は中の主人公が大好きです",
    "text_language" :"ja",
    "output_s3uri":"s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/",
    "cut_punc":"、"}


In [None]:
response=invoke_streams_endpoint(runtime_sm_client,endpointName,request)

In [None]:
!aws s3 cp s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1715954949 ./

In [None]:
endpointName="gpt-sovits-sagemaker-endpoint-v22024-08-21-11-02-12"
sess.delete_endpoint(endpointName)
sess.delete_endpoint_config(endpointName)
model.delete_model()

In [12]:
# sess.delete_endpoint_config(endpointName)
model.delete_model()