# Style Transfer: Optimization

## Neural Style Transfer for Images

First proposed by Gatsy et al.\[1\], neural style transfer is an algorithm, initially targeting the image domain, that enables reproducing an input with a new artistic style. One input image is usually referred as the **content** image, and the artistic style is inferred from a second image, which is often referred as the **style** image.

![image.png](attachment:image.png)
\[image taken from https://pytorch.org/tutorials/advanced/neural_style_tutorial.html\]

The neural style transfer algorithm in \[1\] does not train any neural network, it is an optimization technique which is based on a pre-trained model, VGG\[2\], trained on ImageNet for image classification and believed to encode more and more complex features throughout its layers. Convolutional neural networks are claimed to generalize well because they are able to capture the invariances and particular features defining different classes that are abstracted from background noise and other nuisances such that the intermediate layers can describe the content and the style of an image. Hence, the main objective of the neural style transfer is to match the style and content representations encoded by the pre-trained model at its intermediate layers.

In a nutshell, the target image is considered as a variable to be trained for, and the content and style losses are defined based on the target-content and target-style image pairs. The content loss is defined as the discrepancy between the activation layer $a_l(\cdot)$ of each image, i.e., $\mathcal{L}_{\text{content}} = \frac{1}{2}\sum_{ij} \left(a_l(\mathbf{C})_{ij} - \mathbf{T})_{ij}\right)^2$, where $\mathbf{C}, \mathbf{T}$ represent target and content images, respectively. On the other hand, the style loss is defined by the correlations between activations of the target and style images across different channels, which can be nicely described by a Gram matrix. The Gram matrix is simply calculated by using the flattened activations, such as $a_l(\mathbf{S})$ or $a_l(\mathbf{T})$ for each feature channel. Each entry in the Gram matrix for the corresponding image is computed by the dot product of flattened activations of the channels corresponding to the indices of the entry. The style loss is then computed as the discrepancy between the Gram matrix of the style image at layer $l$, $G_l(\mathbf{S})$, and the Gram matrix of the target image at the same layer $G_l(\mathbf{T})$ such that  $\mathcal{L}_{\text{style}(L)} = \frac{1}{N_l^2, M_l^2}\sum_{ij} \left(G_l(\mathbf{C})_{ij} - G_l(\mathbf{T})_{ij}\right)^2$, where $N_l, M_l$ denote number of feature channels and the dimension of flattened activations, respectively. In \[1\], the style loss is computed using different layers accross different convolutional layers, and the totall loss is computed as a weighted sum of the content and style losses.


## Neural Style Transfer for Audio
One straightforward application of neural style transfer for audio is achieved in \[3\], where Eric *et al.* obtained the best results when they initialize the target as the content audio, and slowly modify it to match the style via aforementioned method of matching the Gram matrix. In this case, the features are computed based on the spectograms of the target and style audio. As the target audio is initialized with the content, the loss function in this case only consists of the style loss.


## Jukebox
In this notebook, rather than operating on spectograms, we opt for using a pre-trained network for audio, and using the encoded representations for computing the Gram matrix for neural style transfer.

In particular, we use the autoencoder in Jukebox \[4\], which is a generative model that outputs audibly plausable audio clips in many different genres that maintain coherence with high fidelity up to multiple minutes. It first encodes raw audio files using a multi-scale VQ-VAE \[5\], and uses the resultant discrete codes as inputs to the autoregressive transformer for conditional or unconditional generation. We borrow the VQ-VAE of Jukebox which is 
trained on hours of raw audio and hence can encode the raw audio signals to discrete codes.

The VQ-VAE of Jukebox is composed of three seperate instances at different level of abstractions. Each level corresponds to a residual network consisting of noncausal 1-D dilated convolutions, interleaved with downsampling and upsampling 1-D convolutions to match different hop lengths. One key strategy they claim for high fidelity reconstructed audio is the additional spectral loss calculated over multiple STFT parameters. The autoencoders can compress 44 kHz audio in dimensionality by 8x, 32x and 128x at three different layers of abstraction, which are illustrated below.

![image-2.png](attachment:image-2.png)
\[image taken from \[4\] \]



## Dependencies

