# GPT-SoVITS on Sagemaker

## build image

In [None]:
!chmod +x ./*.sh && ./build_and_push.sh 

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
bucket = sess.default_bucket()
image="gpt-sovits-inference"
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

full_image_uri=f"{account_id}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
print(full_image_uri)


687912291502.dkr.ecr.us-west-2.amazonaws.com/gpt-sovits-inference:latest


## remote debug deploy test

In [None]:
!pip list|grep -i sagemaker
!pip list|grep -i boto3

sagemaker                               2.243.1
sagemaker-core                          1.0.29
sagemaker-mlflow                        0.1.0
sagemaker-pyspark                       1.4.5
sagemaker-ssh-helper                    2.1.0
aioboto3                                12.0.0
boto3                                   1.37.33


In [None]:
## empty model data for byoc with webserver
!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'gpt_sovits')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'gpt_sovits')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

dummy
upload: ./model.tar.gz to s3://sagemaker-us-west-2-687912291502/gpt_sovits/assets/model.tar.gz


In [None]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
model = Model(image_uri=full_image_uri, model_data=model_data, role=role,dependencies=[SSHModelWrapper.dependency_dir()] )

In [None]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
instance_type = "ml.g5.xlarge"
endpoint_name = sagemaker.utils.name_from_base("gpt-sovits-inference")


ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)  # <--NEW--

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    wait=False
)


#instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  # <--NEW-- 
#print(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")

In [18]:
instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=100)

In [7]:
#!aws ssm start-session --target mi-0c8f4ad16595250cd

## Streams test (only for stream branch deployment)

In [19]:
import requests

chunk_bytes=None

def upsert(lst, new_dict):
    for i, item in enumerate(lst):
        if new_dict['index'] == i:
            lst[i] = new_dict
            return lst
    lst.append(new_dict)
    return lst

def invoke_streams_endpoint(smr_client,endpointName, request):
    global chunk_bytes
    content_type = "application/json"
    payload = json.dumps(request,ensure_ascii=False)

    response_model = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )

    result = []
    print(response_model['ResponseMetadata'])
    event_stream = iter(response_model['Body'])
    index = 0
    try: 
        while True:
            event = next(event_stream)
            eventChunk = event['PayloadPart']['Bytes']
            chunk_dict = {}
            if index == 0:
                print("Received first chunk")
                chunk_dict['first_chunk'] = True
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            else:
                chunk_dict['first_chunk'] = False
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            print("chunk len:",len(chunk_dict['bytes']))
            result.append(chunk_dict)    
            index += 1
            #print('返回chunk：', chunk_dict['bytes'])
    except StopIteration:
        print("All chunks processed")
        chunk_dict = {}
        chunk_dict['first_chunk'] = False
        chunk_dict['bytes'] = chunk_bytes
        chunk_dict['last_chunk'] = True
        chunk_dict['index'] = index-1
        result = upsert(result,chunk_dict)
    print("result",result)
    return result





In [20]:
import json
import boto3
endpointName="gpt-sovits-inference-2025-04-25-00-48-02-079"
runtime_sm_client = boto3.client(service_name="sagemaker-runtime")
#endpointName="gpt-sovits-sagemaker-endpoint2024-04-03-23-49-44"


text="它包括以下几个主要方面:SAP系统管理包括SAP系统实例的安装、启动、监控、备份、升级等日常管理任务。Basis团队负责保证系统的正常运行。"

request = {"refer_wav_path": "s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/out003.wav",
    "prompt_text": "后来我就在直播间里认识了越来越多的听友，渐渐的这份工作，也为我带来了一些兼职收入，我就决定把这份工作做下去。",
    "prompt_language": "zh",
    "text": "作为SAP基础架构专家和SAP系统管理员,我来解释一下SAP Basis的含义:SAP Basis是指SAP系统的基础设施层,负责管理和维护整个SAP系统环境的运行。",
    "text_language": "zh",
    "cut_punc":","}


In [21]:
response=invoke_streams_endpoint(runtime_sm_client,endpointName,request)

{'RequestId': '9b329f7a-a3c4-46f0-a10b-40a29a2dd623', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '9b329f7a-a3c4-46f0-a10b-40a29a2dd623', 'x-amzn-invoked-production-variant': 'AllTraffic', 'x-amzn-sagemaker-content-type': 'audio/ogg', 'date': 'Sun, 27 Apr 2025 03:00:17 GMT', 'content-type': 'application/vnd.amazon.eventstream', 'transfer-encoding': 'chunked', 'connection': 'keep-alive'}, 'RetryAttempts': 0}
Received first chunk
chunk len: 139
chunk len: 208
chunk len: 272
chunk len: 21
chunk len: 308
chunk len: 332
chunk len: 379
chunk len: 261
chunk len: 427
chunk len: 213
chunk len: 475
chunk len: 165
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640

In [22]:
import json
import boto3
endpointName="gpt-sovits-inference-2025-04-28-23-03-39-424 "
runtime_sm_client = boto3.client(service_name="sagemaker-runtime")
#endpointName="gpt-sovits-sagemaker-endpoint2024-04-03-23-49-44"



request = {"refer_wav_path":"s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/speech_20240425104005663.mp3",
    "prompt_text": "私はスポーツが好きな女の子で、私は中華料理が大好きで、私は中国へ旅行するのが好きで、特に杭州、成都が好きです",
    "prompt_language":"ja",
    "text":"When I practice my spells in weather like this, I can do half the work for double the impact..",
    "text_language" :"en",
    "output_s3uri":"s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/",
    "cut_punc":","}


In [23]:
response=invoke_streams_endpoint(runtime_sm_client,endpointName,request)

{'RequestId': '8731f65d-0f20-48c5-a8f5-546330af5481', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '8731f65d-0f20-48c5-a8f5-546330af5481', 'x-amzn-invoked-production-variant': 'AllTraffic', 'x-amzn-sagemaker-content-type': 'audio/ogg', 'date': 'Sun, 27 Apr 2025 03:01:05 GMT', 'content-type': 'application/vnd.amazon.eventstream', 'transfer-encoding': 'chunked', 'connection': 'keep-alive'}, 'RetryAttempts': 0}
Received first chunk
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 488
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 640
chunk len: 64

## 二进制客户端pack测试

In [4]:
!pip install pydub




In [None]:
import requests
import json
import boto3
import io
import wave
import struct
import numpy as np
import pydub
from pydub import AudioSegment
import os

# 设置 ffmpeg 和 ffprobe 的路径
ffmpeg_path = "/home/ec2-user/anaconda3/envs/pytorch_p310/bin/ffmpeg"
ffprobe_path = "/home/ec2-user/anaconda3/envs/pytorch_p310/bin/ffprobe"

os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)

chunk_bytes = None

def upsert(lst, new_dict):
    for i, item in enumerate(lst):
        if new_dict['index'] == i:
            lst[i] = new_dict
            return lst
    lst.append(new_dict)
    return lst

def invoke_streams_endpoint(smr_client, endpointName, request):
    global chunk_bytes
    content_type = "application/json"
    payload = json.dumps(request, ensure_ascii=False)

    response_model = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )

    result = []
    print(response_model['ResponseMetadata'])
    event_stream = iter(response_model['Body'])
    index = 0
    try: 
        while True:
            event = next(event_stream)
            eventChunk = event['PayloadPart']['Bytes']
            chunk_dict = {}
            if index == 0:
                print("Received first chunk")
                chunk_dict['first_chunk'] = True
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            else:
                chunk_dict['first_chunk'] = False
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            print("chunk len:", len(chunk_dict['bytes']))
            result.append(chunk_dict)    
            index += 1
    except StopIteration:
        print("All chunks processed")
        chunk_dict = {}
        chunk_dict['first_chunk'] = False
        chunk_dict['bytes'] = chunk_bytes
        chunk_dict['last_chunk'] = True
        chunk_dict['index'] = index-1
        result = upsert(result, chunk_dict)
    #print("result", result)
    return result


def save_ogg(ogg_data, filename):
    with open(filename, 'wb') as ogg_file:
        ogg_file.write(ogg_data)

def ogg_to_wav(ogg_data):
    audio_segment = AudioSegment.from_ogg(io.BytesIO(ogg_data))
    buffer = io.BytesIO()
    audio_segment.export(buffer, format="wav")
    return buffer.getvalue()

def save_wav(wav_data, filename):
    with wave.open(filename, 'wb') as wav_file:
        wav_file.setnchannels(2)  # Assuming stereo
        wav_file.setsampwidth(2)  # 16-bit
        wav_file.setframerate(44100)  # 44.1kHz
        wav_file.writeframes(wav_data)

def main():
    endpointName = "gpt-sovits-inference-2025-04-28-23-03-38-889"
    runtime_sm_client = boto3.client(service_name="sagemaker-runtime")

    text = "它包括以下几个主要方面:SAP系统管理包括SAP系统实例的安装、启动、监控、备份、升级等日常管理任务。Basis团队负责保证系统的正常运行。"

    request = {
        "refer_wav_path": "s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/out003.wav",
        "prompt_text": "后来我就在直播间里认识了越来越多的听友，渐渐的这份工作，也为我带来了一些兼职收入，我就决定把这份工作做下去。",
        "prompt_language": "zh",
        "text": text,
        "text_language": "zh",
        "cut_punc": "。"
    }

    response = invoke_streams_endpoint(runtime_sm_client, endpointName, request)
    
    
    ogg_data = b''.join(chunk['bytes'] for chunk in response)
    save_ogg(ogg_data, 'output.ogg')
    
    # Concatenate all OGG chunks
    ogg_data = b''.join(chunk['bytes'] for chunk in response)
    
    # Convert OGG to WAV
    wav_data = ogg_to_wav(ogg_data)
    
    # Save WAV file
    save_wav(wav_data, 'output.wav')
    print("WAV file saved as 'output.wav'")

if __name__ == "__main__":
    main()


In [23]:
!aws s3 cp s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1745885495.mp3 ./

download: s3://sagemaker-us-west-2-687912291502/gpt_sovits_output/wav/gpt_sovits_1745885495.mp3 to ./gpt_sovits_1745885495.mp3


## stream api_v2测试

In [None]:
import requests
import json
import boto3
import io
import wave
import struct
import numpy as np
import pydub
from pydub import AudioSegment
import os

# 设置 ffmpeg 和 ffprobe 的路径
ffmpeg_path = "/home/ec2-user/anaconda3/envs/pytorch_p310/bin/ffmpeg"
ffprobe_path = "/home/ec2-user/anaconda3/envs/pytorch_p310/bin/ffprobe"

os.environ["PATH"] += os.pathsep + os.path.dirname(ffmpeg_path)

chunk_bytes = None

def upsert(lst, new_dict):
    for i, item in enumerate(lst):
        if new_dict['index'] == i:
            lst[i] = new_dict
            return lst
    lst.append(new_dict)
    return lst

def invoke_streams_endpoint(smr_client, endpointName, request):
    global chunk_bytes
    content_type = "application/json"
    payload = json.dumps(request, ensure_ascii=False)

    response_model = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpointName,
        ContentType=content_type,
        Body=payload,
    )

    result = []
    print(response_model['ResponseMetadata'])
    event_stream = iter(response_model['Body'])
    index = 0
    try: 
        while True:
            event = next(event_stream)
            eventChunk = event['PayloadPart']['Bytes']
            chunk_dict = {}
            if index == 0:
                print("Received first chunk")
                chunk_dict['first_chunk'] = True
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            else:
                chunk_dict['first_chunk'] = False
                chunk_dict['bytes'] = eventChunk
                chunk_bytes = eventChunk
                chunk_dict['last_chunk'] = False
                chunk_dict['index'] = index
            print("chunk len:", len(chunk_dict['bytes']))
            result.append(chunk_dict)    
            index += 1
    except StopIteration:
        print("All chunks processed")
        chunk_dict = {}
        chunk_dict['first_chunk'] = False
        chunk_dict['bytes'] = chunk_bytes
        chunk_dict['last_chunk'] = True
        chunk_dict['index'] = index-1
        result = upsert(result, chunk_dict)
    #print("result", result)
    return result


def save_ogg(ogg_data, filename):
    with open(filename, 'wb') as ogg_file:
        ogg_file.write(ogg_data)

def ogg_to_wav(ogg_data):
    audio_segment = AudioSegment.from_ogg(io.BytesIO(ogg_data))
    buffer = io.BytesIO()
    audio_segment.export(buffer, format="wav")
    return buffer.getvalue()

def save_wav(wav_data, filename):
    with wave.open(filename, 'wb') as wav_file:
        wav_file.setnchannels(2)  # Assuming stereo
        wav_file.setsampwidth(2)  # 16-bit
        wav_file.setframerate(44100)  # 44.1kHz
        wav_file.writeframes(wav_data)

def main():
    endpointName = "gpt-sovits-inference-2025-04-28-23-03-38-889"
    runtime_sm_client = boto3.client(service_name="sagemaker-runtime")

    text = "它包括以下几个主要方面:SAP系统管理包括SAP系统实例的安装、启动、监控、备份、升级等日常管理任务。Basis团队负责保证系统的正常运行。"

    data = {
    "text": text,
    "text_lang": "zh",
    "ref_audio_path": "s3://sagemaker-us-west-2-687912291502/gpt-sovits/wav/out003.wav",
    "prompt_lang": "zh",
    "prompt_text": "后来我就在直播间里认识了越来越多的听友，渐渐的这份工作，也为我带来了一些兼职收入，我就决定把这份工作做下去。",
    "top_k": 5,
    "top_p": 1.0,
    "temperature": 0.7,
    "text_split_method": "cut5",
    "batch_size": 1,
    "batch_threshold": 0.75,
    "split_bucket": True,
    "speed_factor": 1.0,
    "fragment_interval": 0.3,
    "seed": -1,
    "media_type": "wav",
    "streaming_mode": False,
    "parallel_infer": True,
    "repetition_penalty": 1.35,
    "sample_steps": 32,
    "super_sampling": False
}

    
    response = invoke_streams_endpoint(runtime_sm_client, endpointName, data)
    
    ogg_data = b''.join(chunk['bytes'] for chunk in response)
    save_ogg(ogg_data, 'output.ogg')
    
    # Concatenate all OGG chunks
    ogg_data = b''.join(chunk['bytes'] for chunk in response)
    
    # Convert OGG to WAV
    wav_data = ogg_to_wav(ogg_data)
    
    # Save WAV file
    save_wav(wav_data, 'output.wav')
    print("WAV file saved as 'output.wav'")

if __name__ == "__main__":
    main()
