Skip to content

Commit

Permalink
cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
sakemin committed Nov 2, 2023
1 parent 620821b commit d0f4e87
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ def predict(
default="peak",
choices=["loudness", "clip", "peak", "rms"],
),
beat_sync_threshold: float = Input(
description="When beat syncing, if the gap between generated downbeat timing and input audio downbeat timing is larger than `beat_sync_threshold`, consider the beats are not corresponding. (This will be fixed with the optimal value and be hidden, when releasing.)",
default=0.75,
),
chroma_coefficient: float = Input(
description="Coefficient value multiplied to multi-hot chord chroma.",
default=1.0,
ge=0.5,
le=2.5
le=2.0
),
top_k: int = Input(
description="Reduces sampling to the k most likely tokens.", default=250
Expand All @@ -159,22 +163,18 @@ def predict(
description="Seed for random number generator. If `None` or `-1`, a random seed will be used.",
default=None,
),
overlap: int = Input(
description="The length of overlapping part. Last `overlap` seconds of previous generation output audio is given to the next generation's audio prompt for continuation. (This will be fixed with the optimal value and be hidden, when releasing.)",
default=5, le=15, ge=1
),
in_step_beat_sync: bool = Input(
description="If `True`, beat syncing is performed every generation step. In this case, audio prompting with EnCodec token will not be used, so that the audio quality might be degraded on and on along encoding-decoding sequences of the generation steps. (This will be fixed with the optimal value and be hidden, when releasing.)",
default=False,
),
beat_sync_threshold: float = Input(
description="When beat syncing, if the gap between generated downbeat timing and input audio downbeat timing is larger than `beat_sync_threshold`, consider the beats are not corresponding. (This will be fixed with the optimal value and be hidden, when releasing.)",
default=0.75,
),
amp_rate: float = Input(
description="Amplifying the output audio to prevent volume diminishing along generations. (This will be fixed with the optimal value and be hidden, when releasing.)",
default=1.2,
),
# overlap: int = Input(
# description="The length of overlapping part. Last `overlap` seconds of previous generation output audio is given to the next generation's audio prompt for continuation. (This will be fixed with the optimal value and be hidden, when releasing.)",
# default=5, le=15, ge=1
# ),
# in_step_beat_sync: bool = Input(
# description="If `True`, beat syncing is performed every generation step. In this case, audio prompting with EnCodec token will not be used, so that the audio quality might be degraded on and on along encoding-decoding sequences of the generation steps. (This will be fixed with the optimal value and be hidden, when releasing.)",
# default=False,
# ),
# amp_rate: float = Input(
# description="Amplifying the output audio to prevent volume diminishing along generations. (This will be fixed with the optimal value and be hidden, when releasing.)",
# default=1.2,
# ),
) -> Path:

if prompt is None:
Expand All @@ -201,7 +201,7 @@ def predict(
model = self.model
model.lm.eval()

in_step_beat_sync = in_step_beat_sync
# in_step_beat_sync = in_step_beat_sync

set_generation_params = lambda duration: model.set_generation_params(
duration=duration,
Expand Down Expand Up @@ -241,7 +241,7 @@ def predict(
)

beat_sync_threshold = beat_sync_threshold
amp_rate = amp_rate
# amp_rate = amp_rate

set_generation_params(duration)

Expand Down Expand Up @@ -277,7 +277,7 @@ def predict(

downbeat_offset = input_downbeats[0]-wav_downbeats[0]
if downbeat_offset > 0:
wav = torch.concat([torch.zeros([1,1,int(downbeat_offset*wav_sr)]),wav],dim=-1)
wav = torch.concat([torch.zeros([1,1,int(downbeat_offset*wav_sr)]).cpu(),wav.cpu()],dim=-1)
for i in range(len(wav_downbeats)):
wav_downbeats[i]=wav_downbeats[i]+downbeat_offset
wav_downbeats = [0] + wav_downbeats + [wav_length]
Expand Down

0 comments on commit d0f4e87

Please sign in to comment.