Skip to content

Commit

Permalink
train code (testing)
Browse files Browse the repository at this point in the history
  • Loading branch information
sakemin committed Nov 24, 2023
1 parent a43cc47 commit 79c9fb7
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 27 deletions.
3 changes: 1 addition & 2 deletions audiocraft/solvers/musicgen_chord.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def _compute_cross_entropy(
ce = ce / K
return ce, ce_per_codebook

@torch.no_grad()
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
Expand Down Expand Up @@ -548,7 +547,7 @@ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
rtf = 1.
else:
gen_unprompted_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
batch, gen_duration=target_duration, prompt_duration=None,
**self.generation_params)
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
rtf = gen_unprompted_outputs['rtf']
Expand Down
2 changes: 1 addition & 1 deletion config/conditioner/clapemb2music.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ conditioners:
checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
model_arch: 'HTSAT-base'
enable_fusion: false
sample_rate: 44100
sample_rate: 48000
max_audio_length: 10
audio_stride: 1
dim: 512
Expand Down
6 changes: 5 additions & 1 deletion config/solver/musicgen/musicgen_chord_32khz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ logging:
log_tensorboard: false

schedule:
lr_scheduler: cosine
lr_scheduler: null
cosine:
warmup: 4000
lr_min_ratio: 0.0
cycle_length: 1.0

checkpoint:
continue_from: null
continue_best: False
37 changes: 26 additions & 11 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def _delete_param(cfg, full_name: str):
del cfg[parts[-1]]
OmegaConf.set_struct(cfg, True)

def load_ckpt(path, device):
loaded = torch.load(str(path))
def load_ckpt(path, device, url=False):
if url:
loaded = torch.hub.load_state_dict_from_url(str(path))
else:
loaded = torch.load(str(path))
cfg = OmegaConf.create(loaded['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
Expand All @@ -73,6 +76,13 @@ def setup(self, weights: Optional[Path] = None):

self.mbd = MultiBandDiffusion.get_mbd_musicgen()

if str(weights) == "weights":
weights = None

if weights is not None:
print("Fine-tuned model weights loaded!")
self.model = load_ckpt(weights, self.device, url=True)

def _load_model(
self,
model_path: str,
Expand All @@ -95,8 +105,8 @@ def _load_model(
def predict(
self,
model_version: str = Input(
description="Model type", default="stereo-chord-large",
choices=["chord", "chord-large", "stereo-chord", "stereo-chord-large"]
description="Model type. Select `fine-tuned` if you trained the model into your own repository.", default="stereo-chord-large",
choices=["chord", "chord-large", "stereo-chord", "stereo-chord-large", "fine-tuned"]
),
prompt: str = Input(
description="A description of the music you want to generate.", default=None
Expand Down Expand Up @@ -209,14 +219,19 @@ def predict(
else:
prompt = prompt + f', bpm : {bpm}'


if os.path.isfile(f'musicgen-{model_version}.th'):
pass
if model_version == "fine-tuned":
try:
self.model
except AttributeError:
raise Exception("ERROR: Fine-tuned weights don't exist! Is the model trained from `sakemin/musicgen-chord`? If not, set `model_version` from `chord`, `chord-large`, `stereo-chord` and `stereo-chord-large`.")
else:
url = f"https://weights.replicate.delivery/default/musicgen-chord/musicgen-{model_version}.th"
dest = f"musicgen-{model_version}.th"
subprocess.check_call(["pget", url, dest], close_fds=False)
self.model = load_ckpt(f'musicgen-{model_version}.th', self.device)
if os.path.isfile(f'musicgen-{model_version}.th'):
pass
else:
url = f"https://weights.replicate.delivery/default/musicgen-chord/musicgen-{model_version}.th"
dest = f"musicgen-{model_version}.th"
subprocess.check_call(["pget", url, dest], close_fds=False)
self.model = load_ckpt(f'musicgen-{model_version}.th', self.device)
self.model.lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True

model = self.model
Expand Down
27 changes: 15 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,6 @@ def prepare_data(

for filename in tqdm(os.listdir(target_path)):
if filename.endswith(('.mp3', '.wav', '.flac')):
# if drop_vocals and separator is not None:
# print('Separating Vocals from ' + filename)
# origin, separated = separator.separate_audio_file(target_path + '/' + filename)
# mixed = separated["bass"] + separated["drums"] + separated["other"]
# torchaudio.save(target_path + '/' + filename, mixed, separator.samplerate)


# Chuking audio files into 30sec chunks
audio = AudioSegment.from_file(target_path + '/' + filename)

Expand Down Expand Up @@ -291,7 +284,8 @@ def get_audio_features(audio_filename):
return max_sample_rate, filelen

def train(
dataset_path: Path = Input("Path to dataset directory. Input audio files will be chunked into multiple 30 second audio files. Must be one of 'tar', 'tar.gz', 'gz', 'zip' types of compressed file, or a single 'wav', 'mp3', 'flac' file. Audio files must be longer than 5 seconds.",),
model_version: str = Input(description="Model version to train.", default="stereo-chord", choices=["stereo-chord", "chord"]),
dataset_path: Path = Input(description="Path to dataset directory. Input audio files will be chunked into multiple 30 second audio files. Must be one of 'tar', 'tar.gz', 'gz', 'zip' types of compressed file, or a single 'wav', 'mp3', 'flac' file. Audio files must be longer than 5 seconds.",),
auto_labeling: bool = Input(description="Creating label data like genre, mood, theme, instrumentation, key, bpm for each track. Using `essentia-tensorflow` for music information retrieval.", default=True),
drop_vocals: bool = Input(description="Dropping the vocal tracks from the audio files in dataset, by separating sources with Demucs.", default=True),
one_same_description: str = Input(description="A description for all of audio data", default=None),
Expand Down Expand Up @@ -336,13 +330,23 @@ def train(
solver = "musicgen/musicgen_chord_32khz"
model_scale = "medium"
conditioner = "chord2music"
continue_from = "/src/musicgen_chord.th"
continue_from = f"/src/musicgen-{model_version}.th"

if not os.path.isfile(continue_from):
if os.path.isfile(f'musicgen-{model_version}.th'):
pass
else:
print("Downloading the model weights!")
sp.call(["curl", "https://musicgen-chord.s3.ap-southeast-2.amazonaws.com/musicgen_chord_popped.th", "--output", continue_from])
url = f"https://weights.replicate.delivery/default/musicgen-chord/musicgen-{model_version}.th"
dest = f"musicgen-{model_version}.th"
subprocess.check_call(["pget", url, dest], close_fds=False)

args = ["run", "-d", "--", f"solver={solver}", f"model/lm/model_scale={model_scale}", f"continue_from={continue_from}", f"conditioner={conditioner}"]
if "stereo" in model_version:
args.append(f"codebooks_pattern.delay.delays={[0, 0, 1, 1, 2, 2, 3, 3]}")
args.append('transformer_lm.n_q=8')
args.append('interleave_stereo_codebooks.use=True')
args.append('channels=2')
args.append(f"datasource.max_sample_rate={max_sample_rate}")
args.append(f"datasource.max_sample_rate={max_sample_rate}")
args.append(f"datasource.train={meta_path}")
args.append(f"dataset.train.num_samples={len_dataset}")
Expand All @@ -368,7 +372,6 @@ def train(
args.append("dataset.train.permutation_on_files=True")
args.append(f"optim.updates_per_epoch={updates_per_epoch}")

print(os.getcwd())
sp.call(["dora"]+args)

for dirpath, dirnames, filenames in os.walk("tmp"):
Expand Down

0 comments on commit 79c9fb7

Please sign in to comment.