### 1. 安装HuggingFace 并下载模型到本地

In [None]:
#!pip install huggingface-hub -Uqq

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

local_model_path = Path("./LLM_musicgen_model")
local_model_path.mkdir(exist_ok=True)
model_name = "facebook/musicgen-large"
commit_hash = "c19300a6b2b62d29b345ae9eb7b163278e65238a"

In [None]:
snapshot_download(repo_id=model_name, revision=commit_hash,cache_dir=local_model_path)

### 2. 把模型拷贝到S3为后续部署做准备

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts

region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

In [None]:
s3_model_prefix = "llm/models/LLM_musicgen_model"  # folder where model checkpoint will go
model_snapshot_path = list(local_model_path.glob("**/snapshots/*"))[0]
s3_code_prefix = "LLM-RAG/workshop/LLM_musicgen_deploy_code"
print(f"s3_code_prefix: {s3_code_prefix}")
print(f"model_snapshot_path: {model_snapshot_path}")

In [None]:
!aws s3 cp --recursive {model_snapshot_path} s3://{bucket}/{s3_model_prefix}

### 3. 模型部署准备（entrypoint脚本，容器镜像，服务配置）

In [None]:
inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.22.1-deepspeed0.9.2-cu118"
)

print(f"Image going to be used is ---- > {inference_image_uri}")

In [None]:
!mkdir -p LLM_musicgen_deploy_code

In [None]:
%%writefile LLM_musicgen_deploy_code/model.py
from djl_python import Input, Output
import torch
import logging
import os
import torch
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write



os.system("cp ./s5cmd  /tmp/ && chmod +x /tmp/s5cmd")


def load_model(properties):
    tensor_parallel_degree = properties["tensor_parallel_degree"]
    model_location = properties['model_dir']
    if "model_id" in properties:
        model_location = properties['model_id']
    logging.info(f"Loading model in {model_location}")
    
    model = MusicGen.get_pretrained(model_location)
    model.set_generation_params(duration=8)  # generate 8 seconds.
    return model


model = None


def handle(inputs: Input):
    global model
    if not model:
        model = load_model(inputs.get_properties())

    if inputs.is_empty():
        return None
    data = inputs.get_as_json()
    
    descriptions = data["descriptions"].split(",")

    #melody, sr = torchaudio.load('./assets/bach.mp3')
    # generates using the melody from the given audio and the provided descriptions.
    #wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)

    wav = model.generate(descriptions)
    
    response=[]
    for idx, one_wav in enumerate(wav):
    # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
       audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
       os.system("/tmp/s5cmd sync {0} {1}".format(idx,"s3://sagemaker-us-west-2-687912291502/llm/models/LLM_musicgen_model/output/"))
       response.append()

    result = {"outputs": response}
    return Output().add_as_json(result)

#### Note: option.s3url 需要按照自己的账号进行修改

In [None]:
%%writefile LLM_musicgen_deploy_code/serving.properties
engine=DeepSpeed
option.tensor_parallel_degree=4
option.s3url = s3://sagemaker-us-west-2-687912291502/llm/models/LLM_musicgen_model/

#### 注意: musicgen config 上transformers是4.29.2

In [None]:
%%writefile LLM_musicgen_deploy_code/requirements.txt
ffmpeg
git+https://github.com/facebookresearch/audiocraft.git


In [None]:
!rm model.tar.gz
!cd LLM_musicgen_deploy_code && rm -rf ".ipynb_checkpoints"
!tar czvf model.tar.gz LLM_musicgen_deploy_code

In [None]:
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}")

### 4. 创建模型 & 创建endpoint

In [None]:
from sagemaker.utils import name_from_base
import boto3

model_name = name_from_base(f"falcon") # Append a timestamp to the provided string
print(model_name)
print(f"Image going to be used is ---- > {inference_image_uri}")

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_code_artifact
    },
    
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"


endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.g5.12xlarge",
            "InitialInstanceCount": 1,
            # "VolumeSizeInGB" : 400,
            # "ModelDataDownloadTimeoutInSeconds": 2400,
            "ContainerStartupHealthCheckTimeoutInSeconds": 15*60,
        },
    ],
)
endpoint_config_response

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

#### 持续检测模型部署进度

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### 5. 模型测试

In [None]:
endpoint_name = "falcon-2023-06-14-06-30-46-229-endpoint"
prompts1 = """
'happy rock', 'energetic EDM', 'sad jazz'
"""


response_model = smr_client.invoke_endpoint(
            EndpointName=endpoint_name,
            Body=json.dumps(
            {
                "description": prompts1
            }
            ),
            ContentType="application/json",
        )

response_model['Body'].read()

#### 清除模型Endpoint和config

In [None]:
!aws sagemaker delete-endpoint --endpoint-name chatglm-2023-04-27-05-49-59-117-endpoint

In [None]:
!aws sagemaker delete-endpoint-config --endpoint-config-name chatglm-2023-04-27-05-49-59-117-config