# FireRedASR-AED-L torchserve SageMaker Deployment

In [1]:
!pip install kaldiio



In [2]:
import sagemaker
import boto3

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

print(f"sagemaker role arn: {role}")
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment



sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


sagemaker role arn: arn:aws:iam::596899493901:role/service-role/AmazonSageMaker-ExecutionRole-20240126T153870


In [3]:
%%writefile code/inference.py
import os
import sys

import torch
import logging
import json
import numpy as np
import torch.nn.functional as F
import uuid


from fireredasr.models.fireredasr import FireRedAsr

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def generate_request_id():
    """生成唯一的request_id"""
    return str(uuid.uuid4())

def model_fn(model_dir):
    model = FireRedAsr.from_pretrained("aed", "pretrained_models/FireRedASR-AED-L/")
    return model

def input_fn(request_body, request_content_type):
    audio = np.frombuffer(request_body, dtype=np.int16)
    request_id = generate_request_id()
    return (audio, request_id)


def predict_fn(input_data, model):
    audio_data, request_id = input_data
    results = model.transcribe(
        [request_id],
        audio_data,
        {
            "use_gpu": 1,
            "beam_size": 3,
            "nbest": 1,
            "decode_max_len": 0,
            "softmax_smoothing": 1.0,
            "aed_length_penalty": 0.0,
            "eos_penalty": 1.0
        }
    )
    logger.info(f'ori results: {results}')

    logger.info(
        f'Transcription generated: {results[0]["text"]}') 
    return results


def output_fn(prediction, response_content_type):
    logger.info(
        f'Formatting output with content type: {response_content_type}')
    if response_content_type == 'application/json':
        return json.dumps({'transcription': prediction})
    raise ValueError(f'Unsupported content type: {response_content_type}')


Overwriting code/inference.py


In [4]:
!rm model.tar.gz
!tar -czvf model.tar.gz ./code --exclude='*.ipynb' --exclude='*/.ipynb_checkpoints' --exclude='*/*/__pycache__'

./code/
./code/inference.py
./code/pretrained_models/
./code/pretrained_models/README.md
./code/pretrained_models/FireRedASR-AED-L/
./code/pretrained_models/FireRedASR-AED-L/cmvn.ark
./code/pretrained_models/FireRedASR-AED-L/train_bpe1000.model
./code/pretrained_models/FireRedASR-AED-L/dict.txt
./code/pretrained_models/FireRedASR-AED-L/README.md
./code/pretrained_models/FireRedASR-AED-L/cmvn.txt
./code/pretrained_models/FireRedASR-AED-L/config.yaml
./code/pretrained_models/FireRedASR-AED-L/model.pth.tar
./code/requirements.txt
./code/test.py
./code/output.wav
./code/fireredasr/
./code/fireredasr/utils/
./code/fireredasr/utils/wer.py
./code/fireredasr/utils/param.py
./code/fireredasr/tokenizer/
./code/fireredasr/tokenizer/aed_tokenizer.py
./code/fireredasr/tokenizer/llm_tokenizer.py
./code/fireredasr/data/
./code/fireredasr/data/token_dict.py
./code/fireredasr/data/asr_feat.py
./code/fireredasr/speech2text.py
./code/fireredasr/models/
./code/fireredasr/models/module/
./code/fireredasr/m

In [5]:
s3_code_prefix = f"fireredasr_deploy_codes"
bucket = sess.default_bucket()
code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-east-1-596899493901/fireredasr_deploy_codes/model.tar.gz


In [6]:
from datetime import datetime
import sagemaker 
sm = boto3.Session().client("sagemaker")
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=region,
    py_version="py312",
    image_scope="inference",
    version="2.6.0",
    instance_type="ml.g5.2xlarge",
)

print("image_uri: ", image_uri)

