## MusicGen

- HF model hub : https://huggingface.co/facebook/musicgen-large]
- DJL example (deepspeed) : https://github.com/andjsmi/musicgen/blob/main/musicgen.ipynb

In [None]:
import boto3
import json
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client
s3_client = boto3.client('s3')
default_bucket = sagemaker_session.default_bucket()

In [None]:
llm_engine = "deepspeed"
# llm_engine = "fastertransformer"

In [None]:
framework_name = f"djl-{llm_engine}"
inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.22.1"
)

print(f"Inference container uri: {inference_image_uri}")

In [None]:
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/musicgen-large/code/"
print(s3_target)

In [None]:
!rm -rf musicgen-src.tar.gz
!tar zcvf musicgen-src.tar.gz musicgen-src --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp musicgen-src.tar.gz {s3_target}

In [None]:
model_uri = f"{s3_target}musicgen-src.tar.gz"
print(model_uri)

In [None]:
model_name = name_from_base(f"musicgen-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
default_bucket = sagemaker_session.default_bucket()
async_output_uri = f"s3://{default_bucket}/llm/outputs/{model_name}/"
print(async_output_uri)

In [None]:
instance_type = "ml.g5.2xlarge"
# instance_type = "ml.g5.xlarge"
# instance_type = "ml.g4dn.xlarge"

endpoint_config_name = f"{model_name}-async-config"
endpoint_name = f"{model_name}-async-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": async_output_uri,
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 1
        }
    }
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json
import uuid

In [None]:
# prompt = "Sage are playing games with his pet, disney style"
# prompt = "A man holds a phone with tiger, picasso style, detailed, 8k"
# prompt = "beautiful, edm style, focusing music for study"
prompt = "chillstep feels like lay down in the beach, good for focusing"

In [None]:
payload = {
    "text": [prompt],
    "upload_s3_bucket": sagemaker_session.default_bucket(),
}

In [None]:
# Upload input data onto the S3
s3_uri = f"llm/inputs/{model_name}/{uuid.uuid4()}.json"
s3_client.put_object(
    Bucket=default_bucket,
    Key=s3_uri,
    Body=json.dumps(payload))

input_data_uri = f"s3://{default_bucket}/{s3_uri}"
input_location = input_data_uri

In [None]:
response = sm_runtime_client.invoke_endpoint_async(
    EndpointName=endpoint_name, 
    InputLocation=input_location,
    ContentType="application/json"
)
output_location = response["OutputLocation"]
print(output_location)
output_key_uri = "/".join(output_location.split("/")[3:])

In [None]:
try:
    exists = s3_client.head_object(Bucket=default_bucket, Key=output_key_uri)['ResponseMetadata']['HTTPStatusCode'] == 200
    if exists:
        text_obj = s3_client.get_object(Bucket=default_bucket, Key=output_key_uri)['Body'].read()
        text = text_obj.decode('utf-8')
        print(text)
        # raw_output = json.loads(text)[0]["generated_text"]
        # output = raw_output[len(prompt):]
        # print(output)
except:
    print("Data is not exist yet. Wait until inference finished or check the CW log")

In [None]:
!mkdir -p test-output

In [None]:
import os
import boto3
from IPython.display import Audio

s3_client = boto3.client('s3')

def get_s3_file(s3_uri):
    chunks = s3_uri.split("/")
    filename = chunks[-1]
    bucket = chunks[2]
    object_name = "/".join(chunks[3:])
    local_path = os.path.join("./test-output", filename)
    s3_client.download_file(bucket, object_name, local_path)
    return local_path


In [None]:
local_path = get_s3_file(text)
Audio(local_path)