In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
# sys.path.append("jukebox_clone")
# sys.path.append("lucent_clone")
import jukebox
import torch
import librosa
import soundfile as sf
import scipy
import os
import numpy as np
import nussl
import matplotlib.pyplot as plt

from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
from jukebox.utils.jukebox_utils import get_forward_calls_encoder, split_model
rank, local_rank, device = setup_dist_from_mpi()

Using cuda True


In [2]:
# Load the model
model = "5b" # or "1b_lyrics"     
vqvae, *priors = MODELS[model]
hparams = setup_hparams(vqvae, dict(sample_length = 1048576))
vqvae = make_vqvae(hparams, device)
vqvae = vqvae.eval()

Downloading from azure
Restored from /home/ozaydin/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode


In [3]:
num_seconds = 20
sample_rate = 44100
t = np.linspace(0, num_seconds, sample_rate * num_seconds)

## Access Hidden Layers of JukeBox

In [4]:
# Choose the level of JukeBox and hook it
from lucent.optvis.render import hook_model
from functools import partial
from jukebox.utils.jukebox_utils import get_forward_calls_encoder, get_forward_calls_decoder, split_model, compose_funclist

level = 2
encoder = vqvae.encoders[level]
bottleneck = vqvae.bottleneck.level_blocks[level]
decoder = vqvae.decoders[level]

encoder_hook, encoder_layers = hook_model(encoder, include_class_name=False)
encoder_calls, encoder_layer_names = get_forward_calls_encoder(encoder, prefix="")
decoder_hook, decoder_layers = hook_model(decoder, include_class_name=False)
decoder_calls, decoder_layer_names = get_forward_calls_decoder(decoder, prefix="")

In [5]:
# Split Encoder and Decoder
enc_layer_index = -3
l_enc = encoder_layer_names[enc_layer_index]
dec_layer_index = -3
l_dec = decoder_layer_names[dec_layer_index]

pre_enc, post_enc = split_model(encoder, l_enc, partial(get_forward_calls_encoder, prefix=""))
pre_dec, post_dec = split_model(decoder, l_dec, partial(get_forward_calls_decoder, prefix=""))


In [6]:
from jukebox.utils.jukebox_utils import slice_model
# Slice Encoder and Decoder
enc_layer_indices = [1,3,5,8]
enc_layer_names = [encoder_layer_names[layer_index] for layer_index in enc_layer_indices]
dec_layer_indices = [2,5,]
dec_layer_names = [decoder_layer_names[layer_index] for layer_index in dec_layer_indices]

enc_slices = slice_model(encoder, enc_layer_names, partial(get_forward_calls_encoder, prefix=""))
dec_slices = slice_model(decoder, dec_layer_names, partial(get_forward_calls_decoder, prefix=""))

In [7]:
# Merge Encoder and Decoder
pre = compose_funclist([
    pre_enc,
    post_enc,
    #lambda x: bottleneck(x)[1],
])

post = compose_funclist([
    lambda x: bottleneck(x)[1],
    pre_dec,
    post_dec,
])


In [8]:
# Functions to play and save audio
from jukebox.utils.jukebox_utils import compose_funclist
play = compose_funclist([
    lambda x: x.detach().cpu().numpy()[0,0],
    lambda x: x / max(x.max(), -x.min()),
    lambda x: nussl.AudioSignal(audio_data_array=x, sample_rate=44100),
    lambda x: x.embed_audio()
])

post_process = compose_funclist([
    lambda x: x.detach().cpu().numpy()[0,0],
    lambda x: x / max(x.max(), -x.min()),
])

# Style Transfer Functions

Given source and target audio signals $x, y \in \mathbb{R}^{1 \times T}$ where $T$ is temporal length of the signal, our aim is to generate a translation $x \rightarrow y$ that preserves the 'content' of $x$ and adopts the 'style' from $y$.