We will mainly use the pytorch library in this notebook, and running it on a GPU is highly recommended

In [1]:
import jukebox
import torch
import numpy as np 
import librosa
from IPython.display import Audio
from jukebox.make_models import make_vqvae, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache

from scipy.io import wavfile
from jukebox.vqvae.encdec import assert_shape
from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess, load_audio, save_wav

import torch
from torch import nn
import torch.optim as optim

import matplotlib.pyplot as plt

rank, local_rank, device = setup_dist_from_mpi()

Using cuda True


Now, we will upload the Jukebox model "5b", which refers to the number of model parameters (5 billion). 

In [2]:
model = "5b" # or "1b_lyrics" or "5b_lyrics"  
vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)

# jukebox vqvae model: https://github.com/openai/jukebox/blob/master/jukebox/vqvae/vqvae.py
# encoders: vqvae.encoders, decoders: vqvae.decoders - list with lenght = # of levels
# vqvae.encode() returns a list elements of which are the latent variables (need to apply preprocessing
# alternatively, for the encoder
# xs = []
# for level in range(vqvae.levels):
#     encoder = vqvae.encoders[level]
#     x_out = encoder(x_in)
#     xs.append(x_out[-1])

Downloading from azure
Restored from /home/besbinar/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode


## Reconstruction at different layers of VQ-VAE w/o quantization

We first investigate the effect of quantization in VQ-VAE for reconstruction. We first read and play a sample audio:

In [3]:
data_dir = './../data/exp2/'
output_dir =  './../results/exp2/'

# sample data
# sample_audio_sf, sample_audio = wavfile.read("sample_data/piano-C4.wav")
sample_audio, sample_audio_sf = librosa.load(f"{data_dir}/piano-C4.wav", sr=44100)
print(sample_audio_sf)
Audio(sample_audio, rate=sample_audio_sf)

44100


We then encode-decode the input audio for all three autoencoders, and play the reconstructed outcome

