## 0. 配置sagemaker，获取 account id 等

In [1]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role()
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


## 1. 配置要调用的镜像

之前 terminal 打包上传到 ECR 的镜像

In [15]:
# login
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account_id}.dkr.ecr.{region}.amazonaws.com


REPO_NAME = "sagemaker-endpoint/whisper-triton-byoc"
CONTAINER = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{REPO_NAME}:latest"


https://docs.docker.com/engine/reference/commandline/login/#credentials-store

Login Succeeded


In [4]:
!rm model_date.tar.gz
!cat model_data/deploy_config.sh
!tar czvf model_data.tar.gz model_data/ --exclude=model_data/.ipynb_checkpoints --exclude=model_data/__pycache__

rm: cannot remove ‘model_date.tar.gz’: No such file or directory
model_data/
model_data/triton_client_preprocessed.py
model_data/inference.py
model_data/start_triton_and_client.sh
model_data/.ipynb_checkpoints/
model_data/.ipynb_checkpoints/inference-checkpoint.py
model_data/.ipynb_checkpoints/download_model_from_s3-checkpoint.py
model_data/.ipynb_checkpoints/triton_client-checkpoint.py
model_data/.ipynb_checkpoints/start_triton_and_client-checkpoint.sh
model_data/.ipynb_checkpoints/triton_client_preprocessed-checkpoint.py
model_data/.ipynb_checkpoints/ssh_helper_start-checkpoint.py
model_data/download_model_from_s3.py
model_data/ssh_helper_start.py
model_data/triton_client.py


In [5]:
s3_code_prefix = f"whisper_deploy_codes"
bucket = sess.default_bucket()
code_artifact = sess.upload_data("model_data.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-east-1-596899493901/model_repo_whisper_trtll/model_data.tar.gz


## 2. 使用 SSH-helper 进行调试（可选）

Since we are using the BYOC (Bring Your Own Container) method to deploy model, we can deploy and debug the code using SSH Helper after preparing the initial code. Once the debugging is successful, we can then deploy it using the regular method.

1. Deploy the model using SageMaker SSH Helper(Setting up your AWS account with IAM and SSM configuration)
2. After got the instance_id, ssh to the instance and debug.


在部署的时候使用 SSH-helper 进行推理调试，也会启动一个 sagemaker 实例，在不使用时，使用最后清理步骤，删除节点

In [3]:
%pip install sagemaker_ssh_helper==2.2.0

Collecting sagemaker_ssh_helper==2.2.0
  Downloading sagemaker_ssh_helper-2.2.0-py3-none-any.whl.metadata (3.1 kB)
Downloading sagemaker_ssh_helper-2.2.0-py3-none-any.whl (98 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.8/98.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sagemaker_ssh_helper
Successfully installed sagemaker_ssh_helper-2.2.0
Note: you may need to restart the kernel to use updated packages.


In [6]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
model = Model(image_uri=CONTAINER, model_data=code_artifact, role=role,dependencies=[SSHModelWrapper.dependency_dir()])

In [7]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
from time import gmtime, strftime
from sagemaker import Predictor
instance_type = "ml.g5.4xlarge"
# instance_type = "ml.p4d.24xlarge"
endpoint_name = sagemaker.utils.name_from_base("whisper-trt-triton-sshelper")
# endpointName="gpt-sovits-sagemaker-endpoint-v2-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())

ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)  # <--NEW--

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    wait=False
)


# instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  # <--NEW-- 
# print(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")


In [8]:
import time
sm_client = boto3.client("sagemaker")
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Arn: arn:aws:sagemaker:us-east-1:596899493901:endpoint/whisper-trt-triton-sshelper-2024-09-18-03-50-09-257
Status: InService


In [None]:
# aws ssm start-session --target <Your_instance_ids> 
instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=0)
print(instance_ids[0])

## 2. 正式部署

In [14]:
model = Model(
    model_data=code_artifact,
    image_uri=CONTAINER,
    role=role,
)

# 部署模型到endpoint
endpoint_name = sagemaker.utils.name_from_base("whisper-large-v3")
print(f"endpoint_name: {endpoint_name}")
predictor = model.deploy(
    initial_instance_count=1,
    instance_type='ml.g5.4xlarge',
    endpoint_name=endpoint_name,
)

endpoint_name: whisper-large-v3-2024-09-18-05-29-54-792


Using already existing model: whisper-triton-preprocessed


--------------------!

## 3. 推理调用测试

In [12]:
%pip install pydub

Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub
Successfully installed pydub-0.25.1
Note: you may need to restart the kernel to use updated packages.


In [None]:
import boto3
import json
import base64
import os
import io
from pydub import AudioSegment

endpoint_name = endpoint_name
def encode_audio(audio_file_path):
    # 加载音频文件
    audio = AudioSegment.from_wav(audio_file_path)
    
    # 检查是否为双通道
    if audio.channels == 2:
        print("检测到双通道音频，正在转换为单通道...")
        # 将双通道转换为单通道
        audio = audio.set_channels(1)
    
    # 将音频数据写入内存缓冲区
    buffer = io.BytesIO()
    audio.export(buffer, format="wav")
    buffer.seek(0)
    
    # 将缓冲区的内容编码为 base64
    return base64.b64encode(buffer.read()).decode('utf-8')

def invoke_sagemaker_endpoint(runtime_client, endpoint_name, audio_data, whisper_prompt=""):
    """Invoke SageMaker endpoint with audio data"""
    payload = {
        "whisper_prompt": whisper_prompt,
        "audio_data": audio_data
    }
    
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(payload)
    )
    
    result = json.loads(response['Body'].read().decode())
    return result

def transcribe_audio(audio_path, endpoint_name, whisper_prompt=""):
    # Convert audio to 16kHz mono WAV if it's not already
    # Read and encode the audio file
    print("Reading and encoding audio file...")
    audio_data = encode_audio(audio_path)

    # Create a SageMaker runtime client
    runtime_client = boto3.client('sagemaker-runtime')

    # Invoke the SageMaker endpoint
    print(f"Invoking SageMaker endpoint: {endpoint_name}")
    result = invoke_sagemaker_endpoint(
        runtime_client,
        endpoint_name,
        audio_data,
        whisper_prompt
    )

    return result

# Example usage
if __name__ == "__main__":
    # Set your parameters here
    audio_path = "./English_04.wav"
    whisper_prompt = ""  # Optional: add a prompt if needed, the defualt is <|startoftranscript|><|en|><|transcribe|><|notimestamps|>

    # Call the function
    result = transcribe_audio(audio_path, endpoint_name, whisper_prompt)

    # Print the result
    print("Transcription result:")
    print(result)


In [None]:
%%time
audio_path = "./English_04.wav"
endpoint_name = endpoint_name # "whisper-trt-triton-sshelper-2024-09-17-10-47-56-767"
whisper_prompt = ""  # Optional: add a prompt if needed

# Call the function
result = transcribe_audio(audio_path, endpoint_name, whisper_prompt)

# Print the result
print("Transcription result:")
print(result)

Reading and encoding audio file...
Invoking SageMaker endpoint: whisper-trt-triton-sshelper-2024-09-17-10-47-56-767
Transcription result:
{'code': 0, 'message': 'Success', 'transcribe_text': ' I want to play Sawyer.'}
CPU times: user 17.1 ms, sys: 4.23 ms, total: 21.3 ms
Wall time: 206 ms


In [11]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
sess.delete_model(model.name)