# MusicGen Async SageMaker Inference
https://github.com/aws/amazon-sagemaker-examples/blob/main/async-inference/Transcription_on_SM_endpoint.ipynb
https://huggingface.co/docs/transformers/model_doc/musicgen#generation

In [None]:
!pip install -Uq pip
!pip install -Uq sagemaker

In [None]:
!mkdir model
!mkdir model/code

In [None]:
## requirements.txt https://github.com/facebookresearch/audiocraft/blob/main/README.md
'''
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
# Then proceed to one of the following
pip install -U audiocraft  # stable release
pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft  # bleeding edge
pip install -e .  # or if you cloned the repo locally (mandatory if you want to train).
'''
with open("model/code/requirements.txt", "w") as f:
    f.write("transformers==4.34.1\n")
    f.write("boto3\n")
    f.write("torch>=2.0\n")
    f.write("scipy\n")
    f.write("uuid\n")
    f.write("audiocraft\n")
    f.write("git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft\n")


In [None]:
%%writefile model/code/inference.py

import boto3
from urllib.parse import urlparse
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
import uuid
import torch
import os


def model_fn(model_dir):
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    return model


def process_input(data):
    
    #defaults
    default_config = { 'guidance_scale': 3, 'max_new_tokens': 256, 'do_sample': True }
    default_texts = ["Morning sunshine, beats, ukelele, happy swings"]
    
    # obtain input data
    if 'texts' in data.keys():
        texts = data.pop('texts')
    else:
        texts = default_texts
    
    if 'config' in data.keys():
        config = data.pop('config')
    else:
        config = default_config
    
    processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
    inputs = processor(
        text = texts, #["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
        padding=True,
        return_tensors="pt",
    )
    
    assert type(config) == dict
    if 'guidance_scale' in config.keys():
        guidance_scale = config.pop('guidance_scale')
    else:
        guidance_scale = default_config.pop('guidance_scale')
    
    if 'max_new_tokens' in config.keys():
        max_new_tokens = config.pop('max_new_tokens')
    else:
        max_new_tokens =  default_config.pop('max_new_tokens')

    if 'do_sample' in config.keys():
        do_sample = config.pop('do_sample')
    else:
        do_sample = default_config.pop('do_sample')
    
    processed_config = {'guidance_scale': guidance_scale, 'do_sample': do_sample, 'max_new_tokens': max_new_tokens}
    return inputs, processed_config


def upload_to_s3(wav_on_disk, bucket_name):
    s3 = boto3.resource('s3')
    target_file = wav_on_disk.split('/')[-1]
    s3.Bucket(bucket_name).upload_file(wav_on_disk, f'musicgen_large/output/{target_file}')
    return f"s3://{bucket_name}/musicgen_large/output/{target_file}"


def delete_file_on_disk(filename):
    if os.path.isfile(filename):
        os.remove(filename)
        

def write_to_s3(sampling_rate, audio_values, bucket_name):
    r = len(audio_values)
    s3_wav_keys = []
    s3_prefix = 'musicgen_large/output'
    for i in range(r):
        
        # Write wav to Disk
        prefix = str(uuid.uuid1())
        wav_file = f"{prefix}_{i}_musicgen_large_out.wav"
        wav_on_disk = f'/tmp/{wav_file}'
        try:
            #scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[i, 0].numpy())
            scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[i, 0].cpu().numpy())
        except:
            wav_on_disk = f'/opt/ml/output/data/{wav_file}'
            #scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[0, 0].numpy())
            scipy.io.wavfile.write(wav_on_disk, rate=sampling_rate, data=audio_values[i, 0].cpu().numpy())

        # Upload to S3
        upload_to_s3(wav_on_disk, bucket_name)
        
        # Clean up disk
        delete_file_on_disk(wav_on_disk)
        
    return s3_wav_keys


def predict_fn(data, model):
    
    # https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict
    
    bucket_name = data.pop('bucket_name', None)
    
    if not bucket_name:
        raise ValueError("bucket_name is required ex: sagemaker_default_bucket_007")
        
    inputs, processed_config = process_input(data)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    audio_values = model.generate(**inputs.to(device), 
                                  max_new_tokens=processed_config.pop('max_new_tokens'), 
                                  guidance_scale = processed_config.pop('guidance_scale'),
                                  do_sample = processed_config.pop('do_sample'))

    sampling_rate = model.config.audio_encoder.sampling_rate
    s3_wav_keys = write_to_s3(sampling_rate, audio_values, bucket_name)
    return {
        "generated_outputs_s3": s3_wav_keys,
        "audio_values": audio_values
    }

In [None]:
%cd model

In [None]:
!rm model.tar.gz

In [None]:
!rm -rf code/.ipynb_checkpoints*

In [None]:
!tar zcvf model.tar.gz *

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
musicgen_prefix = 'musicgen_large'
s3_model_key = f'{musicgen_prefix}/model/model.tar.gz'
s3_model_location = f"s3://{sagemaker_session_bucket}/{s3_model_key}"

In [None]:
s3 = boto3.resource("s3")
s3.Bucket(sagemaker_session_bucket).upload_file("model.tar.gz", s3_model_key)

## Async Inference

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join
from sagemaker.utils import name_from_base

async_endpoint_name = name_from_base("musicgen-large-v1-asyc")

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    name=async_endpoint_name,
    model_data=s3_model_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.28",  # transformers version used
    pytorch_version="2.0",  # pytorch version used
    py_version="py310",  # python version used
)

# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "musicgen_large/async_inference/music_output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        # "SuccessTopic": "PUT YOUR SUCCESS SNS TOPIC ARN",
        # "ErrorTopic": "PUT YOUR ERROR SNS TOPIC ARN",
    },  #  Notification configuration
)

# deploy the endpoint endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
)


In [None]:
endpoint_name=async_predictor.endpoint_name

In [None]:
%store \
endpoint_name \
sagemaker_session_bucket

In [None]:
endpoint_name