In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
# sys.path.append("jukebox_clone")
# sys.path.append("lucent_clone")
import jukebox
import torch
import librosa
import soundfile as sf
import scipy
import os
import numpy as np
import nussl
import matplotlib.pyplot as plt

from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
from jukebox.utils.jukebox_utils import get_forward_calls_encoder, split_model
rank, local_rank, device = setup_dist_from_mpi()

Using cuda True


In [2]:
# Load the model
model = "5b" # or "1b_lyrics"     
vqvae, *priors = MODELS[model]
hparams = setup_hparams(vqvae, dict(sample_length = 1048576))
vqvae = make_vqvae(hparams, device)
vqvae = vqvae.eval()

Downloading from azure
Restored from /home/ozaydin/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode


In [3]:
num_seconds = 20
sample_rate = 44100
t = np.linspace(0, num_seconds, sample_rate * num_seconds)

In [4]:
# Choose the level of JukeBox and hook it
from lucent.optvis.render import hook_model
from functools import partial
from jukebox.utils.jukebox_utils import get_forward_calls_encoder, get_forward_calls_decoder, split_model, compose_funclist

level = 2
encoder = vqvae.encoders[level]
bottleneck = vqvae.bottleneck.level_blocks[level]
decoder = vqvae.decoders[level]

encoder_hook, encoder_layers = hook_model(encoder, include_class_name=False)
encoder_calls, encoder_layer_names = get_forward_calls_encoder(encoder, prefix="")
decoder_hook, decoder_layers = hook_model(decoder, include_class_name=False)
decoder_calls, decoder_layer_names = get_forward_calls_decoder(decoder, prefix="")

In [5]:
# Split Encoder and Decoder
enc_layer_index = -3
l_enc = encoder_layer_names[enc_layer_index]
dec_layer_index = -3
l_dec = decoder_layer_names[dec_layer_index]

pre_enc, post_enc = split_model(encoder, l_enc, partial(get_forward_calls_encoder, prefix=""))
pre_dec, post_dec = split_model(decoder, l_dec, partial(get_forward_calls_decoder, prefix=""))

In [6]:
# Merge Encoder and Decoder
pre = compose_funclist([
    pre_enc,
    post_enc,
    
])

post = compose_funclist([
    lambda x: bottleneck(x)[1],
    pre_dec,
    post_dec,
])


In [7]:
# Functions to play and save audio
from jukebox.utils.jukebox_utils import compose_funclist
play = compose_funclist([
    lambda x: x.detach().cpu().numpy()[0,0],
    lambda x: x / max(x.max(), -x.min()),
    lambda x: nussl.AudioSignal(audio_data_array=x, sample_rate=44100),
    lambda x: x.embed_audio()
])

post_process = compose_funclist([
    lambda x: x.detach().cpu().numpy()[0,0],
    lambda x: x / max(x.max(), -x.min()),
])

In [8]:
# AdaIN functions
def compute_stats(x):
    mean = x.mean(dim=-1).contiguous()[...,None]
    std = x.std(dim=-1).contiguous()[...,None]
    return mean, std

def normalize(x, eps=1e-10):
    normalize_mean, normalize_std = compute_stats(x)
    x_normalized = (x - normalize_mean)/(normalize_std+eps)
    #out = F.batch_norm(a, None, None, training=True)
    return  x_normalized

def modulate(x, mean, std):
    x_modulated = x * std + mean
    return  x_modulated

def AdaIN(x, mean, std):
    x_normalized = normalize(x)
    x_modulated = modulate(x_normalized, mean, std)
    return x_modulated

# AdaIN using torch
def AdaIN_torch(x, mean, std):
    return torch.nn.functional.batch_norm(x, None, None, training=True, weight=std, bias=mean)


In [15]:
# Load wave wiles
wav_1, sr_1 = librosa.load('piano-C4.wav', sr=sample_rate)
wav_2, sr_2 = librosa.load('violin-C4.wav', sr=sample_rate)

wav_1 = wav_1[:num_seconds * sample_rate]
wav_2 = wav_2[:num_seconds * sample_rate]

x_1 = torch.Tensor(wav_1).to(device)[None,None,:] / max(wav_1.max(), -wav_1.min())
x_2 = torch.Tensor(wav_2).to(device)[None,None,:] / max(wav_2.max(), -wav_2.min())

