In [None]:
!pip install keras-cv==0.4.0 -q
!pip install -U tensorflow -q
!pip install webcolors==1.3
!pip install Random-Word

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [4]:
from textwrap import wrap
import os
import pathlib

import keras_cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from tensorflow import keras

import librosa
import librosa.display
import librosa.feature
from IPython.display import Audio
from IPython.display import Image as IImage
from PIL import Image
import random
from webcolors import rgb_to_name
import matplotlib.colors as mcolors
from moviepy.editor import *

# Fine - Tuning a stable diffusion model
This step of the notebook is not necessary to run to generate outputs, because the fine-tuned weights have been stored in finetuned_stable_diffusion.h5. The dataset used for fine-tuning is the Describable Textures Dataset.


In [None]:
######## Fine-tuning stable diffusion code taken and modified from
######## https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/finetune_stable_diffusion.ipynb
######## Authored by Sayak Paul, Chansung Park

######## FINE-TUNING THE MODEL ###################################

# Load data
data_path = pathlib.Path('/content/gdrive/MyDrive/MyTextures')
data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))

data_frame["image_path"] = data_frame["image_path"].apply(
    lambda x: os.path.join(data_path, x)
)
data_frame.head()

##################################################################

# The padding token and maximum prompt length are specific to the text encoder.
PADDING_TOKEN = 49407
MAX_PROMPT_LENGTH = 77

# Load the tokenizer.
tokenizer = SimpleTokenizer()

#  Method to tokenize and pad the tokens.
def process_text(caption):
    tokens = tokenizer.encode(caption)
    tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
    return np.array(tokens)

# Collate the tokenized captions into an array.
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))

all_captions = list(data_frame["caption"].values)
for i, caption in enumerate(all_captions):
    tokenized_texts[i] = process_text(caption)


####################################################################

# Defining the model
RESOLUTION = 256
AUTO = tf.data.AUTOTUNE
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
        keras_cv.layers.RandomFlip(),
        tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
    ]
)
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)


#####################################################################

# Formatting the dataset
def process_image(image_path, tokenized_text):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, 3)
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    return image, tokenized_text


def apply_augmentation(image_batch, token_batch):
    return augmenter(image_batch), token_batch


def run_text_encoder(image_batch, token_batch):
    return (
        image_batch,
        token_batch,
        text_encoder([token_batch, POS_IDS], training=False),
    )


def prepare_dict(image_batch, token_batch, encoded_text_batch):
    return {
        "images": image_batch,
        "tokens": token_batch,
        "encoded_text": encoded_text_batch,
    }


def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
    dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
    dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
    dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
    dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
    return dataset.prefetch(AUTO)

# Prepare the dataset.
training_dataset = prepare_dataset(
    np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
)

# Take a sample batch and investigate.
sample_batch = next(iter(training_dataset))

for k in sample_batch:
    print(k, sample_batch[k].shape)

plt.figure(figsize=(20, 10))

for i in range(3):
    ax = plt.subplot(1, 4, i + 1)
    plt.imshow((sample_batch["images"][i] + 1) / 2)

    text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
    text = text.replace("<|startoftext|>", "")
    text = text.replace("<|endoftext|>", "")
    text = "\n".join(wrap(text, 12))
    plt.title(text, fontsize=15)

    plt.axis("off")


##########################################################################

#Setting up a trainer for the fine-tuning training

