In [None]:
%pip install openai-whisper==20230918 -q
%pip install torchaudio==2.1.0 -q
%pip install datasets==2.16.1 -q
%pip install sagemaker==2.184.0  -q
%pip install librosa -q
%pip install soundfile -q

In [None]:
import json
import torch
import whisper
import torchaudio
import sagemaker
import time
import json
import boto3
import soundfile as sf
from datasets import load_dataset
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.serializers import DataSerializer
from sagemaker.deserializers import JSONDeserializer

In [None]:
session = boto3.Session()
print("Region:", session.region_name)

In [None]:
boto_session = boto3.Session(region_name='us-east-1')
sess = sagemaker.Session(boto_session=boto_session)
bucket = 'whisper-bucket-unilex-new'
prefix = 'whisper_blog_post'
role = "arn:aws:iam::307946674662:role/service-role/AmazonSageMaker-ExecutionRole-20250306T120165"
#
# sm_runtime = boto3.client("sagemaker-runtime")

In [None]:
model = whisper.load_model("base")
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'dims': model.dims.__dict__,
    },
    'base.pt'
)

In [None]:
!mkdir -p model
!mv base.pt model
!tar cvzf model.tar.gz -C model/ .

In [None]:
model_uri = sess.upload_data('model.tar.gz', bucket=bucket, key_prefix=f"{prefix}/pytorch/model")
!rm model.tar.gz
!rm -rf model
model_uri

In [None]:
id = int(time.time())
model_name = f'whisper-pytorch-model-{id}'

image = "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"

In [None]:
whisper_pytorch_model = PyTorchModel(
    model_data=model_uri,
    image_uri=image,
    role=role,
    entry_point="inference.py",
    source_dir='/opt/var',
    name=model_name,
    env={
        'MMS_MAX_REQUEST_SIZE': '2000000000',
        'MMS_MAX_RESPONSE_SIZE': '2000000000',
        'MMS_DEFAULT_RESPONSE_TIMEOUT': '900'
    }
)

In [None]:
audio_serializer = DataSerializer(content_type="audio/x-audio")
deserializer = JSONDeserializer()

In [None]:
%%time
endpoint_name = f'whisper-pytorch-real-time-endpoint1'

real_time_predictor = whisper_pytorch_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    endpoint_name = endpoint_name,
    serializer=audio_serializer,
    entry_point="inference.py",
    deserializer = deserializer,
    region_name='us-east-1'
)