env_variables_dict = {
    "SAGEMAKER_TS_BATCH_SIZE": "4",
    "SAGEMAKER_TS_MAX_BATCH_DELAY": "100",
    "SAGEMAKER_TS_MIN_WORKERS": "4",
    "SAGEMAKER_TS_MAX_WORKERS": "4",
}
endpoint_name = f"fireredasr-torchserve-{datetime.now():%Y-%m-%d-%H-%M-%S}"
resp = sm.create_model(
    ModelName=endpoint_name,
    ExecutionRoleArn=role,
    Containers=[{"Image": image_uri, "ModelDataUrl": code_artifact, "Environment": env_variables_dict}]
)
print(f"Created Model: {resp}")

image_uri:  763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.6.0-gpu-py312
Created Model: {'ModelArn': 'arn:aws:sagemaker:us-east-1:596899493901:model/fireredasr-torchserve-2025-05-20-04-39-39', 'ResponseMetadata': {'RequestId': 'a98a04e2-d07d-4743-9f3e-353c26b8c28c', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'a98a04e2-d07d-4743-9f3e-353c26b8c28c', 'content-type': 'application/x-amz-json-1.1', 'content-length': '103', 'date': 'Tue, 20 May 2025 04:39:40 GMT'}, 'RetryAttempts': 0}}


In [7]:
resp = sm.create_endpoint_config(
    EndpointConfigName=endpoint_name,
    ProductionVariants=[
        {
            "VariantName": "AllTraffic",
            "ModelName": endpoint_name,
            "InstanceType": "ml.g5.2xlarge",
            "InitialInstanceCount": 1,
        }
    ],
)
print(f"Created Endpoint Config: {resp}")

Created Endpoint Config: {'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:596899493901:endpoint-config/fireredasr-torchserve-2025-05-20-04-39-39', 'ResponseMetadata': {'RequestId': '3fa078dd-000e-4e5d-b145-bc534980160f', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '3fa078dd-000e-4e5d-b145-bc534980160f', 'content-type': 'application/x-amz-json-1.1', 'content-length': '122', 'date': 'Tue, 20 May 2025 04:39:45 GMT'}, 'RetryAttempts': 0}}


In [8]:
# creating endpoint with the first endpoint config (ep_config_name)
resp = sm.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_name)
print(f"\nCreated Endpoint: {resp}")


Created Endpoint: {'EndpointArn': 'arn:aws:sagemaker:us-east-1:596899493901:endpoint/fireredasr-torchserve-2025-05-20-04-39-39', 'ResponseMetadata': {'RequestId': 'ba7e8699-ae29-4c41-a48b-74a266803c16', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'ba7e8699-ae29-4c41-a48b-74a266803c16', 'content-type': 'application/x-amz-json-1.1', 'content-length': '109', 'date': 'Tue, 20 May 2025 04:39:49 GMT'}, 'RetryAttempts': 0}}


In [9]:
%%time
import time
def wait_for_endpoint_in_service(endpoint_name):
    print("Waiting for endpoint in service")
    while True:
        details = sm.describe_endpoint(EndpointName=endpoint_name)
        status = details["EndpointStatus"]
        if status in ["InService", "Failed"]:
            print("\nDone!")
            break
        print(".", end="", flush=True)
        time.sleep(30)


wait_for_endpoint_in_service(endpoint_name)

sm.describe_endpoint(EndpointName=endpoint_name)

Waiting for endpoint in service
..............
Done!
CPU times: user 61.4 ms, sys: 14.9 ms, total: 76.3 ms
Wall time: 7min 1s


