Skip to content

Commit

Permalink
infinite generation built in use
Browse files Browse the repository at this point in the history
  • Loading branch information
sakemin committed Nov 1, 2023
1 parent b245c4c commit 6674ff5
Showing 1 changed file with 53 additions and 43 deletions.
96 changes: 53 additions & 43 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def predict(
default="loudness",
choices=["loudness", "clip", "peak", "rms"],
),
chroma_coefficient: float = Input(
description="Coefficient value multiplied to multi-hot chord chroma.",
default=1.0,
ge=0.5,
le=2.5
),
top_k: int = Input(
description="Reduces sampling to the k most likely tokens.", default=250
),
Expand Down Expand Up @@ -212,16 +218,19 @@ def predict(
cfg_coef=classifier_free_guidance,
)

model.lm.condition_provider.conditioners['self_wav'].chroma_coefficient = chroma_coefficient

if not seed or seed == -1:
seed = torch.seed() % 2 ** 32 - 1
set_all_seeds(seed)
set_all_seeds(seed)
print(f"Using seed {seed}")

'''
if duration > 30:
encodec_rate = 50
sub_duration=15
sub_duration=25
overlap = 30 - sub_duration
wavs = []
wav_sr = model.sample_rate
Expand Down Expand Up @@ -330,54 +339,55 @@ def predict(
wav = wav.cpu()
else:
if not audio_chords:
set_generation_params(duration)
if text_chords is None or text_chords == '': # Case 4
wav, tokens = model.generate([prompt], progress=True, return_tokens=True)
else: # Case 5
wav, tokens = model.generate_with_text_chroma(descriptions = [prompt], chord_texts = [text_chords], bpm = [bpm], meter = [int(time_sig.split('/')[0])], progress=True, return_tokens=True)
else:
audio_chords, sr = torchaudio.load(audio_chords)
audio_chords = audio_chords[None] if audio_chords.dim() == 2 else audio_chords
'''
if not audio_chords:
set_generation_params(duration)
if text_chords is None or text_chords == '': # Case 4
wav, tokens = model.generate([prompt], progress=True, return_tokens=True)
else: # Case 5
wav, tokens = model.generate_with_text_chroma(descriptions = [prompt], chord_texts = [text_chords], bpm = [bpm], meter = [int(time_sig.split('/')[0])], progress=True, return_tokens=True)
else:
audio_chords, sr = torchaudio.load(audio_chords)
audio_chords = audio_chords[None] if audio_chords.dim() == 2 else audio_chords

audio_start = 0 if not audio_start else audio_start
if audio_end is None or audio_end == -1:
audio_end = audio_chords.shape[2] / sr
audio_start = 0 if not audio_start else audio_start
if audio_end is None or audio_end == -1:
audio_end = audio_chords.shape[2] / sr

if audio_start > audio_end:
raise ValueError(
"`audio_start` must be less than or equal to `audio_end`"
)
if audio_start > audio_end:
raise ValueError(
"`audio_start` must be less than or equal to `audio_end`"
)

audio_chords_wavform = audio_chords[
..., int(sr * audio_start) : int(sr * audio_end)
]
audio_chords_duration = audio_chords_wavform.shape[-1] / sr

if continuation:
set_generation_params(duration)

if text_chords is None or text_chords == '': # Case 6
wav, tokens = model.generate_continuation(
prompt=audio_chords_wavform,
prompt_sample_rate=sr,
descriptions=[prompt],
progress=True,
return_tokens=True
)
else: # Case 7
wav, tokens = model.generate_continuation_with_text_chroma(
audio_chords_wavform, sr, [prompt], [text_chords], bpm=[bpm], meter=[int(time_sig.split('/')[0])], progress=True, return_tokens=True
)
audio_chords_wavform = audio_chords[
..., int(sr * audio_start) : int(sr * audio_end)
]
audio_chords_duration = audio_chords_wavform.shape[-1] / sr

else: # Case 8
set_generation_params(duration)
wav, tokens = model.generate_with_chroma(
[prompt], audio_chords_wavform, sr, progress=True, return_tokens=True
if continuation:
set_generation_params(duration)

if text_chords is None or text_chords == '': # Case 6
wav, tokens = model.generate_continuation(
prompt=audio_chords_wavform,
prompt_sample_rate=sr,
descriptions=[prompt],
progress=True,
return_tokens=True
)
else: # Case 7
wav, tokens = model.generate_continuation_with_text_chroma(
audio_chords_wavform, sr, [prompt], [text_chords], bpm=[bpm], meter=[int(time_sig.split('/')[0])], progress=True, return_tokens=True
)

if multi_band_diffusion:
wav = self.mbd.tokens_to_wav(tokens)
else: # Case 8
set_generation_params(duration)
wav, tokens = model.generate_with_chroma(
[prompt], audio_chords_wavform, sr, progress=True, return_tokens=True
)

if multi_band_diffusion:
wav = self.mbd.tokens_to_wav(tokens)

audio_write(
"out",
Expand Down

0 comments on commit 6674ff5

Please sign in to comment.