In [4]:
duration = min(10, len(sample_audio)//sample_audio_sf)
x = sample_audio[ :duration*sample_audio_sf, ...]
if len(x.shape) == 1:
    x = x[:, None]
x = x.astype(float) / 32767.0
x_tensor = torch.tensor(x[None,...]).cuda()


hps = Hyperparams()
hps.sr = 44100
hps.aug_blend = False  # that is the default


x_tensor = audio_preprocess(x_tensor, hps)
x_in = vqvae.preprocess(x_tensor)

xs = []
for level in range(vqvae.levels):
    encoder = vqvae.encoders[level]
    x_out = encoder(x_in)
    xs.append(x_out[-1])

zs, xs_quantised, _, _ = vqvae.bottleneck(xs)

x_outs = []
x_outs_nonquantised = []
for level in range(vqvae.levels):
    decoder = vqvae.decoders[level]
    
    x_out = decoder(xs_quantised[level:level+1], all_levels=False)
    x_outs.append(x_out)
    
    x_out = decoder(xs[level:level+1], all_levels=False)
    x_outs_nonquantised.append(x_out)

for level in range(vqvae.levels):
    x_outs[level] = vqvae.postprocess(x_outs[level])
    # x_out[level] = audio_postprocess(x_out[level], hps)  # because: def audio_postprocess(x, hps): return x
    x_outs_nonquantised[level] = vqvae.postprocess(x_outs_nonquantised[level])

RuntimeError: CUDA out of memory. Tried to allocate 88.00 MiB (GPU 0; 31.75 GiB total capacity; 103.38 MiB already allocated; 47.00 MiB free; 110.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

### decoded from quantized codes - as in the original work

In [None]:
# Lowest compression - 8x
Audio(x_outs[0].detach().squeeze().cpu().numpy(), rate=audio_a_sf)

In [None]:
# Middle compression - 32x
Audio(x_outs[1].detach().squeeze().cpu().numpy(), rate=audio_a_sf)

In [None]:
# Highest compression - 128x
Audio(x_outs[2].detach().squeeze().cpu().numpy(), rate=audio_a_sf)

### decoded from non-quantized codes: almost the same output

In [None]:
# Lowest compression for non-quantized codes - 8x
Audio(x_outs_nonquantised[0].detach().squeeze().cpu().numpy(), rate=audio_a_sf)

In [None]:
# Middle compression - 32x
Audio(x_outs_nonquantised[1].detach().squeeze().cpu().numpy(), rate=audio_a_sf)

In [None]:
# Highest compression for non-quantized codes - 8x
Audio(x_outs_nonquantised[2].detach().squeeze().cpu().numpy(), rate=audio_a_sf)

# Neural Audio Style Transfer

In [None]:
content_audio_name = "violin-C4"
style_audio_name = "piano-C4"

style_clip, style_sf = librosa.load(f"{data_dir}/{style_audio_name}.wav", duration=10)
if len(style_clip.shape) == 1:
    style_clip = style_clip[:, None]

content_clip, content_sf = librosa.load(f"{data_dir}/{content_audio_name}.wav", duration=10)
if len(content_clip.shape) == 1:
    content_clip = content_clip[:, None]

assert style_sf == content_sf
duration = min(10, len(content_clip)//44100)

hps = Hyperparams()
hps.sr = 44100
hps.aug_blend = False  # that is the default

# preprocessing
style_clip = style_clip[:duration*style_sf,...]
style_clip = style_clip.astype(float) / 32767.0
print(style_clip.shape)
style_tensor = torch.tensor(style_clip[None,...]).cuda()  # additional batch dimension
style_tensor = audio_preprocess(style_tensor, hps)

content_clip = content_clip[:duration*content_sf,...]
content_clip = content_clip.astype(float) / 32767.0
print(content_clip.shape)
content_tensor = torch.tensor(content_clip[None,...]).cuda()
content_tensor = audio_preprocess(content_tensor, hps)

## Audio style transfer using the Gram matrix of a particular autoencoder

In [None]:
def gram_matrix(x):
    N, C, L = x.size() 

    features = x.view(N, C*L) 
    G = torch.mm(features, features.t())  # compute the gram product

    return G.div(N*C*L)

class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = nn.functional.mse_loss(G, self.target)
        return input
    
def get_input_optimizer(input_audio):
    optimizer = optim.LBFGS([input_audio])
    return optimizer

class audio_style_transfer(nn.Module):
    def __init__(
        self, 
        vqvae,
        jukebox_level,
        style_input,
    ):
        super().__init__()

        # set up Jukebox modules
        self.vqvae = vqvae
        self.output_level = jukebox_level

        # process the style input to get VQVAE embedding (before quantization)
        style_input = self.vqvae.preprocess(style_input)
        style_codes = []
        for level in range(self.vqvae.levels):
            encoder = self.vqvae.encoders[level]
            style_code = encoder(style_input)
            style_codes.append(style_code[-1])

        style_embed = style_codes[self.output_level:self.output_level+1] 

        # initialize the loss with the Gram matrix of the style input
        self.style_loss = StyleLoss(style_embed[0])
        print(f"Style embedding shape: {style_embed[0].shape}")

        
    def forward(self, content_input):
        
        content_input = self.vqvae.preprocess(content_input)
        
        content_codes = []
        for level in range(self.vqvae.levels):
            encoder = self.vqvae.encoders[level]
            content_code = encoder(content_input)
            content_codes.append(content_code[-1])

        content_embed = content_codes[self.output_level:self.output_level+1]
        _ = self.style_loss(content_embed[0])
        
        _, content_codes_q, _, _ = self.vqvae.bottleneck(content_codes)
        content_embed_q = content_codes_q[self.output_level:self.output_level+1]
        
        decoder = self.vqvae.decoders[self.output_level]
        decoded_content_1 = decoder(content_embed, all_levels=False)
        decoded_content_1 = self.vqvae.postprocess(decoded_content_1)
        
        decoded_content_2 = decoder(content_embed_q, all_levels=False)
        decoded_content_2 = self.vqvae.postprocess(decoded_content_2)

        return decoded_content_1, decoded_content_2

    
def run_style_loss(
    jukebox_module, 
    jukebox_level, 
    content_audio, 
    style_audio, 
    content_name, 
    style_name, 
    num_iters=100, 
    sf=44100
):
    
    print(f"Style clip shape: {style_audio.shape}")
    print(f"Content clip shape: {content_audio.shape}")
    
    model = audio_style_transfer(jukebox_module, jukebox_level, style_audio)

    input_audio = content_audio.clone()
    input_audio.requires_grad_(True)
    
    output_audio_1, output_audio_2 = model(input_audio)
    model.requires_grad_(False)

    optimizer = get_input_optimizer(input_audio)

    
    print('Optimizing..')
    run = [0]
    while run[0] <= num_iters:

        def closure():
            with torch.no_grad():
                input_audio.clamp_(-1, 1)

            optimizer.zero_grad()
            output_audio_1, output_audio_2 = model(input_audio)

            loss = model.style_loss.loss
            loss.backward()

            if run[0] % 25 == 0:
                print("run {}:".format(run))
                print(f"Iteration {run}, loss value: {loss.item()}")
            run[0] += 1

            return loss

        optimizer.step(closure)
    
        
    wavfile.write(
        f"{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel{model.output_level}_input.wav", 
        rate=sf, 
        data=input_audio.detach().squeeze().cpu().numpy()
    )
    wavfile.write(
        f"{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel{model.output_level}_decoded_noquant.wav", 
        rate=sf, 
        data=output_audio_1.detach().squeeze().cpu().numpy()
    )
    wavfile.write(
        f"{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel{model.output_level}_decoded.wav", 
        rate=sf, 
        data=output_audio_2.detach().squeeze().cpu().numpy()
    )

### level 0

In [None]:
run_style_loss(vqvae, 0, content_tensor, style_tensor, content_audio_name, style_audio_name, num_iters=200)

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel0_input.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel0_decoded_noquant.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel0_decoded.wav')

### level 1

In [None]:
run_style_loss(vqvae, 1, content_tensor, style_tensor)

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel1_input.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel1_decoded_noquant.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel1_decoded.wav')

### level 2

In [None]:
run_style_loss(vqvae, 2, content_tensor, style_tensor)

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel2_input.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel2_decoded_noquant.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_onlylevel2_decoded.wav')

## Audio style transfer using the Gram matrix of all autoencoders

In [None]:
def gram_matrix(x):
    N, C, L = x.size() 

    features = x.view(N*C, L) 
    G = torch.mm(features, features.t())  # compute the gram product

    return G.div(N*C*L)

class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = nn.functional.mse_loss(G, self.target)
        return input
    
def get_input_optimizer(input_audio):
    optimizer = optim.LBFGS([input_audio])
    return optimizer

class audio_style_transfer_all_levels(nn.Module):
    def __init__(
        self, 
        vqvae,
        style_input,
    ):
        super().__init__()

        # set up Jukebox modules
        self.vqvae = vqvae
        self.losses = []

        # process the style input to get VQVAE embedding (before quantization)
        style_input = self.vqvae.preprocess(style_input)
        for level in range(self.vqvae.levels):
            encoder = self.vqvae.encoders[level]
            style_code = encoder(style_input)
            
            self.losses.append(StyleLoss(style_code[-1]))
            
        
    def forward(self, content_input):
        
        content_input = self.vqvae.preprocess(content_input)
        
        content_codes = []
        for level in range(self.vqvae.levels):
            encoder = self.vqvae.encoders[level]
            content_code = encoder(content_input)
            content_codes.append(content_code[-1])
            
            _ = self.losses[level](content_codes[level])
        
        _, content_codes_q, _, _ = self.vqvae.bottleneck(content_codes)
        
        outputs = {}
        for level in range(vqvae.levels):
            decoder = self.vqvae.decoders[level]

            x_out = decoder(content_codes_q[level:level+1], all_levels=False)
            outputs[f"quantization_level{level}"] = self.vqvae.postprocess(x_out)

            x_out_nq = decoder(content_codes[level:level+1], all_levels=False)
            outputs[f"noquantization_level{level}"] = self.vqvae.postprocess(x_out_nq)
        
        return outputs

    
def run_style_loss_alllevels(
    jukebox_module, 
    content_audio, 
    style_audio, 
    content_name,
    style_name,
    num_iters=100, 
    sf=44100, 
    output_levels=[2]
):
    
    print(f"Style clip shape: {style_audio.shape}")
    print(f"Content clip shape: {content_audio.shape}")
    
    model = audio_style_transfer_all_levels(jukebox_module, style_audio)

    input_audio = content_audio.clone()
    input_audio.requires_grad_(True)
    
    outputs = model(input_audio)
    model.requires_grad_(False)

    optimizer = get_input_optimizer(input_audio)

    
    print('Optimizing..')
    run = [0]
    while run[0] <= num_iters:

        def closure():
#             with torch.no_grad():
#                 input_audio.clamp_(-1, 1)

            optimizer.zero_grad()
            outputs = model(input_audio)
            
            loss = 0.
            for l in model.losses:
                loss += l.loss
            loss.backward()

            if run[0] % 10 == 0:
                print("run {}:".format(run))
                print(f"Iteration {run}, loss value: {loss.item()}")
            run[0] += 1

            return loss

        optimizer.step(closure)
    
        
    wavfile.write(
        f"sample_data/{content_name}_{style_name}_grammatrix_input.wav", 
        rate=sf, 
        data=input_audio.detach().squeeze().cpu().numpy()
    )
    
    for level in output_levels:
        wavfile.write(
            f"sample_data/{content_name}_{style_name}_grammatrix_alllevels_decoded_level{level}.wav", 
            rate=sf, 
            data=outputs[f"quantization_level{level}"].detach().squeeze().cpu().numpy()
        )
        wavfile.write(
            f"sample_data/{content_name}_{style_name}_grammatrix_alllevels_decoded_noquant_level{level}.wav", 
            rate=sf, 
            data=outputs[f"noquantization_level{level}"].detach().squeeze().cpu().numpy()
        )
        

In [None]:
run_style_loss_alllevels(vqvae, content_tensor, style_tensor, content_audio_name, style_audio_name, num_iters=40)

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_input.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_decoded_level2.wav')

In [None]:
Audio(f'{output_dir}/{content_name}_{style_name}_grammatrix_decoded_noquant_level2.wav')

## Inspecting spectograms

In [None]:
def audio_spectrum(filename, N_FFT=2048, plot=True):
    x, fs = librosa.load("sample_data/"+ filename, duration=10)
    S = librosa.stft(x, N_FFT)
    p = np.angle(S)
    S = np.log1p(np.abs(S))  
    
    return S


def plot_spectrums(content_audio_name, style_audio_name, model_name="grammatrix_alllevels"):
    
    content_audio = audio_spectrum(content_audio_name + ".wav")
    style_audio = audio_spectrum(style_audio_name + ".wav")

    output_audio1 = audio_spectrum(f"{content_name}_{style_name}_{model_name}_input.wav")
    output_audio2 = audio_spectrum(f"{content_name}_{style_name}_{model_name}_decoded_level2.wav)

    plt.figure(figsize=(16,16))
    plt.subplot(2,2,1)
    plt.title('Content')
    plt.imshow(content_audio[:500,:500])
    plt.subplot(2,2,2)
    plt.title('Style')
    plt.imshow(style_audio[:500,:500])
    plt.subplot(2,2,3)
    plt.title('Result 1 - Modified Input')
    plt.imshow(output_audio1[:500,:500])
    plt.subplot(2,2,4)
    plt.title('Result 2 - Decoded Code w/ Quantization')
    plt.imshow(output_audio2[:500,:500])
    plt.show()
    plt.show()
                                   
plot_spectrums(content_audio_name, style_audio_name)

## References
\[1\] Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

\[2\] Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).

\[3\] Grinstein, Eric, Ngoc QK Duong, Alexey Ozerov, and Patrick Pérez. "Audio style transfer." In 2018 IEEE international conference on acoustics, speech and signal processing (ICASSP), pp. 586-590. IEEE, 2018.

\[4\] Dhariwal, Prafulla, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, and Ilya Sutskever. "Jukebox: A generative model for music." arXiv preprint arXiv:2005.00341 (2020).

\[5\] Van Den Oord, Aaron, and Oriol Vinyals. "Neural discrete representation learning." Advances in neural information processing systems 30 (2017).