# Deploy MusicGen Large model on Amazon SageMaker for Asynchronous Inferencing

https://github.com/aws/amazon-sagemaker-examples/blob/main/async-inference/Transcription_on_SM_endpoint.ipynb
https://huggingface.co/docs/transformers/model_doc/musicgen#generation

Install necessary packages to run this notebook on SageMaker Studio.

In [None]:
!pip install -Uq pip
!pip install -Uq sagemaker

## Prepare Inference Scripts

We will create model directory that holds the code artefacts such as inference python script and the requirements.txt that holds all the relevant python packages that will be installed when the model is deployed.

In [None]:
!mkdir model
!mkdir model/code

In [None]:
## requirements.txt https://github.com/facebookresearch/audiocraft/blob/main/README.md
'''
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
# Then proceed to one of the following
pip install -U audiocraft  # stable release
pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft  # bleeding edge
pip install -e .  # or if you cloned the repo locally (mandatory if you want to train).
'''
with open("model/code/requirements.txt", "w") as f:
    f.write("transformers==4.34.1\n")
    f.write("boto3\n")
    f.write("torch>=2.0\n")
    f.write("scipy\n")
    f.write("uuid\n")
    f.write("audiocraft\n")
    f.write("ffmpeg\n")
    f.write("ffmpeg-python\n")
    f.write("git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft\n")


In [None]:
%%writefile model/code/inference.py

import boto3
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
import scipy
import uuid
import torch
import os


def model_fn(model_dir):
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
    device = "cuda:0,1,2,3" if torch.cuda.is_available() else "cpu"
    model.to(device)
    return model


def _process_input(data):
    
    #defaults
    default_config = { 'guidance_scale': 3, 'max_new_tokens': 256, 'do_sample': True }
    default_texts = ["Morning sunshine, beats, ukelele, happy swings"]
    
    # obtain input data
    texts = data.get('texts', default_texts)
    config = data.get('config', default_config)
    
    processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
    inputs = processor(
        text = texts, #["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
        padding=True,
        return_tensors="pt",
    )
    
    assert type(config) == dict
    guidance_scale = config.get('guidance_scale', default_config.get('guidance_scale'))
    max_new_tokens = config.get('max_new_tokens', default_config.get('max_new_tokens'))
    do_sample = config.get('do_sample', default_config.get('do_sample'))
    
    processed_config = {'guidance_scale': guidance_scale, 'do_sample': do_sample, 'max_new_tokens': max_new_tokens}
    return inputs, processed_config


def _upload_to_s3(wav_on_disk, bucket_name):
    s3 = boto3.resource('s3')
    target_file = wav_on_disk.split('/')[-1]
    prefix = 'musicgen_large/output'
    key = f'{prefix}/{target_file}'
    s3.Bucket(bucket_name).upload_file(wav_on_disk, key)
    return f"s3://{bucket_name}/{key}"


def _upload_wav_files(wav_files, bucket_name):
    #wav_files = [f for f in os.listdir('.') if f.endswith('.wav')]
    futures_list = []
    results = []
    max_workers = multiprocessing.cpu_count () * 2 + 1
    with ThreadPoolExecutor(max_workers=max_workers) as executor: # I/O bound process
        for file in wav_files:
            futures = executor.submit(_upload_to_s3, file, bucket_name)
            futures_list.append(futures)

        for future in futures_list:
            try:
                result = future.result(timeout=60)
                results.append(result)
            except Exception:
                results.append(None)

    return results


def _write_wavs_to_disk(sampling_rate, audio_values):
    r = len(audio_values)
    prefix = str(uuid.uuid1())
    disk_wav_locations = []
    for i in range(r):
        # Write wav to Disk
        wav_file = f"{prefix}_{i}_musicgen_large_out.wav"
        wav_on_disk = f'/tmp/{wav_file}'
        scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[i, 0].cpu().numpy())
        disk_wav_locations.append(wav_on_disk)
    return disk_wav_locations


def _delete_file_on_disk(filename):
    if os.path.isfile(filename):
        os.remove(filename)


