In [None]:
import sagemaker
import boto3
import logging
from sagemaker.huggingface import HuggingFaceModel

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
    'HF_MODEL_ID': 'BELLE-2/Belle-whisper-large-v3-zh-punct',
    'HF_TASK': 'automatic-speech-recognition'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.37.0',
    pytorch_version='2.1.0',
    py_version='py310',
    env=hub,
    role=role,
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,  # number of instances
    instance_type="ml.g4dn.2xlarge",  # ec2 instance type
    endpoint_name = "whisper"
)

from sagemaker.serializers import DataSerializer

predictor.serializer = DataSerializer(content_type='audio/x-audio')

# Make sure the input file "sample1.flac" exists
with open("/home/sagemaker-user/whisper_deploy/test.wav", "rb") as f:
    data = f.read()
predictor.predict(data)

INFO:sagemaker:Creating model with name: huggingface-pytorch-inference-2024-11-14-01-42-34-706
INFO:sagemaker:Creating endpoint-config with name whisper
INFO:sagemaker:Creating endpoint with name whisper


--

In [17]:
import requests
import json
import boto3
import base64
import os

def get_aws_signature(endpoint_name, region):
    session = boto3.Session()
    credentials = session.get_credentials()
    
    client = boto3.client(
        'sagemaker-runtime',
        region_name=region,
        aws_access_key_id=credentials.access_key,
        aws_secret_access_key=credentials.secret_key,
        aws_session_token=credentials.token if credentials.token else None
    )
    
    return client

def invoke_endpoint_api(audio_file_path, endpoint_name, region='ap-northeast-1'):
    # 获取AWS签名
    client = get_aws_signature(endpoint_name, region)
    
    # 读取音频文件
    with open(audio_file_path, 'rb') as f:
        audio_data = f.read()
    
    # 调用endpoint
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='audio/x-audio',
        Body=audio_data
    )
    
    # 解析响应
    result = json.loads(response['Body'].read().decode())
    return result

# 使用示例
abs_path = "/home/sagemaker-user/output.wav"
print(f"音频文件完整路径: {abs_path}")

if not os.path.exists(abs_path):
    print(f"错误：找不到音频文件")
    print(f"当前工作目录: {os.getcwd()}")
    print("当前目录文件列表:")
    print('\n'.join(os.listdir('.')))

endpoint_name = "whisper"
result = invoke_endpoint_api(abs_path, endpoint_name)
print(result)

音频文件完整路径: /home/sagemaker-user/output.wav
{'text': '我好伤心，我想哭。'}


In [12]:
import boto3
import soundfile as sf
import io

def validate_and_convert_audio(input_file_path, target_format='wav'):
    """验证并转换音频文件"""
    try:
        # 读取音频文件
        data, samplerate = sf.read(input_file_path)
        
        # 创建内存缓冲区
        audio_buffer = io.BytIO()
        
        # 写入为WAV格式
        sf.write(audio_buffer, data, samplerate, format=target_format)
        
        return audio_buffer.getvalue()
    except Exception as e:
        print(f"音频文件验证/转换错误: {str(e)}")
        raise

def invoke_whisper_endpoint(audio_data, endpoint_name='whisper'):
    """调用Whisper endpoint"""
    try:
        client = boto3.client('sagemaker-runtime')
        
        response = client.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType='audio/x-audio',
            Body=audio_data
        )
        
        return response['Body'].read().decode()
    except Exception as e:
        print(f"调用endpoint错误: {str(e)}")
        raise

# 使用示例
try:
    # 1. 首先验证并转换音频
    audio_data = validate_and_convert_audio('/home/sagemaker-user/whisper_deploy/test.wav')
    
    # 2. 调用endpoint
    result = invoke_whisper_endpoint(audio_data)
    print(result)
except Exception as e:
    print(f"处理失败: {str(e)}")

音频文件验证/转换错误: Error opening '/home/sagemaker-user/whisper_deploy/test.wav': Format not recognised.
处理失败: Error opening '/home/sagemaker-user/whisper_deploy/test.wav': Format not recognised.


In [14]:
from sagemaker.serializers import DataSerializer
from sagemaker.predictor import Predictor

def transcribe_audio(file_path, endpoint_name='whisper'):
    # 创建预测器
    predictor = Predictor(endpoint_name=endpoint_name)
    predictor.serializer = DataSerializer(content_type='audio/x-audio')
    
    # 直接读取并传递文件
    with open(file_path, "rb") as f:
        data = f.read()
    print('zzz')
    # 让模型处理音频文件
    result = predictor.predict(data)
    return result

# 使用示例
result = transcribe_audio("/home/sagemaker-user/whisper_deploy/test.wav")
print(result)

zzz


ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
  "code": 400,
  "type": "InternalServerException",
  "message": "Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote URL, ensure that the URL is the full address to **download** the audio file."
}
". See https://ap-northeast-1.console.aws.amazon.com/cloudwatch/home?region=ap-northeast-1#logEventViewer:group=/aws/sagemaker/Endpoints/whisper in account 034362076319 for more information.