In [16]:
# Extract features from JukeBox and modulate them with AdaIN
with torch.no_grad():
    x_1_z = pre(x_1)
    x_2_z = pre(x_2)

    x_1_z_stats = compute_stats(x_1_z)
    x_2_z_stats = compute_stats(x_2_z)
    
    x_1_z_normalized = normalize(x_1_z)
    x_1_r_normalized = post(x_1_z_normalized)
    
    x_2_z_normalized = normalize(x_2_z)
    x_1_r_normalized = post(x_2_z_normalized)

    x_1_z_modulated = AdaIN(x_1_z, *x_2_z_stats)
    x_2_z_modulated = AdaIN(x_2_z, *x_1_z_stats)

    x_1_r = post(x_1_z)
    x_1_r_modulated = post(x_1_z_modulated)
    x_2_r = post(x_2_z)
    x_2_r_modulated = post(x_2_z_modulated)

    
#sf.write('adain_piano2guitar_moonlight.wav', post_process(x_1_r_modulated), sample_rate)
#sf.write('adain_guitar2piano_moonlight.wav', post_process(x_2_r_modulated), sample_rate)

In [17]:
print(x_1_r_modulated.min(), x_1_r_modulated.max(), x_1_r_modulated.min(), x_1_r_modulated.max())
play(x_1_r_modulated)
#play(x_2_r_modulated)

tensor(-1.1102, device='cuda:0') tensor(1.1451, device='cuda:0') tensor(-1.1102, device='cuda:0') tensor(1.1451, device='cuda:0')


ffmpeg version 5.0.1 Copyright (c) 2000-2022 the FFmpeg developers
  built with gcc 10.3.0 (GCC)
  configuration: --prefix=/home/conda/feedstock_root/build_artifacts/ffmpeg_1649114005999/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1649114005999/_build_env/bin/x86_64-conda-linux-gnu-cc --disable-doc --disable-openssl --enable-demuxer=dash --enable-gnutls --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-vaapi --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-libvpx --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/home/conda/feedstock_root/build_artifacts/ffmpeg_1649114005999/_build_env/bin/pkg-config
  libavutil      57. 17.

In [None]:
import scipy
from lucent.optvis.render import hook_model
from functools import partial
from jukebox.utils.jukebox_utils import get_forward_calls_encoder, get_forward_calls_decoder, split_model

#x = torch.zeros(1,1,44100)
level = 2
encoder = vqvae.encoders[level]
bottleneck = vqvae.bottleneck.level_blocks[level]
decoder = vqvae.decoders[level]


wave = scipy.signal.chirp(t, f_start, num_seconds, f_end)
x = scipy.signal.chirp(t, f_start, num_seconds, f_end)[None, None, :]
x = torch.from_numpy(x).cuda().float()


hook, layers = hook_model(encoder, include_class_name=False)
print(list(layers.keys())[:10])
x_z = encoder(x)[-1]
encoder_calls, encoder_layer_names = get_forward_calls_encoder(encoder, prefix="")

for l in encoder_layer_names:
    pre, post = split_model(encoder, l, partial(get_forward_calls_encoder, prefix=""))
    try:
        h1 = hook(l)
        h2 = pre(x)
        print((h2 == h1).all().item())
        if not (h2 == h1).all().item():
            error = torch.abs(h1 - h2).mean()
            print(f"Layer {l} failed with error = {error}")
    except:
            continue
      
    
hook, layers = hook_model(decoder, include_class_name=False)
_, xs_quantized, _, _ = bottleneck(x_z)
decoder([xs_quantized], all_levels=False)
decoder_calls, decoder_layer_names = get_forward_calls_decoder(decoder, prefix="")
# print(decoder)
for l in decoder_layer_names:
    pre, post = split_model(decoder, l, partial(get_forward_calls_decoder, prefix=""))
    try:
        h1 = hook(l)
        h2 = pre(xs_quantized)
        print((h2 == h1).all().item())
        if not (h2 == h1).all().item():
            error = torch.abs(h1 - h2).mean()
            print(f"Layer {l} failed with error = {error}")
        
    except:
        continue

['level_blocks-0-model-0-0', 'level_blocks-0-model-0-1-model-0-model-0', 'level_blocks-0-model-0-1-model-0-model-1', 'level_blocks-0-model-0-1-model-0-model-2', 'level_blocks-0-model-0-1-model-0-model-3', 'level_blocks-0-model-0-1-model-0-model', 'level_blocks-0-model-0-1-model-0', 'level_blocks-0-model-0-1-model-1-model-0', 'level_blocks-0-model-0-1-model-1-model-1', 'level_blocks-0-model-0-1-model-1-model-2']
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
