In [128]:
%store -r \
endpoint_name \
sagemaker_session_bucket

In [129]:
endpoint_name, sagemaker_session_bucket

('musicgen-large-v1-asyc-2023-11-02-16-57-23-399',
 'sagemaker-us-west-2-920487201358')

In [130]:
import sagemaker
sm_session = sagemaker.session.Session()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [131]:
import os
import uuid
import json


def generate_json(data):
    suffix = str(uuid.uuid1())
    filename = f'payload_{suffix}.json'
    with open(filename, 'w') as fp:
        json.dump(data, fp)
    return filename


def upload_input_json(sm_session, filename):
    return sm_session.upload_data(
        filename,
        bucket=sm_session.default_bucket(),
        key_prefix='musicgen_large/input_payload',
        extra_args={"ContentType": "application/json"},
    )


def delete_file_on_disk(filename):
    if os.path.isfile(filename):
        os.remove(filename)

In [132]:
import urllib, time
from botocore.exceptions import ClientError
import random

def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    icons = ["🎵","🎶","🎷","🎸","🎺","🎼","🥁"]
    print("generating music")
    while True:
        try:
            res = sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
            print("\nMusic is ready!🎉")
            return res
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                
                print(random.choice(icons), end = '')
                time.sleep(2)
                continue
            raise
    

import botocore
def download_from_s3(url):
    """ex: url = s3://bucketname/prefix1/music.wav"""
    url_parts = url.split("/")  # => ['s3:', '', 'sagemakerbucketname', 'data', ...
    bucket_name = url_parts[2]
    key = os.path.join(*url_parts[3:])
    filename = url_parts[-1]
    if not os.path.exists(filename):
        try:
            # Create an S3 client
            s3 = boto3.resource('s3')
            print('Downloading {} to {}'.format(url, filename))
            s3.Bucket(bucket_name).download_file(key, filename)
            return filename
        except botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] == "404":
                print('The object {} does not exist in bucket {}'.format(
                    key, bucket_name))
            else:
                raise
                
                
import IPython
def play_output_audio(filename):
    return IPython.display.Audio(filename)


In [145]:
default_config = { 'guidance_scale': 3, 'max_new_tokens': 512, 'do_sample': True }
data = {
    "texts": ['Morning sunshine, beats, ukelele, happy swings'],#["Peaceful, happy feeling, spending time with beloved son"], #['Morning sunshine, bets, ukelele, happy swings'],
    "bucket_name": sagemaker_session_bucket,
    "config": default_config
}
data

{'texts': ['Morning sunshine, beats, ukelele, happy swings'],
 'bucket_name': 'sagemaker-us-west-2-920487201358',
 'config': {'guidance_scale': 3, 'max_new_tokens': 512, 'do_sample': True}}

In [146]:
filename = generate_json(data)
input_s3_location = upload_input_json(sm_session, filename)
delete_file_on_disk(filename)

In [147]:
input_s3_location

's3://sagemaker-us-west-2-920487201358/musicgen_large/input_payload/payload_eebda67c-79a3-11ee-9bf5-6fe380f23dcc.json'

In [148]:
import boto3
sagemaker_runtime = boto3.client('sagemaker-runtime')
response = sagemaker_runtime.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=input_s3_location,
    ContentType="application/json",
)

In [149]:
response

{'ResponseMetadata': {'RequestId': '2b13ee5a-7588-46ba-ad91-4caf6f0e5c7a',
  'HTTPStatusCode': 202,
  'HTTPHeaders': {'x-amzn-requestid': '2b13ee5a-7588-46ba-ad91-4caf6f0e5c7a',
   'x-amzn-sagemaker-outputlocation': 's3://sagemaker-us-west-2-920487201358/musicgen_large/async_inference/music_output/ceac81f2-a57e-4509-8b62-677488b6661b.out',
   'x-amzn-sagemaker-failurelocation': 's3://sagemaker-us-west-2-920487201358/async-endpoint-failures/musicgen-large-v1-asyc-2023-11-02-16-57-23-399-1698944244-e1dc/ceac81f2-a57e-4509-8b62-677488b6661b-error.out',
   'date': 'Thu, 02 Nov 2023 17:19:10 GMT',
   'content-type': 'application/json',
   'content-length': '54',
   'connection': 'keep-alive'},
  'RetryAttempts': 0},
 'OutputLocation': 's3://sagemaker-us-west-2-920487201358/musicgen_large/async_inference/music_output/ceac81f2-a57e-4509-8b62-677488b6661b.out',
 'FailureLocation': 's3://sagemaker-us-west-2-920487201358/async-endpoint-failures/musicgen-large-v1-asyc-2023-11-02-16-57-23-399-1698

In [153]:
response.get('OutputLocation')

's3://sagemaker-us-west-2-920487201358/musicgen_large/async_inference/music_output/ceac81f2-a57e-4509-8b62-677488b6661b.out'

In [None]:
%%time
output = get_output(response.get('OutputLocation'))

In [None]:
output = json.loads(output)
output.keys()

In [None]:
response.get('OutputLocation')

In [142]:
output.get('generated_output_s3')

In [None]:
music = download_from_s3(output.get('generated_output_s3'))

In [None]:
play_output_audio(music)

In [None]:
delete_file_on_disk(music)

## Cleanup

In [156]:
sm_client = boto3.client('sagemaker')
endpoint = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_config_name = endpoint['EndpointConfigName']
endpoint_config = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
model_name = endpoint_config['ProductionVariants'][0]['ModelName']

print(f"""
About to delete the following sagemaker resources:
Endpoint: {endpoint_name}
Endpoint Config: {endpoint_config_name}
Model: {model_name}
""")


About to delete the following sagemaker resources:
Endpoint: musicgen-large-v1-asyc-2023-11-02-16-57-23-399
Endpoint Config: musicgen-large-v1-asyc-2023-11-02-16-57-23-399
Model: musicgen-large-v1-asyc-2023-11-02-16-57-23-399



In [158]:
# delete endpoint
#sm_client.delete_endpoint(EndpointName=endpoint_name)
# delete endpoint config
#sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
# delete model
#sm_client.delete_model(ModelName=model_name)