To that end, we first extract deep features of $x$ and $y$ from a state-of-the-art autoencoder trained on musical signals, JukeBox. We split JukeBox into two parts $F_{pre}: \mathbb{R}^{1 \times T} \rightarrow\mathbb{R}^{C \times T'}$ and $F_{post}: \mathbb{R}^{C \times T'} \rightarrow \mathbb{R}^{1 \times T}$ where $C$ denotes the number of feature channels and $T'$ is the temporal dimension of the deep features. The features $h_x, h_y \in \mathbb{R}^{C \times T'}$ are formally computed as 

$$
h_s = F_{pre}(s)
$$

where $s \in \{x,y\}$ is the input signal to JukeBox.

Extracted features $h_s$ can be directly decoded to reconstruct the input audio signal $\hat s \sim F_{post}(F_{pre}(s))$. In our work, we propose a transforming the deep features in order to perform style transfer. We denote our transformation function as $F_{trans}: (\mathbb{R}^{C \times T'}, \mathbb{R}^{C \times T'}) \rightarrow \mathbb{R}^{C \times T'} $ and the transformed features as $h_{x \rightarrow y}$.

$$
h_{x \rightarrow y} = F_{trans}(h_x, h_y)\\
$$

Finally, we obtain the translation as

$$
x \rightarrow y = F_{post}(h_{x \rightarrow y})
$$

In the next sections, we describe different transformation functions we used.

## Adaptive Instance Normalziation (AdaIN)

We first use AdaIN to implement $F_{trans}$ as $F_{AdaIN}$.

$$
F_{AdaIN}(h_x, h_y) = \sigma_{h_y} \frac{h_x - \mu_{h_x}}{\sigma_{h_x}} + \mu_y
$$

where $\mu_{h_s}, \sigma_{h_s} \in \mathbb{R}^{C \times 1}$  temporal average and standard deviation of $h_s$ respectively.

The translation using AdaIN is defined as
$$
\begin{align}
x \xrightarrow{AdaIN} y &= F_{post}(F_{AdaIN}(F_{pre}(x), F_{pre}(y)))\\
                        &= F_{post}(F_{AdaIN}(h_x, h_y))
\end{align}
$$

In [9]:
# AdaIN functions
def compute_stats(x):
    mean = x.mean(dim=-1).contiguous()[...,None]
    std = x.std(dim=-1).contiguous()[...,None]
    return mean, std

def normalize(x, eps=1e-10):
    normalize_mean, normalize_std = compute_stats(x)
    x_normalized = (x - normalize_mean)/(normalize_std+eps)
    #out = F.batch_norm(a, None, None, training=True)
    return  x_normalized

def modulate(x, mean, std):
    x_modulated = x * std + mean
    return  x_modulated

def AdaIN(x, mean, std):
    x_normalized = normalize(x)
    x_modulated = modulate(x_normalized, mean, std)
    return x_modulated

# AdaIN using torch
def AdaIN_torch(x, mean, std):
    return torch.nn.functional.batch_norm(x, None, None, training=True, weight=std, bias=mean)

## Whitening and Coloring Transform (WCT)
We then use WCT to implement $F_{trans}$ as $F_{WCT}$. We first extract the normalized features  as $\hat{h}_s = h_s - \mu_{h_s}$. Then we apply eigen-decomposition to the normalized features 

$$
\frac{1}{T'} \hat{h}_s \hat{h}_s^T = E_s A_s E_s^T
$$

We perform the whitening on the source features $h_x$ to obtain an uncorrelated feature map.

$$
h_x^{white} = E_x A_x^{-1/2} E_x^T \hat{h}_x
$$

We then colorize $h_x^{white}$ with $\hat{h}_y$

$$
h_x^{colored_y} = E_y A_y^{1/2} E_y^T h_x^{white}
$$

Finally, we add the mean of $h_y$ to the colored features

$$
h_{x \rightarrow y} = h_x^{colored_y} + \mu_y
$$

The transformation function for WCT then becomes

$$
F_{WCT}(h_x, h_y) = E_y A_y^{1/2} E_y^T (E_x A_x^{-1/2} E_x^T (h_x - \mu_{h_x})) + \mu_{h_y}
$$

Notice that similar to AdaIN, we normalized the features along the temporal dimension. However, WCT applies whitening and coloring transforms instead of the standardization used in AdaIN. In that sense, WCT is a stronger transformation compared to AdaIN. Applying stronger transforms is advantageous as it brings the feature statistics of the target and translation signals closer to each other. On the other hand, stronger test time transformations create discrepancy between testing and training, which might result in less realistic outputs.

In [10]:
# WCT functions
def get_gram_matrix(h, remove_mean=True):
    b, c, T = h.shape
    assert b == 1
    f = h[0]
    if remove_mean:
        f = f - f.mean(dim=1, keepdim=True)
    G = torch.matmul(f, f.permute(1, 0))
    G = G / T
    
    return G

def whitening_transform(h):
    eps = 1e-10
    b, c, T = h.shape
    assert b == 1
    
    h = h.to(torch.double)
    G = get_gram_matrix(h).to(torch.double)
    
    
    E, D, V = torch.linalg.svd(G, full_matrices=True)
    
    i = len(D)
    for j in range(len(D)):
        if D[j] < eps:
            i = j
            break
            
    
    Dt = D[:i]    
    Dt = Dt ** (-0.5)
    Dt = torch.diag(Dt)
    D = torch.diag(D)
    Et = E[:, :i]
    Vt = V[:i, :]
    A = Et @ Dt @ Et.T
    
    f = h.reshape(c, T)
    f = f - f.mean(dim=1, keepdim=True)
    wh = A @ f

        
    wh = wh.reshape(b, c, T)
    return wh

def coloring_transform(wh, h_style):
    b, c, T = h_style.shape
    assert b == 1
        
    h_style = h_style.to(torch.double)
    G_style = get_gram_matrix(h_style).to(torch.double)
    
    E, D, V = torch.linalg.svd(G_style, full_matrices=True)

    D = D ** (0.5)
    D = torch.diag(D)
    A = E @ D
    A = A @ E.T
    A = A
    
    b2, c2, T2 = wh.shape
    wh = wh.reshape(c2, T2)
    # wh = wh
    sh = A @ wh
    sh = sh + h_style.reshape(c, T).mean(dim=1, keepdim=True)
    sh = sh.reshape(b2, c2, T2)
    
    return sh

## Applying AdaIN and WCT to the hidden layers
In order to have a stronger control over the style, we propose to use our transformation functions in multiple feature layers of JukeBox.

In [11]:
#### 
# Extract features form a sliced model
def sliced_features(x, slices):
    feats = []
    stats = []
    for block in slices:
        stat = compute_stats(x) 
        feats.append(x)
        stats.append(stat)
        x = block(x)
    feats.append(x)
    return stats, feats

# AdaIN for a sliced model
def sliced_AdaIN(x, y_stats, slices):
    assert len(y_stats) == len(slices)
    feats = []
    for stat, block in zip(y_stats, slices):
        x = AdaIN(x, *stat)
        feats.append(x)
        x = block(x)
    feats.append(x)
    return feats

# WCT for a sliced model
def sliced_WCT(x, y_feats, slices):
    assert len(y_feats) == len(slices) + 1, 'y: {}, slices: {}'.format(len(y_feats), len(slices))
    feats = []
    for y, block in zip(y_feats, slices):
        x_w = whitening_transform(x)
        x_wct = coloring_transform(x_w, y).to(torch.float32)
        feats.append(x_wct)
        x = block(x_wct)
    feats.append(x)
    return feats

## Loading the files

In [12]:
# Load wave wiles
wav_1, sr_1 = librosa.load('piano-C4.wav', sr=sample_rate)
wav_2, sr_2 = librosa.load('violin-C4.wav', sr=sample_rate)

wav_1 = wav_1[:num_seconds * sample_rate]
wav_2 = wav_2[:num_seconds * sample_rate]

x1 = torch.Tensor(wav_1).to(device)[None,None,:] / max(wav_1.max(), -wav_1.min())
x2 = torch.Tensor(wav_2).to(device)[None,None,:] / max(wav_2.max(), -wav_2.min())

## AdaIN and WCT in the end of the encoder

In [13]:
# Extract features from JukeBox and modulate them with AdaIN
with torch.no_grad():
    x1z = pre(x1)
    x2z = pre(x2)
    x1r = post(x1z)
    x2r = post(x2z)

    #### AdaIN
    x1z_stats = compute_stats(x1z)
    x2z_stats = compute_stats(x2z)

    x1z_adain = AdaIN(x1z, *x2z_stats)
    x2z_adain = AdaIN(x2z, *x1z_stats)

    x12_adain = post(x1z_adain)
    x21_adain = post(x2z_adain)
    
    #### WCT
    x1z_w = whitening_transform(x1z)
    x1z_wct = coloring_transform(x1z_w, x2z).to(torch.float32)
    
    x2z_w = whitening_transform(x2z)
    x2z_wct = coloring_transform(x2z_w, x1z).to(torch.float32)
    
    x12_wct = post(x1z_wct)
    x21_wct = post(x2z_wct)
#sf.write('adain_piano2guitar_moonlight.wav', post_process(x_1_r_modulated), sample_rate)
#sf.write('adain_guitar2piano_moonlight.wav', post_process(x_2_r_modulated), sample_rate)

In [14]:
play(x12_adain)

ffmpeg version 4.3.2 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 10.3.0 (GCC)
  configuration: --prefix=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/x86_64-conda-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-libx264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/pkg-config
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  l

In [15]:
play(x21_adain)

ffmpeg version 4.3.2 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 10.3.0 (GCC)
  configuration: --prefix=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/x86_64-conda-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-libx264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/pkg-config
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  l

In [None]:
play(x12_wct)

In [None]:
play(x21_wct)

## AdaIN and WCT in multiple layers of the decoder

In [22]:
with torch.no_grad():
    x1z = pre(x1)
    x2z = pre(x2)
    
    x1_stats, x1_feats = sliced_features(x1z, dec_slices)
    x2_stats, x2_feats = sliced_features(x2z, dec_slices)
    
    x12_wct_z = sliced_WCT(x1z, x2_feats, dec_slices)
    
    x21_wct_z = sliced_WCT(x2z, x1_feats, dec_slices)
    
    x12_wct = x12_wct_z[-1]
    x21_wct = x21_wct_z[-1]


In [15]:
play(x21_wct)

ffmpeg version 4.3.2 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 10.3.0 (GCC)
  configuration: --prefix=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/x86_64-conda-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-libx264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/pkg-config
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  l

In [1]:
play(x12_wct)

NameError: name 'play' is not defined

In [15]:
with torch.no_grad():
    x1z = pre(x1)
    x2z = pre(x2)
    
    x1_stats, x1_feats = sliced_features(x1z, dec_slices)
    x2_stats, x2_feats = sliced_features(x2z, dec_slices)
    
    x12_adain_z = sliced_AdaIN(x1z, x2_stats, dec_slices)
    x21_adain_z = sliced_AdaIN(x2z, x1_stats, dec_slices)
    
    x12_adain = x12_adain_z[-1]
    x21_adain = x21_adain_z[-1]

In [22]:
play(x12_adain)
#sf.write('sliced_adain_guitar2piano_moonlight.wav', post_process(x21), sample_rate)
#sf.write('sliced_adain_piano2guitar_moonlight.wav', post_process(x12), sample_rate)

ffmpeg version 4.3.2 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 10.3.0 (GCC)
  configuration: --prefix=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/x86_64-conda-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-libx264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/home/conda/feedstock_root/build_artifacts/ffmpeg_1645955405450/_build_env/bin/pkg-config
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  l

In [2]:
play(x21_adain)

NameError: name 'play' is not defined