{'EndpointName': 'fireredasr-torchserve-2025-05-20-04-39-39',
 'EndpointArn': 'arn:aws:sagemaker:us-east-1:596899493901:endpoint/fireredasr-torchserve-2025-05-20-04-39-39',
 'EndpointConfigName': 'fireredasr-torchserve-2025-05-20-04-39-39',
 'ProductionVariants': [{'VariantName': 'AllTraffic',
   'DeployedImages': [{'SpecifiedImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.6.0-gpu-py312',
     'ResolvedImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference@sha256:d1b173191e87ab763e1c567506b03ec7621302148759f106aa781300231d86ac',
     'ResolutionTime': datetime.datetime(2025, 5, 20, 4, 39, 50, 248000, tzinfo=tzlocal())}],
   'CurrentWeight': 1.0,
   'DesiredWeight': 1.0,
   'CurrentInstanceCount': 1,
   'DesiredInstanceCount': 1}],
 'EndpointStatus': 'InService',
 'CreationTime': datetime.datetime(2025, 5, 20, 4, 39, 49, 493000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2025, 5, 20, 4, 46, 59, 775000, tzinfo=tzlocal()),
 'Respon

In [17]:
import boto3
import json
import base64
import kaldiio

def invoke_sagemaker_endpoint(runtime_client, endpoint_name, audio_data, request_id=None, whisper_prompt=""):
    """Invoke SageMaker endpoint with audio data"""
    
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/octet-stream',
        Body=audio_data,
    )

    result = json.loads(response['Body'].read().decode())
    # # 在返回结果中添加request_id
    # result['request_id'] = request_id
    return result

def prepare_audio(audio_file, target_sr=16000):
    sample_rate, wav_np = kaldiio.load_mat(audio_file)    
    audio_bytes = wav_np.tobytes()
    return audio_bytes

def transcribe_audio(audio_path, endpoint_name):
    """
    转录音频文件
    :param audio_path: 音频文件路径
    :param endpoint_name: SageMaker端点名称
    :param request_id: 可选的请求ID
    :return: 转录结果
    """
    try:
        audio_data = prepare_audio(audio_path)
        runtime_client = boto3.client('sagemaker-runtime')
        result = invoke_sagemaker_endpoint(
            runtime_client,
            endpoint_name,
            audio_data
        )

        return result

    except Exception as e:
        error_response = {
            'error': str(e),
        }
        return error_response

if __name__ == "__main__":
    audio_path = "examples/wav/BAC009S0764W0121.wav"

    result = transcribe_audio(audio_path, endpoint_name)

    print("Transcription result:")
    print(json.dumps(result, indent=2, ensure_ascii=False))


Transcription result:
{
  "transcription": [
    {
      "request_id": "0ee2e8e9-3fa4-4c7e-9d01-f23358602c97",
      "text": "甚至出现交易几乎停滞的情况",
      "rtf": "0.0692"
    }
  ]
}


In [18]:
st = time.monotonic()
audio_path = "./output.wav"
result = transcribe_audio(audio_path, endpoint_name)
time_consume = time.monotonic() - st
print("Transcription result:")
print(json.dumps(result, indent=2, ensure_ascii=False))
print("time_consume: ", time_consume)

Transcription result:
{
  "transcription": [
    {
      "request_id": "87f010f3-b76f-40b8-877c-f60cc01a9c34",
      "text": "THIS WAY THE EQUIPMENT YOU WISH TO SELL WILL ALREADY BE SELECTED WHEN YOU OPEN THE SHOP THE NEXT TIME AND AFTER SELLING IT YOU CAN CLOSE THE SHOP AND BUY THE NEXT PIECE OF EQUIPMENT DIRECTLY",
      "rtf": "0.1290"
    }
  ]
}
time_consume:  1.4025085670000408


In [34]:
!ffmpeg -i 223.wav -ar 16000 -ac 1 -acodec pcm_s16le -f wav output.wav

ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/ec2-user/anaconda3/envs/pytorch_p310 --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  7.100 /  5.  7.100
  libswresample   3.  7.100 /  3.  7.100
[0;33mGuessed Channel Layout for Input Stream #0.0 : mono
[0mInput #0, wav, from '223.wav':
  Duration: 00:00:09.57, bitrate: 256 kb/s
 

In [23]:
# delete endpoint
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_name)
sm.delete_model(ModelName=endpoint_name)

{'ResponseMetadata': {'RequestId': '5040fbe9-1c20-4632-8912-9838a3429811',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '5040fbe9-1c20-4632-8912-9838a3429811',
   'content-type': 'application/x-amz-json-1.1',
   'date': 'Tue, 20 May 2025 04:30:20 GMT',
   'content-length': '0'},
  'RetryAttempts': 0}}