In [None]:
import os
import sys
import io
import typing as T

from uuid import uuid4

import numpy as np
import pandas as pd
import pydub
import torch
from datetime import datetime

from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline

from tools.riffusion.spectrogram_params import SpectrogramParams
from tools.riffusion.spectrogram_image_converter import SpectrogramImageConverter

from IPython.display import clear_output

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.backends.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
DTYPE = torch.float32
device


In [None]:
# Parameters
EXPERIMENT_DATE = datetime.today().strftime("%Y-%m-%d")
OUTPUT_FOLDER = f"experiments/{EXPERIMENT_DATE}"

EXPERIMENT_ID = str(uuid4())

FILE_PATH = "test_reference_long.wav"
SAMPLING_RATE = 44100

DENOISING_STRENGTH = 0.45
GUIDANCE_SCALE = 7.0
NUM_INFERENCE_STEPS = 50
SEED = 42

PROMPT = "classic italian tenor operatic pop"
NEGATIVE_PROMPT = ""


In [None]:
segment = pydub.AudioSegment.from_file(FILE_PATH)
segment = segment.set_frame_rate(SAMPLING_RATE)
print(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")


In [None]:
START_TIME = 0.0
DURATION = segment.duration_seconds
CLIP_DURATION = 5.0
OVERLAP_DURATION = 2.0


In [None]:
duration = min(DURATION, segment.duration_seconds - START_TIME)
increment = CLIP_DURATION - OVERLAP_DURATION
clip_start_times = START_TIME + np.arange(0, DURATION - CLIP_DURATION, increment)


In [None]:
pd.DataFrame(
    {
        "Start Time [s]": clip_start_times,
        "End Time [s]": clip_start_times + CLIP_DURATION,
        "Duration [s]": CLIP_DURATION,
    }
)


In [None]:
clip_segments: T.List[pydub.AudioSegment] = []

for i, clip_start_time in enumerate(clip_start_times):
    clip_start_time_ms = int(clip_start_time * 1000)
    clip_duration_ms = int(CLIP_DURATION * 1000)
    clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]

    if i == len(clip_start_times) - 1:
        silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
        if silence_ms > 0:
            clip_segment = clip_segment.append(
                pydub.AudioSegment.silent(duration=silence_ms)
            )

    clip_segments.append(clip_segment)


In [None]:
if all(
    [
        type(clip_segment) == pydub.audio_segment.AudioSegment
        for clip_segment in clip_segments
    ]
):
    print("Segments created successfully")


In [None]:
params = SpectrogramParams()

image_converter = SpectrogramImageConverter(params=params, device=device)

result_images: T.List[Image.Image] = []
result_segments: T.List[pydub.AudioSegment] = []


In [None]:
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    pretrained_model_name_or_path="riffusion/riffusion-model-v1",
    revision="main",
    torch_dtype=DTYPE,
    safety_checker=lambda images, **kwargs: (images, [False] * len(images)),
).to(device)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(SEED)

num_expected_steps = max(int(NUM_INFERENCE_STEPS * DENOISING_STRENGTH), 1)


In [None]:
result_images = list()
result_segments = list()

for clip_segment in clip_segments:
    # TODO: implement intermediary saving of the source clip here if needed
    clear_output(wait=False)
    print(f"Processing clip {len(result_images) + 1}/{len(clip_segments)}...")

    init_image = image_converter.spectrogram_image_from_audio(clip_segment)

    closest_width = int(np.ceil(init_image.width / 32) * 32)
    closest_height = int(np.ceil(init_image.height / 32) * 32)
    init_image_resized = init_image.resize(
        (closest_width, closest_height), Image.BICUBIC
    )

    result = pipeline(
        prompt=PROMPT,
        image=init_image,
        strength=DENOISING_STRENGTH,
        num_inference_steps=NUM_INFERENCE_STEPS,
        guidance_scale=GUIDANCE_SCALE,
        negative_prompt=NEGATIVE_PROMPT or None,
        num_images_per_prompt=1,
        generator=generator,
        callback=None,
        callback_steps=1,
    )
    result_image = result[0][0]
    result_image = result_image.resize(init_image.size, Image.BICUBIC)

    result_segment = image_converter.audio_from_spectrogram_image(result_image)

    result_images.append(result_image)
    result_segments.append(result_segment)


In [None]:
# Combine clips with a crossfade based on overlap
crossfade_ms = int(OVERLAP_DURATION * 1000)
combined_segment = result_segments[0]
for segment in result_segments[1:]:
    combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)

print(f"#### Final Audio ({combined_segment.duration_seconds}s)")
combined_segment


In [None]:
combined_segment.export(f"{OUTPUT_FOLDER}/{EXPERIMENT_ID}.wav", format="wav")

with open(f"{OUTPUT_FOLDER}/{EXPERIMENT_ID}.txt", "w") as f:
    f.write(f"Experiment date: {EXPERIMENT_DATE}\n")
    f.write(f"Experiment ID: {EXPERIMENT_ID}\n")
    f.write(f"Prompt: {PROMPT}\n")
    f.write(f"Negative prompt: {NEGATIVE_PROMPT}\n")
    f.write(f"Denoising strength: {DENOISING_STRENGTH}\n")
    f.write(f"Guidance scale: {GUIDANCE_SCALE}\n")
    f.write(f"Number of inference steps: {NUM_INFERENCE_STEPS}\n")
    f.write(f"Seed: {SEED}\n")
