
## Codebook Style Transfer

In this notebook we describe and demostrate our simplest method that was used as a baseline model. Recall that the encoder first transforms the input audio into latent vectors $h_s$ and then the vectors are quantized to the closest codebook vector $e_s$. The intuition behind this first method is to quantize the latent vectors from $x$ (content) using only the unique vectors from $y$ (style).

First, we obtain the latent vectors from $x$ and $y$ ( $h_x$ and $h_y$ ). For $y$ we quantize the vectors using the full codebook to obtain $e_y$ and for $x$ we quantize the vectors using only the unique vectors from $e_y$. 

In [None]:
import jukebox
from torch import float64
import torch as t
import torch
import librosa
import os
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
rank, local_rank, device = setup_dist_from_mpi()

import scipy
import torch
import numpy as np
import lucent
from lucent.optvis.render import hook_model
from lucent.modelzoo.util import get_model_layers
import matplotlib.pyplot as plt
from scipy.io.wavfile import read, write
import itertools

Using cuda True


In [None]:
model = "1b_lyrics" # or "1b_lyrics"     
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 3 if model=='5b_lyrics' else 8
hps.name = 'samples'
chunk_size = 16 if model=="5b_lyrics" else 32
max_batch_size = 3 if model=="5b_lyrics" else 16
hps.levels = 3
hps.hop_fraction = [.5,.5,.125]

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)
vqvae = vqvae.eval()

f_start = 100
f_end = 4000
num_seconds = 10
sample_rate = 44100


Downloading from azure
Running  wget -O /root/.cache/jukebox/models/5b/vqvae.pth.tar https://openaipublic.azureedge.net/jukebox/models/5b/vqvae.pth.tar
Restored from /root/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode
Creating cond. autoregress with prior bins [79, 2048], 
dims [384, 6144], 
shift [ 0 79]
input shape 6528
input bins 2127
Self copy is False
Loading artist IDs from /usr/local/lib/python3.7/dist-packages/jukebox/data/ids/v3_artist_ids.txt
Loading artist IDs from /usr/local/lib/python3.7/dist-packages/jukebox/data/ids/v3_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:786432
Downloading from azure
Running  wget -O /root/.cache/jukebox/models/1b_lyrics/prior_level_2.pth.tar https://openaipublic.azureedge.net/jukebox/models/1b_lyrics/prior_level_2.pth.tar
Restored from /root/.cache/jukebox/models/1b_lyrics/prior_level_2.pth.tar
0: Loading prior in eval mode


In [None]:


def get_wave(instrument, note):    
  file_name = "data/notes/{}-{}.wav".format(instrument, note)
  a = read(file_name)
  wave = np.array(a[1],dtype='int16')
  return a[0], wave

def encode(vqvae, wave):
  x = torch.from_numpy(wave).unsqueeze(0).unsqueeze(2).cuda()
  x_in = vqvae.preprocess(x)
  xs = []
  for level in range(vqvae.levels):
      encoder = vqvae.encoders[level]
      x_out = encoder(x_in)
      xs.append(x_out[-1])

  return xs

def decode(vqvae, xs, quant=True):

  if quant:
    zs, xs_quantised, _, _ = vqvae.bottleneck(xs)
  else:
    xs_quantised = xs 

  x_outs = []
  x_outs_nonquantised = []
  for level in range(vqvae.levels):
      decoder = vqvae.decoders[level]    
      x_out = decoder(xs_quantised[level:level+1], all_levels=False)
      x_outs.append(x_out)
      x_outs[level] = vqvae.postprocess(x_outs[level])

  return x_outs


def make_continous(discrete, continous):
  a, num_discrete = discrete.shape
  _, num_continous = continous.shape
  ref = torch.empty((a,num_discrete), dtype =torch.float)

  for d in range(num_discrete):
    smallest_distance = 1000000000000
    nn = 0 # index of nn
    current_discrete = discrete[:,d]

    for c in range(num_continous):
      current_cont = continous[:,c]
      delta = current_discrete - current_cont 
      distance = torch.sqrt(torch.pow(delta, 2).sum(dim=0))

      if distance < smallest_distance:
        smallest_distance = distance
        nn = c

    ref[:,d] = continous[:,nn]

  return ref.cuda()


def transfer_codebook(vqvae, hx, hy):
  # Content
  num_layers = len(hx)

  # Style
  zs, y_quantised, _, _ = vqvae.bottleneck(hy)

  x_transfered = []

  for layer in range(num_layers):
    query = hx[layer][0]

    ref_discrete = torch.unique(torch.floor(y_quantised[layer][0]), dim=1)
    ref = make_continous(ref_discrete, y_quantised[layer][0])

    len_emb, num_queries = query.shape
    len_codebook = ref.shape[1]

    transfer = torch.empty((len_emb,num_queries), dtype =torch.float)

    for q in range(num_queries):
      smallest_distance = 1000000000000
      nn = 0 # index of nn
      query_c = query[:,q] 
      for c in range(len_codebook):
        ref_c = ref[:,c]
        delta = query_c - ref_c
        distance = torch.sqrt(torch.pow(delta, 2).sum(dim=0))

        if distance < smallest_distance:
          smallest_distance = distance
          nn = c

      transfer[:,q] = ref[:,nn]

    x_transfered.append(torch.unsqueeze(transfer,0).cuda())
      
  return x_transfered

