In [None]:
import os
import boto3
import sagemaker
from typing import Final
from time import sleep

# 各種クライアントの設定
smr_client: Final = boto3.client('sagemaker-runtime')
sm_client: Final = boto3.client('sagemaker')
s3_client: Final = boto3.client('s3')
endpoint_waiter: Final = sm_client.get_waiter('endpoint_in_service')

# 各種変数の設定
role: Final[str] = sagemaker.get_execution_role()
region: Final[str] = sagemaker.Session().boto_region_name
bucket: Final[str] = sagemaker.Session().default_bucket()
account_id: Final[str] = boto3.client('sts').get_caller_identity().get('Account')

In [None]:
# 作業ディレクトリに移動
os.chdir('/home/ec2-user/SageMaker/')

# モデルのディレクトリを設定
model_dir: Final[str] = 'whisper-model'

# モデルをS3にアップロード
model_s3_uri: Final[str] = sagemaker.Session().upload_data(
    f'./{model_dir}/model.tar.gz',
    key_prefix='whisper-transcribe'
)
print(model_s3_uri)

In [None]:
# デプロイリソース名の設定
model_name: Final[str] = 'WhisperTranscribeModel'
variant_name: Final[str] = 'AllTrafic'
endpoint_name: Final[str] = model_name + 'Endpoint'
endpoint_config_name: Final[str] = model_name + 'EndpointConfig'

In [None]:
# Sagemaker Model 作成
sm_client.create_model(
    ModelName=model_name,
    PrimaryContainer={
        'Image': f'{account_id}.dkr.ecr.{region}.amazonaws.com/whisper-transcribe:GPU',
        'ModelDataUrl': model_s3_uri,
        'Environment': {
            'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
            'SAGEMAKER_PROGRAM': 'inference.py',
            'SAGEMAKER_REGION': region,
            'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code',
            'TS_MAX_REQUEST_SIZE': '1000000000',
            'TS_MAX_RESPONSE_SIZE': '1000000000',
            'TS_DEFAULT_RESPONSE_TIMEOUT': '3600'
        }
    },
    ExecutionRoleArn=role,
)

In [None]:
# Sagemaker EndpointConfig 作成
sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': variant_name,
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.g4dn.xlarge'
        }
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": f"s3://{bucket}/whisper-transcribe/async-inference/output"
        }
    }
)

In [None]:
# 非同期 Endpoint 作成
sm_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)

# 非同期 Endpoint 更新
# sm_client.update_endpoint(
#     EndpointName=endpoint_name,
#     EndpointConfigName=endpoint_config_name,
# )

# Endpoint が有効化されるまで待つ
endpoint_waiter.wait(
    EndpointName=endpoint_name,
    WaiterConfig={'Delay': 5}
)

In [None]:
# 非同期 Endpoint 呼び出し
response = smr_client.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=f"s3://{bucket}/whisper-transcribe/async-inference/input/test.mp3",
    ContentType='audio/mpeg',
    Accept='text/plain'
)
output_s3_uri = response['OutputLocation']
output_key = output_s3_uri.replace(f's3://{bucket}/', '')
while True:
    result = s3_client.list_objects(Bucket=bucket, Prefix=output_key)
    exists = True if "Contents" in result else False
    if exists:
        print('!')
        obj = s3_client.get_object(Bucket=bucket, Key=output_key)
        predictions = obj['Body'].read().decode()
        print(predictions)
        break
    else:
        print('.', end='')
        sleep(0.1)

In [None]:
# 非同期 Endpoint 削除
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=model_name)