## 0. 配置sagemaker，获取 account id 等

In [1]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role()
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


## 模型导出为 onnx，并上传 S3

In [2]:
s3_url = "s3://triton-models-xq/firered-onnx-triton/model_repository/"

In [None]:
!cp onnx_base/* model_repository/fireredasr_onnx/1/
!aws s3 sync model_repository/ {s3_url} --exclude "*/*/.ipynb_checkpoints" 

upload: model_repository/fireredasr_onnx/.ipynb_checkpoints/config-checkpoint.pbtxt to s3://triton-models-xq/firered-onnx-triton/model_repository/fireredasr_onnx/.ipynb_checkpoints/config-checkpoint.pbtxt
upload: model_repository/fireredasr_onnx/1/FireRedASR/fireredasr/data/__pycache__/asr_feat.cpython-312.pyc to s3://triton-models-xq/firered-onnx-triton/model_repository/fireredasr_onnx/1/FireRedASR/fireredasr/data/__pycache__/asr_feat.cpython-312.pyc
upload: model_repository/fireredasr_onnx/1/FireRedASR/LICENSE to s3://triton-models-xq/firered-onnx-triton/model_repository/fireredasr_onnx/1/FireRedASR/LICENSE
upload: model_repository/fireredasr_onnx/1/.ipynb_checkpoints/model-checkpoint.py to s3://triton-models-xq/firered-onnx-triton/model_repository/fireredasr_onnx/1/.ipynb_checkpoints/model-checkpoint.py
upload: model_repository/fireredasr_onnx/1/FireRedASR/fireredasr/models/module/conformer_encoder.py to s3://triton-models-xq/firered-onnx-triton/model_repository/fireredasr_onnx/1/Fi

## 1. 配置要调用的镜像

需要先在 SageMaker notebook 里的 terminal 打包镜像上传到 ECR 
```
bash build_and_push.sh
```

In [3]:
# login
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account_id}.dkr.ecr.{region}.amazonaws.com


REPO_NAME = "sagemaker-endpoint/fireredasr_tirton_onnx"
CONTAINER = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{REPO_NAME}:25.06"


https://docs.docker.com/engine/reference/commandline/login/#credentials-store

Login Succeeded


In [None]:
!rm ./model_data.tar.gz
!cat sagemaker_deploy/deploy_config.sh
!cat sagemaker_deploy/start_triton_and_client.sh
!tar czvf model_data.tar.gz sagemaker_deploy/ --exclude=sagemaker_deploy/.ipynb_checkpoints --exclude=sagemaker_deploy/__pycache__

In [5]:
s3_code_prefix = f"fireredasr_onnx_deploy"
bucket = sess.default_bucket()
code_artifact = sess.upload_data("model_data.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-east-1-596899493901/fireredasr_onnx_deploy/model_data.tar.gz


## 2. 使用 SSH-helper 进行调试（可选）

Since we are using the BYOC (Bring Your Own Container) method to deploy model, we can deploy and debug the code using SSH Helper after preparing the initial code. Once the debugging is successful, we can then deploy it using the regular method.

1. Deploy the model using SageMaker SSH Helper(Setting up your AWS account with IAM and SSM configuration)
2. After got the instance_id, ssh to the instance and debug.


在部署的时候使用 SSH-helper 进行推理调试，也会启动一个 sagemaker 实例，在不使用时，使用最后清理步骤，删除节点

In [6]:
%pip install sagemaker_ssh_helper==2.3.0 kaldiio

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting sagemaker_ssh_helper
  Downloading sagemaker_ssh_helper-2.3.0-py3-none-any.whl.metadata (3.0 kB)
Downloading sagemaker_ssh_helper-2.3.0-py3-none-any.whl (102 kB)
Installing collected packages: sagemaker_ssh_helper
Successfully installed sagemaker_ssh_helper-2.3.0
Note: you may need to restart the kernel to use updated packages.


In [6]:
from time import gmtime, strftime
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
name = "fireresasr-onnx-triton-sshelper"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
env_variables_dict ={"model_s3_url": s3_url}

model = Model(image_uri=CONTAINER, model_data=code_artifact, role=role,dependencies=[SSHModelWrapper.dependency_dir()],name = name, env=env_variables_dict)

In [7]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper
from time import gmtime, strftime

from sagemaker import Predictor
instance_type = "ml.g5.2xlarge"

ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)  # <--NEW--

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=name,
    inference_ami_version="al2-ami-sagemaker-inference-gpu-3-1",
    wait=True,
    timeout=180
)


# instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=900)  # <--NEW-- 
# print(f"To connect over SSM run: aws ssm start-session --target {instance_ids[0]}")


-----------------!

In [10]:
# aws ssm start-session --target <Your_instance_ids> 
instance_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=0)
print(instance_ids[0])

mi-0bbd331c8cd384bda


## 2. 正式部署

In [26]:
from datetime import datetime
import sagemaker 
sm = boto3.Session().client("sagemaker")
env_variables_dict ={"model_s3_url": s3_url}
endpoint_name = f"fireredasr-onnx-triton-{datetime.now():%Y-%m-%d-%H-%M-%S}"
resp = sm.create_model(
    ModelName=endpoint_name,
    ExecutionRoleArn=role,
    Containers=[{"Image": CONTAINER, "ModelDataUrl": code_artifact, "Environment": env_variables_dict}]
)
print(f"Created Model: {resp}")

Created Model: {'ModelArn': 'arn:aws:sagemaker:us-east-1:596899493901:model/fireredasr-onnx-triton-2025-07-24-13-34-30', 'ResponseMetadata': {'RequestId': '4717931e-fdca-47b3-b8c7-aeeea0ec5083', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '4717931e-fdca-47b3-b8c7-aeeea0ec5083', 'content-type': 'application/x-amz-json-1.1', 'content-length': '104', 'date': 'Thu, 24 Jul 2025 13:34:30 GMT'}, 'RetryAttempts': 0}}


In [27]:
resp = sm.create_endpoint_config(
    EndpointConfigName=endpoint_name,
    ProductionVariants=[
        {
            "VariantName": "AllTraffic",
            "ModelName": endpoint_name,
            "InstanceType": "ml.g5.2xlarge",
            "InitialInstanceCount": 1,
            "InferenceAmiVersion": "al2-ami-sagemaker-inference-gpu-3-1"
        }
    ],
)
print(f"Created Endpoint Config: {resp}")

Created Endpoint Config: {'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:596899493901:endpoint-config/fireredasr-onnx-triton-2025-07-24-13-34-30', 'ResponseMetadata': {'RequestId': '4fdb606d-0b00-45e5-b877-f8c0877f70d4', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '4fdb606d-0b00-45e5-b877-f8c0877f70d4', 'content-type': 'application/x-amz-json-1.1', 'content-length': '123', 'date': 'Thu, 24 Jul 2025 13:34:34 GMT'}, 'RetryAttempts': 0}}


In [28]:
resp = sm.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_name)
print(f"\nCreated Endpoint: {resp}")


Created Endpoint: {'EndpointArn': 'arn:aws:sagemaker:us-east-1:596899493901:endpoint/fireredasr-onnx-triton-2025-07-24-13-34-30', 'ResponseMetadata': {'RequestId': '0730a82a-fb38-4728-8b30-16070a3dd4b2', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '0730a82a-fb38-4728-8b30-16070a3dd4b2', 'content-type': 'application/x-amz-json-1.1', 'content-length': '110', 'date': 'Thu, 24 Jul 2025 13:34:37 GMT'}, 'RetryAttempts': 0}}


In [29]:
%%time
import time
def wait_for_endpoint_in_service(endpoint_name):
    print("Waiting for endpoint in service")
    while True:
        details = sm.describe_endpoint(EndpointName=endpoint_name)
        status = details["EndpointStatus"]
        if status in ["InService", "Failed"]:
            print("\nDone!")
            break
        print(".", end="", flush=True)
        time.sleep(30)


wait_for_endpoint_in_service(endpoint_name)

sm.describe_endpoint(EndpointName=endpoint_name)

Waiting for endpoint in service
.................
Done!
CPU times: user 79 ms, sys: 16.4 ms, total: 95.4 ms
Wall time: 8min 31s


