# Deployment notebook

This notebook is used to deploy the endpoint using the Sagemaker SDK, both locally and 
online. This is not meant to be the main source of endpoint provision, which should be
done with terraform through the CD pipeline, but rather this is a way to test that
everything works before provisioning it.

It also register the model in the model registry for CD provisioning later.

---

Before running the cells, make sure you login to AWS using either:

- `aws configure sso` → for first time login
- `aws sso login` → for all subsequent login

In [1]:
# general settings, shared between local and online deployments

model_name = "musicgen"
model_entry_point = "../src/code/inference.py"
model_data = "../model/model.tar.gz"

endpoint_name = "endpoint-musicgen-0001-dev"
instance_type = "ml.g4dn.xlarge"

In [2]:
# set local temp folder to avoid /tmp to become full
import os
from pathlib import Path

repo_root_dir = Path(os.getcwd()).parents[2].resolve()
local_temp_folder_path = str(repo_root_dir / ".temp" / "sagemaker_local")

## Local

In [None]:
!pip install sagemaker[local]

In [3]:
import sagemaker
from sagemaker.local import LocalSession
from sagemaker.pytorch import PyTorchModel

session = LocalSession()

session.config = {
    "local": {
        "local_code": True,
        "container_root": local_temp_folder_path,
    }
}

session.settings = sagemaker.session_settings.SessionSettings(
    local_download_dir = local_temp_folder_path
)

role = sagemaker.get_execution_role()

print("Role:", role)
print("Local temp folder path:", local_temp_folder_path)

Role: arn:aws:iam::138140302683:role/aws-reserved/sso.amazonaws.com/AWSReservedSSO_AdministratorAccess_6f1d7369dc867f6b
Local temp folder path: /home/ubuntu/musicgen-endpoint-ableton/.temp/sagemaker_local


In [None]:
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer

model_image_uri = "138140302683.dkr.ecr.us-east-1.amazonaws.com/musicgen-pytorch:1.0"

model = PyTorchModel(
    name=model_name,
    role=role,
    entry_point=model_entry_point,
    model_data=model_data,
    image_uri=model_image_uri,
    sagemaker_session=session,
)