class Trainer(tf.keras.Model):
    # Reference:
    # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

    def __init__(
        self,
        diffusion_model,
        vae,
        noise_scheduler,
        use_mixed_precision=False,
        max_grad_norm=1.0,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.diffusion_model = diffusion_model
        self.vae = vae
        self.noise_scheduler = noise_scheduler
        self.max_grad_norm = max_grad_norm

        self.use_mixed_precision = use_mixed_precision
        self.vae.trainable = False

    def train_step(self, inputs):
        images = inputs["images"]
        encoded_text = inputs["encoded_text"]
        batch_size = tf.shape(images)[0]

        with tf.GradientTape() as tape:
            # Project image into the latent space and sample from it.
            latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
            # Know more about the magic number here:
            # https://keras.io/examples/generative/fine_tune_via_textual_inversion/
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents.
            noise = tf.random.normal(tf.shape(latents))

            # Sample a random timestep for each image.
            timesteps = tnp.random.randint(
                0, self.noise_scheduler.train_timesteps, (batch_size,)
            )

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process).
            noisy_latents = self.noise_scheduler.add_noise(
                tf.cast(latents, noise.dtype), noise, timesteps
            )

            # Get the target for loss depending on the prediction type
            # just the sampled noise for now.
            target = noise  # noise_schedule.predict_epsilon == True

            # Predict the noise residual and compute loss.
            timestep_embedding = tf.map_fn(
                lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
            )
            timestep_embedding = tf.squeeze(timestep_embedding, 1)
            model_pred = self.diffusion_model(
                [noisy_latents, timestep_embedding, encoded_text], training=True
            )
            loss = self.compiled_loss(target, model_pred)
            if self.use_mixed_precision:
                loss = self.optimizer.get_scaled_loss(loss)

        # Update parameters of the diffusion model.
        trainable_vars = self.diffusion_model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        if self.use_mixed_precision:
            gradients = self.optimizer.get_unscaled_gradients(gradients)
        gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {m.name: m.result() for m in self.metrics}

    def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
        half = dim // 2
        log_max_preiod = tf.math.log(tf.cast(max_period, tf.float32))
        freqs = tf.math.exp(
            -log_max_preiod * tf.range(0, half, dtype=tf.float32) / half
        )
        args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
        embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
        embedding = tf.reshape(embedding, [1, -1])
        return embedding

    def sample_from_encoder_outputs(self, outputs):
        mean, logvar = tf.split(outputs, 2, axis=-1)
        logvar = tf.clip_by_value(logvar, -30.0, 20.0)
        std = tf.exp(0.5 * logvar)
        sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
        return mean + std * sample

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        # Overriding this method will allow us to use the `ModelCheckpoint`
        # callback directly with this trainer class. In this case, it will
        # only checkpoint the `diffusion_model` since that's what we're training
        # during fine-tuning.
        self.diffusion_model.save_weights(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
        )


##############################################################################

# Start training process

# Enable mixed-precision training if the underlying GPU has tensor cores.
USE_MP = True
if USE_MP:
    keras.mixed_precision.set_global_policy("mixed_float16")

image_encoder = ImageEncoder(RESOLUTION, RESOLUTION)
diffusion_ft_trainer = Trainer(
    diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
    # Remove the top layer from the encoder, which cuts off the variance and only
    # returns the mean.
    vae=tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-2].output,
    ),
    noise_scheduler=NoiseScheduler(),
    use_mixed_precision=USE_MP,
)

# These hyperparameters come from this tutorial by Hugging Face:
# https://huggingface.co/docs/diffusers/training/text2image
lr = 1e-5
beta_1, beta_2 = 0.9, 0.999
weight_decay = (1e-2,)
epsilon = 1e-08

optimizer = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr,
    weight_decay=weight_decay,
    beta_1=beta_1,
    beta_2=beta_2,
    epsilon=epsilon,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")

# Actually train. 70 epochs chosen because the source stated that was the number 
#of iterations for which they got the best resuls

epochs = 70
ckpt_path = "/content/gdrive/MyDrive/finetuned_stable_diffusion.h5"
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
    ckpt_path,
    save_weights_only=True,
    monitor="loss",
    mode="min",
)
diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])

# Generating outputs
From here we can load the fine-tuned weights and use these to generate images inspired by an input audio. We then perform interpolations between prompts parametrised by musical characteristics to walk through the latent space and output short videos with the audio.

In [None]:
######## Loading fine-tuned weights
weights_path = '/content/gdrive/MyDrive/finetuned_stable_diffusion.h5'

img_height = img_width = 512
textures_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)
# We just reload the weights of the fine-tuned diffusion model.
textures_model.diffusion_model.load_weights(weights_path)

###########################################################################

######## Loading audio file

audioname = input('Enter a filepath for a .wav audio file: ')

if 'wav' not in audioname:
  print('Not a .wav file, please try again.')

y, sr = librosa.load(audioname)

clip_length_seconds = 15
if len(y) > clip_length_seconds * sr:
  y = y[:clip_length_seconds*sr]


######## Processing audio file

X = np.abs(librosa.stft(y)) # get STFT representatino

SC = librosa.feature.spectral_centroid(y=y, sr=sr) # get spectral centroids over time

onset_env = librosa.onset.onset_strength(y=y, sr=sr) # get spectral flux over time

zcr = librosa.feature.zero_crossing_rate(y)[0] # get zero crossing rate over time

XT = X.T
volume = []
prompts = []

