Skip to content

Commit

Permalink
Merge pull request #10 from replicate/optimise
Browse files Browse the repository at this point in the history
Code looks code and I've played around w the model at https://replicate.com/fofr/meta-prototype?prediction=orwifjbbk7rzkvzmkq4gfsvkxe
  • Loading branch information
zsxkib committed Mar 27, 2024
2 parents a7b8c1d + b640bde commit 647e795
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 84 deletions.
11 changes: 11 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
**/__pycache__
**/.git
**/.github
**/.ci

# Outputs
**/*.wav
**/*.mp3

# Models
models/
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.cog/
models/
**/__pycache__
*.wav
*.mp3
5 changes: 3 additions & 2 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ build:
- "torchmetrics"
- "encodec"
- "protobuf"

# commands run after the environment is setup
# run:
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
Expand Down
153 changes: 71 additions & 82 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,10 @@

import os
import random

# We need to set `TRANSFORMERS_CACHE` before any imports, which is why this is up here.
MODEL_PATH = "/src/models/"
os.environ["TRANSFORMERS_CACHE"] = MODEL_PATH
os.environ["TORCH_HOME"] = MODEL_PATH


import shutil

from tempfile import TemporaryDirectory
from pathlib import Path
from distutils.dir_util import copy_tree
import time
from typing import Optional
from cog import BasePredictor, Input, Path
import torch
import datetime

# Model specific imports
import torchaudio
import subprocess
import typing as tp
Expand All @@ -32,63 +18,61 @@
load_lm_model,
)
from audiocraft.data.audio import audio_write
from weights_downloader import WeightsDownloader

MODEL_PATH = "/src/models/"
os.environ["HF_HOME"] = MODEL_PATH
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["TORCH_HOME"] = MODEL_PATH


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"

self.mbd = MultiBandDiffusion.get_mbd_musicgen()

self.stereo_melody_model = self._load_model(
model_path=MODEL_PATH,
cls=MusicGen,
model_id="facebook/musicgen-stereo-melody-large",
)
start = time.time()
self.weights_downloader = WeightsDownloader()
for model, dest in [
("955717e8-8726e21a.th", "models/hub/checkpoints"),
("models--facebook--musicgen-small", "models/hub"),
("models--facebook--encodec_32khz", "models/hub"),
("models--t5-base", "models/hub"),
]:
self.weights_downloader.download_weights(model, dest)