predictor = model.deploy(
    initial_instance_count=1,
    instance_type="local_gpu",
    endpoint_name=endpoint_name,
    sagemaker_session=session,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

## Online

In [4]:
import boto3
import sagemaker
from sagemaker.pytorch import PyTorchModel

boto_session = boto3.Session()
client = boto3.client(service_name="sagemaker")

session = sagemaker.Session()

session.settings = sagemaker.session_settings.SessionSettings(
    local_download_dir = local_temp_folder_path
)

role = "arn:aws:iam::138140302683:role/service-role/AmazonSageMaker-ExecutionRole-20230522T162566"

In [9]:
# step 1: create the model
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer

model_image_uri = "138140302683.dkr.ecr.us-east-1.amazonaws.com/musicgen-pytorch:1.0"

model = PyTorchModel(
    name=model_name,
    role=role,
    entry_point=model_entry_point,
    model_data=model_data,
    image_uri=model_image_uri,
    sagemaker_session=session,
)

In [7]:
# step 2: register the model

model.register(
    model_package_group_name="musicgen-model",
    content_types=["application/json"],
    response_types=["application/json"],
    inference_instances=[instance_type],
    approval_status="Approved",
)

<sagemaker.model.ModelPackage at 0x7f432c445390>

In [10]:
# step 3: create endpoint

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    sagemaker_session=session,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

---------------!

# Predictor tests

In [49]:
import base64

def base64_to_audio_file(base64_string: str, audio_file_path: str) -> None:
    """Converts a base64-encoded string in an audio file"""

    with open(audio_file_path, "wb") as audio_file:
        audio_file.write(base64.b64decode((base64_string)))

input_data = {
    "prompt": "berghain acid techno",
    "duration": 3,
    "temperature": 1.0,
    "top_p": 0.0,
    "top_k": 250,
    "cfg_coefficient": 3.0,
}

In [42]:
response = predictor.predict(data=input_data)
print("Response:", response)

base64_audio = response["result"]["prediction"]
base64_to_audio_file(base64_audio, "predictor_response.mp3")

Response: {'result': {'prediction': '//voxAAAOvInHjW8gAeKw6gjO7AAAM+zePCnDbjgLEJpLKbu/nBwpwsSca9nRSRx0Qb87G3MxsS4awqGoJRoRwZULGqOY5pptnHWeOp78nvmdLocmjejmmQYyBsPHBIckhyTHREetp0xmyeGDvEhIUsR3M1M6KDomOSI4pDosPDI9MjwsOhw5XTbbNlk2VzXXNdc1UTLDT7eNIcuGYIpklmaWY4LFgQEa7JsqmiWBh1rySHF5mCKayxrLGkgZxxmGGYUZRRlHGkcb7JrpmaKWvUHWHL/mCOZ5puNG48bzhtMGskZhACERQcSflbAy1ZhjmWaZ5pkhprxF343buPouRMRMBFcu+WXMEEwQSzZcsuWW3MIAwgjCEMQYxiDIIMYYwgEFFN0vy4ZZctOg+oOsOmOXgSIWI1x2GsLCJCJEF4y75acAgmCGYYphhmGCWvLiFsAAAYQRhBGEAWQLaIoJForgEEwQSzZeNQdx78rvwwwxUigCYiEgvAWTLJmCCYIJhhmGGAgy26KaYCxF2M4ZwzhdjEH4jbtqZoSy5ZctAOkWsdYjOGuOQ1hhipFSMQYmsOmOiuiuiuiuhLLlmEAYAACALMF4EHExFSKCIB0A6AdFdMdY7E3Lh+HH8chnDOGGLsXYsRFdFdItMdQcCCAAAAAAQxxOcwWAsz/KgxOB0XEQyHI8xgBMaE4GAMYQgeCAiMBBQV8YNgyYCAQQg+YYAMFBNL2kQQYQGDxCa0mGqB5bgLhjDgCMmWhZlCMcCRBVWOzcRYDV2RBQOER4UMpRwMZAocBRobNHGzsAoAiw4DhomAFVnMS2BAMYshGOBRkI6ZeCraNPmDFjUzoRAKQctlG6tBbNEwYAhoUVEijimuaOimri6P5m5qCiswQKM5NyJDNmIQRMG8ggFQjZwwSuhAyGYnsNLfTXW6r9hbQGTrCrmDgcwwTMTGTNwMxgWRDBQu

In [50]:
import boto3
import json

runtime_client = boto3.client('sagemaker-runtime')

response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(input_data),
)

response_body = json.loads(response["Body"].read().decode())
print(response_body)

{'result': {'prediction': '//voxAAAPe4nEjXdAAdxQ6YDPaAAAE0FMkxHHA0SY43mho4ohI3iZw1OQQzKK4ywKgzvN8zVMUyzKMyTIExtFkxDC8wZAEUAwwmCw5805sc3rM1J8zJMxocBB0F4PXOoO860AIEMQOMgUMsaNAcNanNGdMqJQrXwWwQUZiYEOZg0Zw8Z48aBEaJIaREZ40ZQcChCP6mZf8suYMGYUKBg6sbIDAgzDiTFjTHkzLmTNmTLkwEXa0YEGXTU6AgAwQBBSKsPZfDjAC1BhBRkChljBmDBljBmihozpmSICJr3iFEquARJkyppkxqFBpkhoDhmCRjAhbyBVzlyC2iKDaMzLhlk0N3wAoEwIEsmWnLboA1dxpYQtmWTLJlmy26D6Y6Q6AcBACyAAAGCCGIFGMHGOHGMEAoIgDLvgECYMCWjV3KmsPxK2BlwwCBAQNBd15RLKdw0E5gQ5jSJlS5lypkRJeuK0zgLsYg5DWEvCyBjRpjx5jxpixJhQJdtv00DCBjLlzKkTFgy+br0DsMMVIgHQDoBy75bMtOXHTrkb7sPUDRXQDorqDqBoBwEALiMkgtQNCWkWy+XLnRPLZllyzaANY7vv+1x3JZdfdh6X4FDmPImiQGgOGYIGIAJWS125cpgAQJgwZgQJZsvGpvIloFyDDhTDhTBgS26RbL5+AF2IrqDsvn6huylZmE2FeasCZhhXARG/0jqYJwVRgdgohcAQwmQDAcEEYDANwKCgFhGzAnAwMDkF4IAiJggjBdAgZODiAlmMSDKpUypZMsEihkRFlJs0AoVVYGESIWLDwMbEZFDcGDIYM+cBAAuqVQo04VSUKR9Z4j4OAiBMnGu5qj+gocZhqNHAwOnqDhpgSgXAwyFwhhgz1gaIakqDs4BHhBwCDDEFUDlJhciZswLCQGiZ6FFphNxkS5lYxkwawhQnBQseArqIhA8cFhYwBQTKzNokw2EKBi7YGjoc

In [11]:
# delete endpoint 
# NOTE: this doesn't delete the model in the s3 bucket, nor it deletes the model from
# model registry, nor the image from the ECR 

predictor.delete_model()
predictor.delete_endpoint()