# Define some prompts vaguely inspired by timbral dimensions
high_centroid = ['bright', 'shiny', 'up', 'light', 'airy']
low_centroid = ['dark', 'dull', 'unsettling', 'down', 'deep']
high_flux = ['intense', 'swirl', 'thick', 'colourful', 'irregular']
low_flux = ['undulating', 'regular', 'round', 'sphere', 'bland']
high_zcr = ['rough', 'harsh', 'crackly', 'bumpy', 'intricate', 'pointed']
low_zcr = ['flat', 'smooth', 'soft', 'square', 'satin', 'linear']

######## Mappings ########

def map_SC_to_prob(SC):
    # Define the height range for the color mapping
    SC_min = 0
    SC_max = 2000
    
    # Define the probability range for the color mapping
    prob_min = 0.1
    prob_max = 0.9
    
    # Map the height value to the probability range using a linear mapping function
    prob = (SC - SC_min) / (SC_max - SC_min) * (prob_max - prob_min) + prob_min
    
    return prob

# Making a dictionary for colour prompts

def get_rgb(name):
    rgb_tuple = mcolors.to_rgb(name)
    return rgb_tuple

colours_dict = mcolors.CSS4_COLORS
sorted_colours = sorted(colours_dict, key=get_rgb)


prob = map_SC_to_prob(SC[0][0]) 

######## Prompt creation ########
print('Creating prompts...')
prompt_step = 40
for i in range(0,len(SC[0]), prompt_step):

    # Spectral centroid prompt
    prob = map_SC_to_prob(SC[0][i])
    rand_num = random.random()

    if prob >= rand_num:
        rand_ind = random.randrange(len(high_centroid)-1)
        hue = high_centroid[rand_ind]
    else:
        rand_ind = random.randrange(len(low_centroid)-1)
        hue = low_centroid[rand_ind]

    # Spectral flux prompt
    if onset_env[i] > 2:
        rand_ind = random.randrange(len(high_flux)-1)
        flux = high_flux[rand_ind]
    else:
        rand_ind = random.randrange(len(low_flux)-1)
        flux = low_flux[rand_ind]

    # Zero crossing rate prompt
    if zcr[i] > 0.05:
        rand_ind = random.randrange(len(high_zcr)-1)
        texture = high_zcr[rand_ind]
    else:
        rand_ind = random.randrange(len(low_zcr)-1)
        texture = low_zcr[rand_ind]

    # Colours prompt
    volume = np.mean(XT[i])
    map_colours = (volume-min(XT[i]))/(max(XT[i])-min(XT[i])+0.01) * (147*2)
    index = int(map_colours)
    colour = sorted_colours[index]

    string = hue + ' ' + flux + ' ' + texture + ' ' + colour
    prompts.append(string)

#######################################################################

# Walk through the latent diffusion space:
# https://keras.io/examples/generative/random_walks_with_stable_diffusion/
# Authored by Ian Stenbit, fchollet, lukewood

# Helper function from https://keras.io/examples/generative/random_walks_with_stable_diffusion/
# Creates gifs

print('Setting up image generation...')
def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
    if rubber_band:
        images += images[2:-1][::-1]
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )

# Encode
encoding = []
for prompt in prompts:
  encoding.append(tf.squeeze(textures_model.encode_text(prompt)))

# Set parameters for generation
interpolation_steps = 6
batch_size = 3
batches = (interpolation_steps) // batch_size
seed = 12345
noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)
images = []


# Interpolate between pairs of prompts for the length
# of the prompts array
print('Generating gifs...')
for i in range(len(encoding)-1):
  interpolations = tf.linspace(encoding[i], encoding[i+1], interpolation_steps)
  batched_encodings = tf.split(interpolations, batches)

  for batch in range(batches):
    images += [
        Image.fromarray(img)
        for img in textures_model.generate_image(
            batched_encodings[batch],
            batch_size=batch_size,
            num_steps=25,
            diffusion_noise=noise,
        )
    ]
if 'wav' in audioname:
  gif_name = audioname.strip('wav') + 'gif'
elif 'mp3' in audioname:
  gif_name = audioname.strip('mp3') + 'gif'
export_as_gif(gif_name, images, rubber_band=True)
  
######## Combine video with audio

print('Creating video...')

video_clip = VideoFileClip(gif_name)
audio_clip = AudioFileClip(audioname)
audio_clip = audio_clip.subclip(0, clip_length_seconds)
final_clip = video_clip.set_audio(audio_clip)
audioname = audioname.strip('mp3')
output_filename = audioname.strip('wav') + 'mp4'
final_clip.write_videofile(output_filename)