## Script to sample from experiments of training different control methods

Requires GPU. Best to Run in Google Colab. Just clone the entire github into your drive and pull the model from HuggingFace as directed below to make this work. Make sure you have set up the control condas environment as well.

In [1]:
# UNCOMMENT AND RUN THIS BLOCK IF USING GOOGLE COLAB

from google.colab import drive
drive.mount("/content/drive/")

## cd into desired directory 
%cd drive/MyDrive/audio-inpainting-project/riff-cnet

# install dependencies
!pip install -q einops
!pip install -q omegaconf
!pip install -q transformers
!pip install -q open-clip-torch
!pip install -q pydub
!pip install -q pytorch_lightning==1.7.7 
!pip install -q fastcore -U
!pip install -q Pillow==9.1.0

In [8]:
# may need to pull updates
!git pull

imports

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image

from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams

from cldm.ddim_hacked import DDIMSampler
from cnet_riff_dataset import CnetRiffDataset
from utils.cnet_utils import get_model, sample_ddim

Loop through and get all contents

In [42]:
# # PARAMETERS
control_methods = ["canny","fullspec","sobel","sobeldenoise"]
num_samples = 2      

# load model into pytorch 
mdl_paths = dict(canny="lightning_logs/version_6/checkpoints/epoch=5-step=27365.ckpt",
                 fullspec="lightning_logs/version_7/checkpoints/epoch=5-step=27365.ckpt",
                 sobel="lightning_logs/version_4/checkpoints/epoch=8-step=41048.ckpt",
                 sobeldenoise="lightning_logs/version_8/checkpoints/epoch=9-step=43448.ckpt")

img_converter_to_audio = SpectrogramImageConverter(SpectrogramParams(sample_rate=44100, min_frequency=0, max_frequency=10000))

# generate the samples for the desired dataset
os.makedirs("experiment_samples", exist_ok=True)

for control_method in control_methods:

  # get new model
  model = get_model(mdl_paths[control_method])
  ddim_sampler = DDIMSampler(model)
  save_dir = os.path.join("experiment_samples",control_method)
  os.makedirs(save_dir, exist_ok=True)

  # get dataset
  val_dataset = CnetRiffDataset("val-data/", promptfile="prompt-"+control_method+".json")

  print(f"{control_method} model loaded!")

  for i, item in enumerate(val_dataset):
      # only sample a subset, like around 15 or so samples should be good
      if i%11 == 0:
          print(f"Sampling for prompt: {item['txt']}")
          results, _ = sample_ddim(item['hint'], item['txt'], model, ddim_sampler, num_samples=num_samples)

          for (k, sample) in enumerate(results):
              # save each sample spectrogram
              cv2.imwrite(os.path.join(save_dir,f"{item['txt']}_samp_{k}.png"), sample)
              # save each sample audio
              sample_img = Image.open(os.path.join(save_dir,f"{item['txt']}_samp_{k}.png"))
              out_audio_recon = img_converter_to_audio.audio_from_spectrogram_image(sample_img, apply_filters=True).set_channels(2)
              out_audio_recon.export(os.path.join(save_dir,f"{item['txt']}_samp_{k}.wav"), format="wav")

          # save source for reference
          source = item['hint']
          if (np.max(source) <= 1) and (np.min(source) >= 0):
              print("switching control scale from [0.,1.] to [0,255]")
              source = np.uint8(source  * 255)
          cv2.imwrite(os.path.join(save_dir,f"{item['txt']}_source.png"), source)
          source_img = Image.open(os.path.join(save_dir,f"{item['txt']}_source.png")) 
          out_audio_recon = img_converter_to_audio.audio_from_spectrogram_image(source_img, apply_filters=True).set_channels(2)
          out_audio_recon.export(os.path.join(save_dir,f"{item['txt']}_source.wav"), format="wav") 

          # save target too
          target = (item['jpg'] + 1) / 2 * 255
          cv2.imwrite(os.path.join(save_dir,f"{item['txt']}_target.png"), target)
          target_img = Image.open(os.path.join(save_dir,f"{item['txt']}_target.png")) 
          out_audio_recon = img_converter_to_audio.audio_from_spectrogram_image(target_img, apply_filters=True).set_channels(2)
          out_audio_recon.export(os.path.join(save_dir,f"{item['txt']}_target.wav"), format="wav") 

  del model
  del ddim_sampler
  del val_dataset

Get one example

In [None]:
# # PARAMETERS
control_method = "canny" # do canny, fullspec, sobel, sobeldenoise
num_samples = 2      

# load model into pytorch 
mdl_paths = dict(canny="lightning_logs/version_6/checkpoints/epoch=5-step=27365.ckpt",
                 fullspec="lightning_logs/version_7/checkpoints/epoch=5-step=27365.ckpt",
                 sobel="lightning_logs/version_4/checkpoints/epoch=8-step=41048.ckpt",
                 sobeldenoise="lightning_logs/version_8/checkpoints/epoch=9-step=43448.ckpt")
model = get_model(mdl_paths[control_method])
ddim_sampler = DDIMSampler(model)
img_converter_to_audio = SpectrogramImageConverter(SpectrogramParams(sample_rate=44100, min_frequency=0, max_frequency=10000))

# get dataset
val_dataset = CnetRiffDataset("val-data/", promptfile="prompt-"+control_method+".json")

# before looping, just try 1. view sample spec
item = val_dataset[0]
print(f"Sampling for prompt: {item['txt']}")
results, _ = sample_ddim(item['hint'], item['txt'], model, ddim_sampler, num_samples=num_samples)
results = results[0]
cv2.imwrite(os.path.join(f"test_samp.png"), results)
sample_img = Image.open(os.path.join(f"test_samp.png"))
sample_img

In [None]:
# listen to sample audio:
out_audio_recon = img_converter_to_audio.audio_from_spectrogram_image(sample_img, apply_filters=True).set_channels(2)
out_audio_recon

In [None]:
# look at target
target = (item['jpg'] + 1) / 2 * 255
cv2.imwrite("test_target.png", target)
target_img = Image.open("test_target.png") 
target_img

In [None]:
# listen to target audio 
out_audio_recon = img_converter_to_audio.audio_from_spectrogram_image(target_img, apply_filters=True).set_channels(2)
out_audio_recon