self.stereo_large_model = self._load_model(
model_path=MODEL_PATH,
cls=MusicGen,
model_id="facebook/musicgen-stereo-large",
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.loaded_models = {}

self.melody_model = self._load_model(
model_path=MODEL_PATH,
cls=MusicGen,
model_id="facebook/musicgen-melody-large",
)
elapsed_time = time.time() - start
print(f"Setup time: {elapsed_time:.2f}s")

self.large_model = self._load_model(
model_path=MODEL_PATH,
cls=MusicGen,
model_id="facebook/musicgen-large",
)
def _load_model(
self,
model_path: str,
cls: Optional[any] = None,
load_args: Optional[dict] = {},
model_id: Optional[str] = None,
device: Optional[str] = None,
model_version: Optional[str] = None,
) -> MusicGen:

if device is None:
device = self.device
self.weights_downloader.download_weights(
f"models--facebook--musicgen-{model_version}"
)

compression_model = load_compression_model(
model_id, device=device, cache_dir=model_path
model_id, device=self.device, cache_dir=model_path
)
lm = load_lm_model(model_id, device=device, cache_dir=model_path)
lm = load_lm_model(model_id, device=self.device, cache_dir=model_path)

return MusicGen(model_id, compression_model, lm)

def predict(
self,
model_version: str = Input(
description="Model to use for generation. If set to 'encode-decode', the audio specified via 'input_audio' will simply be encoded and then decoded.",
description="Model to use for generation",
default="stereo-melody-large",
choices=["stereo-melody-large", "stereo-large", "melody-large", "large", "encode-decode"],
choices=[
"stereo-melody-large",
"stereo-large",
"melody-large",
"large",
],
),
prompt: str = Input(
description="A description of the music you want to generate.", default=None
Expand Down Expand Up @@ -148,37 +132,52 @@ def predict(
default=None,
),
) -> Path:

if prompt is None and input_audio is None:
raise ValueError("Must provide either prompt or input_audio")
if continuation and not input_audio:
raise ValueError("Must provide `input_audio` if continuation is `True`.")
if (model_version == "stereo-large" or model_version=="large") and input_audio and not continuation:
if (
(model_version == "stereo-large" or model_version == "large")
and input_audio
and not continuation
):
raise ValueError(
"`stereo-large` and `large` model does not support melody input. Set `model_version='stereo-melody-large'` or `model_version='melody-large'` to condition on audio input."
)
if "stereo" in model_version and multi_band_diffusion:
raise ValueError("Multi-Band Diffusion is only available with non-stereo models.")

if model_version == "stereo-melody-large":
model = self.stereo_melody_model
elif model_version == "stereo-large":
model = self.stereo_large_model
elif model_version == "melody-large":
model = self.melody_model
elif model_version == "large":
model = self.large_model

set_generation_params = lambda duration: model.set_generation_params(
duration=duration,
top_k=top_k,
top_p=top_p,
temperature=temperature,
cfg_coef=classifier_free_guidance,
)
raise ValueError(
"Multi-Band Diffusion is only available with non-stereo models."
)

if multi_band_diffusion and not hasattr(self, "mbd"):
print("Loading MultiBandDiffusion...")
self.weights_downloader.download_weights(
"models--facebook--multiband-diffusion", "models/hub"
)
self.mbd = MultiBandDiffusion.get_mbd_musicgen()
print("MultiBandDiffusion loaded successfully.")

if model_version not in self.loaded_models:
print(f"Loading model {model_version}...")
self.loaded_models[model_version] = self._load_model(
model_path=MODEL_PATH,
model_id=f"facebook/musicgen-{model_version}",
model_version=model_version,
)
print(f"Model {model_version} loaded successfully.")
model = self.loaded_models[model_version]

def set_generation_params(duration):
return model.set_generation_params(
duration=duration,
top_k=top_k,
top_p=top_p,
temperature=temperature,
cfg_coef=classifier_free_guidance,
)

if not seed or seed == -1:
seed = torch.seed() % 2 ** 32 - 1
seed = torch.seed() % 2**32 - 1
set_all_seeds(seed)
set_all_seeds(seed)
print(f"Using seed {seed}")
Expand All @@ -189,14 +188,6 @@ def predict(
if multi_band_diffusion:
wav = self.mbd.tokens_to_wav(tokens)

elif model_version == "encode-decode":
encoded_audio = self._preprocess_audio(input_audio, model)
set_generation_params(duration)
if multi_band_diffusion:
wav = self.mbd.tokens_to_wav(tokens)
else:
wav = model.compression_model.decode(encoded_audio).squeeze(0)

else:
input_audio, sr = torchaudio.load(input_audio)
input_audio = input_audio[None] if input_audio.dim() == 2 else input_audio
Expand All @@ -213,16 +204,15 @@ def predict(
input_audio_wavform = input_audio[
..., int(sr * continuation_start) : int(sr * continuation_end)
]
input_audio_duration = input_audio_wavform.shape[-1] / sr

if continuation:
set_generation_params(duration)# + input_audio_duration)
set_generation_params(duration)
wav, tokens = model.generate_continuation(
prompt=input_audio_wavform,
prompt_sample_rate=sr,
descriptions=[prompt],
progress=True,
return_tokens=True
return_tokens=True,
)
if multi_band_diffusion:
wav = self.mbd.tokens_to_wav(tokens)
Expand Down Expand Up @@ -258,7 +248,6 @@ def predict(
def _preprocess_audio(
audio_path, model: MusicGen, duration: tp.Optional[int] = None
):

wav, sr = torchaudio.load(audio_path)
wav = torchaudio.functional.resample(wav, sr, model.sample_rate)
wav = wav.mean(dim=0, keepdim=True)
Expand Down
36 changes: 36 additions & 0 deletions weights_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import subprocess
import time
import os

BASE_URL = "https://weights.replicate.delivery/default/musicgen"


class WeightsDownloader:
def __init__(self):
pass

def download_weights(self, weight_str, dest="models"):
self.download_if_not_exists(weight_str, dest)

def download_if_not_exists(self, weight_str, dest):
if not os.path.exists(f"{dest}/{weight_str}"):
self.download(weight_str, dest)

def download(self, weight_str, dest):
url = BASE_URL + "/" + weight_str + ".tar"
print(f"Downloading {weight_str} to {dest}")
start = time.time()
subprocess.check_call(
["pget", "--log-level", "warn", "-xf", url, dest], close_fds=False
)
elapsed_time = time.time() - start
if os.path.isfile(os.path.join(dest, os.path.basename(weight_str))):
file_size_bytes = os.path.getsize(
os.path.join(dest, os.path.basename(weight_str))
)
file_size_megabytes = file_size_bytes / (1024 * 1024)
print(
f"Downloaded {weight_str} in {elapsed_time:.2f}s, size: {file_size_megabytes:.2f}MB"
)
else:
print(f"Downloaded {weight_str} in {elapsed_time:.2f}s")

0 comments on commit 647e795

Please sign in to comment.