{'EndpointName': 'fireredasr-onnx-triton-2025-07-24-13-34-30',
 'EndpointArn': 'arn:aws:sagemaker:us-east-1:596899493901:endpoint/fireredasr-onnx-triton-2025-07-24-13-34-30',
 'EndpointConfigName': 'fireredasr-onnx-triton-2025-07-24-13-34-30',
 'ProductionVariants': [{'VariantName': 'AllTraffic',
   'DeployedImages': [{'SpecifiedImage': '596899493901.dkr.ecr.us-east-1.amazonaws.com/sagemaker-endpoint/fireredasr_tirton_onnx:25.06',
     'ResolvedImage': '596899493901.dkr.ecr.us-east-1.amazonaws.com/sagemaker-endpoint/fireredasr_tirton_onnx@sha256:6409edb20a5eef8e19878f7f492a31c7d040783610b79db11a3bc65fc44759f4',
     'ResolutionTime': datetime.datetime(2025, 7, 24, 13, 34, 38, 84000, tzinfo=tzlocal())}],
   'CurrentWeight': 1.0,
   'DesiredWeight': 1.0,
   'CurrentInstanceCount': 1,
   'DesiredInstanceCount': 1}],
 'EndpointStatus': 'InService',
 'CreationTime': datetime.datetime(2025, 7, 24, 13, 34, 37, 460000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2025, 7, 24, 13, 

## 3. 推理调用测试

In [32]:
import boto3
import json
import base64
import kaldiio

def invoke_sagemaker_endpoint(runtime_client, endpoint_name, audio_data, request_id=None, whisper_prompt=""):
    """Invoke SageMaker endpoint with audio data"""
    
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/octet-stream',
        Body=audio_data,
    )

    result = json.loads(response['Body'].read().decode())
    # # 在返回结果中添加request_id
    # result['request_id'] = request_id
    return result

def prepare_audio(audio_file, target_sr=16000):
    _, wav_np = kaldiio.load_mat(audio_file)    
    return wav_np.tobytes()

def transcribe_audio(audio_path, endpoint_name):
    """
    转录音频文件
    :param audio_path: 音频文件路径
    :param endpoint_name: SageMaker端点名称
    :param request_id: 可选的请求ID
    :return: 转录结果
    """
    try:
        audio_data = prepare_audio(audio_path)
        runtime_client = boto3.client('sagemaker-runtime')
        result = invoke_sagemaker_endpoint(
            runtime_client,
            endpoint_name,
            audio_data
        )

        return result

    except Exception as e:
        error_response = {
            'error': str(e),
        }
        return error_response

if __name__ == "__main__":
    audio_path = "zh_1.wav"
    # endpoint_name = "fireresasr-onnx-triton-sshelper2025-07-24-13-14-54"
    result = transcribe_audio(audio_path, endpoint_name)

    print("Transcription result:")
    print(json.dumps(result, indent=2, ensure_ascii=False))


Transcription result:
{
  "transcription": [
    "每一天都要快乐喔"
  ],
  "status": "success"
}


In [None]:
%%time
audio_path = "zh_1.wav"

result = transcribe_audio(audio_path, endpoint_name)

print("Transcription result:")
print(json.dumps(result, indent=2, ensure_ascii=False))

In [24]:
%%time
audio_path = "223.wav"

result = transcribe_audio(audio_path, endpoint_name)

print("Transcription result:")
print(json.dumps(result, indent=2, ensure_ascii=False))

Transcription result:
{
  "transcription": [
    "THIS WAY THE EQUIPMENT YOU WISH TO SELL WILL ALREADY BE SELECTED WHEN YOU OPEN THE SHOP THE NEXT TIME AND AFTER SELLING IT YOU CAN CLOSE THE SHOP AND BUY THE NEXT PIECE OF EQUIPMENT DIRECTLY"
  ],
  "status": "success"
}
CPU times: user 16.1 ms, sys: 1.03 ms, total: 17.1 ms
Wall time: 930 ms


In [None]:
# 调试使用
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
sess.delete_model(model.name)

In [None]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_name)
sm.delete_model(ModelName=endpoint_name)