def transfer_notes(vqvae, instrumet1, instrumet2):
  # instrument 1 is tha base
  # instrument 2 is the style.   
  hx = encode(vqvae, instrumet1)
  hy = encode(vqvae, instrumet2)

  transfer_enc = transfer_codebook(vqvae, hx, hy)

  transfer_enc[0]= transfer_enc[0].float().cuda()
  transfer_enc[1]= transfer_enc[1].float().cuda()
  transfer_enc[2]= transfer_enc[2].float().cuda()

  #instrumet1_dec = decode(vqvae, hx, quant = True)
  #instrumet2_dec = decode(vqvae, hy, quant = True)
  transfer_dec = decode(vqvae, transfer_enc, False)

  return transfer_dec

# Experiments
In the following cells we run the transfer experiments

## Experiment 1

In [None]:

notes_list = ['C4','G4']
instruments_list =  ['flute', 'piano', 'trumpet','violin']
for note in notes_list:
  for i1, i2 in itertools.combinations(instruments_list, 2):
    print(i1,i2)
    sr, instrumet1 = get_wave(i1, note)
    sr, instrumet2 = get_wave(i2, note)
    
    transfer_dec = transfer_notes(vqvae, instrumet1, instrumet2)
    
    audio = transfer_dec[0].detach().squeeze().cpu().numpy()
    dir = './data/JukeBox Outputs/Single notes/codebook_{i1}-{note}_{i2}-{note}.wav'.format(note = note, i1=i1, i2=i2)
    write(dir, sr, audio)

    audio = transfer_dec[2].detach().squeeze().cpu().numpy()
    dir = './data/JukeBox Outputs/Single notes/codebook-onlylevel2-{i1}-{note}_{i2}-{note}.wav'.format(note = note, i1=i1, i2=i2)
    write(dir, sr, audio)

    audio = transfer_dec[1].detach().squeeze().cpu().numpy()
    dir = './data/JukeBox Outputs/Single notes/codebook-onlylevel1-{i1}-{note}_{i2}-{note}.wav'.format(note = note, i1=i1, i2=i2)
    write(dir, sr, audio)


flute piano
flute trumpet
flute violin
piano trumpet
piano violin
trumpet violin
flute piano
flute trumpet
flute violin
piano trumpet
piano violin
trumpet violin


## Experiment 2

In [None]:
song_dir = "./data/JukeBox Outputs/Moonlight/"


def get_wave(instrumet):
  file_name = "./data/JukeBox Outputs/Moonlight/{}-moonlight.wav".format(instrumet)
  a = read(file_name)
  wave = np.mean(a[1][:900000], axis=1)
  return a[0], wave


instrument_list = ['guitar', 'piano']

for i1 in instrument_list:
  for i2 in instrument_list:
    print(i1,i2)
    if i1!=i2:
      sr, instrumet1 = get_wave(i1)
      sr, instrumet2 = get_wave(i2)

      transfer_dec = transfer_notes(vqvae, instrumet1, instrumet2)

      audio = transfer_dec[0].detach().squeeze().cpu().numpy()
      dir = '{song_dir}/codebook_{i1}-moonlight_{i2}-moonlight.wav'.format(song_dir =song_dir, i1=i1, i2=i2)
      write(dir, sr, audio)


guitar guitar
guitar piano


## Experiment 3

In [None]:
import librosa
banjo, banjo_sr = librosa.load("./data/JukeBox Outputs/SS VQ-VAE dataset/banjo.mp3")
church, church_sr = librosa.load("./data/JukeBox Outputs/SS VQ-VAE dataset/church-organ.mp3")

banjo = banjo[:150000]
church = church[:150000]




In [None]:
transfer_dec = transfer_notes(vqvae, banjo, church)

audio = transfer_dec[0].detach().squeeze().cpu().numpy()
dir = './data/JukeBox Outputs/SS VQ-VAE dataset/codebook_banjo_church-organ.wav'
write(dir, banjo_sr, audio)


In [None]:
transfer_dec = transfer_notes(vqvae, church, banjo)

audio = transfer_dec[0].detach().squeeze().cpu().numpy()
dir = './data/JukeBox Outputs/SS VQ-VAE dataset/codebook_church-organ_banjo.wav'
write(dir, church_sr, audio)