def predict_fn(data, model):
    
    # https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict
    
    bucket_name = data.pop('bucket_name', None)
    
    if not bucket_name:
        raise ValueError("bucket_name is required ex: sagemaker_default_bucket_007")
        
    inputs, processed_config = _process_input(data)
    
    device = "cuda:0,1,2,3" if torch.cuda.is_available() else "cpu"
    audio_values = model.generate(**inputs.to(device), 
                                  max_new_tokens=processed_config.pop('max_new_tokens'), 
                                  guidance_scale = processed_config.pop('guidance_scale'),
                                  do_sample = processed_config.pop('do_sample'))

    # process generated output audio_values
    sampling_rate = model.config.audio_encoder.sampling_rate
    disk_wav_locations = _write_wavs_to_disk(sampling_rate, audio_values)

    # Upload wavs to S3
    results = _upload_wav_files(disk_wav_locations, bucket_name)
    # Clean up disk
    for wav_on_disk in disk_wav_locations:
        _delete_file_on_disk(wav_on_disk)
    return {
        "generated_outputs_s3": results
    }

In [None]:
%cd model

## Place the inference scripts archive on Amazon S3

We will create the archive of the inference scripts and upload those to Amazon S3 bucket. The uploaded uri of this object on S3 will later be used to create the HuggingFace Model.

In [None]:
!rm model.tar.gz

In [None]:
!rm -rf code/.ipynb_checkpoints*

In [None]:
!tar zcvf model.tar.gz *

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
musicgen_prefix = 'musicgen_large'
s3_model_key = f'{musicgen_prefix}/model/model.tar.gz'
s3_model_location = f"s3://{sagemaker_session_bucket}/{s3_model_key}"

In [None]:
s3 = boto3.resource("s3")
s3.Bucket(sagemaker_session_bucket).upload_file("model.tar.gz", s3_model_key)

## Deploy Asynchronous Inference Endpoint on Amazon SageMaker

To create the Asynchronous Inference Endpoint, we will perform the following steps:

1. Create a model of type HuggingFaceModel since we are using the musicgen large model from HuggingFace as model provider. 
2. We then create an asynchronous endpoint configuration whose notification configuration is associated to Amazon SNS for success and error topics. 
3. We will finally deploy the model to generate a SageMaker Endpoint for asynchronous inferencing.

### Create a Model

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join
from sagemaker.utils import name_from_base

async_endpoint_name = name_from_base("musicgen-large-v1-asyc")

# create Hugging Face Model Class
# AudioCraft's models requires Python 3.9, PyTorch 2.0.0 or latest. (https://github.com/facebookresearch/audiocraft/blob/main/README.md)
huggingface_model = HuggingFaceModel(
    name=async_endpoint_name,
    model_data=s3_model_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.28",  # transformers version used
    pytorch_version="2.0",  # pytorch version used
    py_version="py310",  # python version used
)

### Asynchronous Inference Configuration

#### Create Amazon SNS topics for Success and Failure Notification configuration

In [None]:
%cd ..

In [None]:
import sys, os
# https://stackoverflow.com/a/8015152
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))

from utils.sns_client import SnsClient
del sys.path[0]

In [None]:
# Create SNS topic
import time
#from utils.sns_client import SnsClient

sns_client = SnsClient(boto3.client("sns"))
timestamp = time.time_ns()
topic_names = [f"musicgen-large-topic-SuccessTopic-{timestamp}", f"musicgen-large-topic-ErrorTopic-{timestamp}"]

topic_arns = []
for topic_name in topic_names:
    print(f"Creating topic {topic_name}.")
    response = sns_client.create_topic(topic_name)
    topic_arns.append(response.get('TopicArn'))

In [None]:
topic_arns

#### Create Async Inference Configuration and associate it with SNS topics using notification configuration

In [None]:
# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "musicgen_large/async_inference/music_output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        "SuccessTopic": topic_arns[0],
        "ErrorTopic": topic_arns[1],
    },  #  Notification configuration
)

### Deploy the model to generate a SageMaker Endpoint for asynchronous inferencing

In [None]:
# deploy the endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.12xlarge",
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
)

Meanwhile the model gets deployed, you can refer to the following info about Amazon SageMaker and Facebook Musicgen.

- https://huggingface.co/facebook/musicgen-large
- https://huggingface.co/docs/transformers/model_doc/musicgen#generation
- https://github.com/facebookresearch/audiocraft/blob/main/README.md
- https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model
- https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict
- https://github.com/aws/amazon-sagemaker-examples/blob/main/async-inference/Transcription_on_SM_endpoint.ipynb

Let's save the variables that will be re-used in the infer notebook.

In [None]:
endpoint_name=async_predictor.endpoint_name

In [None]:
%store \
endpoint_name \
sagemaker_session_bucket \
topic_arns

In [None]:
endpoint_name