# Jukebox representation

__author__ = "dr.seunggoo.kim@gmail.com"
__refernece__ = ["https://openai.com/blog/jukebox/"]

##  Processing steps
1. ENCODING (VQ-VAE-encoder): waveform (44.1 kHz) to top-level (334 Hz)
2. GENERATION (transformers): new top-level representation (334 Hz)
3. UPSAMPLING (transformers): from top-level (334 Hz) to bottom-level (x Hz)
4. DECODING (VQ-VAE-decoder): bottom-level (x Hz) to waveform (44.1 kHz)

## Components
1. VQ-VAE: encoding & decoding
3 levels (8x, 32x, 128x) *cnn*
codebook size = 2048

2. TRANSFORMERS: sampling at the top-level & upsampling to bottom-level
Three-level *prior models* using a simplified variant of Sparse Transformers (72 layers of factorized self-attention on a context of 8192 codes).


### Limitations
Still short: (no choruses that repeat)
Discernable noise
Slow: 9 hours for 1-min new "sampling" (at the top-level)


In [1]:
# Load modules
import librosa as lr
import numpy as np
import torch
import scipy
import pathlib
from argparse import ArgumentParser

# imports and set up Jukebox's multi-GPU parallelization
import jukebox
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.make_models import MODELS, make_prior, make_vqvae
from jukebox.utils.dist_utils import setup_dist_from_mpi
from tqdm import tqdm
rank, local_rank, device = setup_dist_from_mpi()


Using cuda True


In [2]:
# Set up paths
pathlib.os.getcwd()
input_dir = pathlib.Path("input")
output_dir = pathlib.Path("output")
input_paths = sorted(list(input_dir.iterdir()))
print(input_paths[0])

input/classical.00000.wav


In [3]:
input_path = input_paths[0]
# Read audio 
JUKEBOX_SAMPLE_RATE = 44100 # Hz
audio, _ = lr.load(input_path, sr=JUKEBOX_SAMPLE_RATE) # upsampling
if audio.ndim == 2:
    audio = audio.mean(axis=0) # mono-audio
audio /= np.abs(audio).max()
print(audio)
print(f'#smp={audio.shape}, length={audio.shape[0]/JUKEBOX_SAMPLE_RATE} sec')
print(f'{JUKEBOX_SAMPLE_RATE * 25}')
audio = audio[: JUKEBOX_SAMPLE_RATE * 25] # crop only the first 25 seconds


[-0.05790375 -0.07014011 -0.05139991 ...  0.1015296   0.08913703
  0.04516252]
#smp=(1323588,), length=30.013333333333332 sec
1102500


In [4]:
# Set up & load VQ-VAE
model = "5b"  # or "1b_lyrics" or "5b_lyrics"
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 3 if model == "5b_lyrics" else 8
hps.name = "samples"
chunk_size = 16 if model == "5b_lyrics" else 32
max_batch_size = 3 if model == "5b_lyrics" else 16
hps.levels = 3 
hps.hop_fraction = [0.5, 0.5, 0.125]
vqvae, *priors = MODELS[model]
vqvae = make_vqvae(
    setup_hparams(vqvae, dict(sample_length=1048576)), device
)
print(vqvae)

Downloading from azure
Restored from /home/seung-goo.kim/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode
VQVAE(
  (encoders): ModuleList(
    (0): Encoder(
      (level_blocks): ModuleList(
        (0): EncoderConvBlock(
          (model): Sequential(
            (0): Sequential(
              (0): Conv1d(1, 64, kernel_size=(4,), stride=(2,), padding=(1,))
              (1): Resnet1D(
                (model): Sequential(
                  (0): ResConv1DBlock(
                    (model): Sequential(
                      (0): ReLU()
                      (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
                      (2): ReLU()
                      (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
                    )
                  )
                  (1): ResConv1DBlock(
                    (model): Sequential(
                      (0): ReLU()
                      (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(3,), dil

In [5]:
# Compute the compressed audio:
zs = vqvae.encode(torch.cuda.FloatTensor(audio[np.newaxis, :, np.newaxis]))
z = zs[-1].flatten()[np.newaxis, :] # the last one is the top-level?
# for given 1102500 audio samples
print(zs[0].shape)  # short-scale (137812 tokens; 8x; 5512 Hz)
print(zs[1].shape)  # mid-scale (34453 tokens; 32x; 1378 Hz)
print(zs[2].shape)  # long-scale (8613 tokens; 128x; 345 Hz)


torch.Size([1, 137812])
torch.Size([1, 34453])
torch.Size([1, 8613])


In [6]:
zs[-1].max() # and integers... 0 to 2047 (2048 codes)

tensor(2047, device='cuda:0')

In [7]:
# save them in MAT format
import scipy
scipy.io.savemat('output/codes_lvl0.mat', {'Codes':zs[0].cpu().numpy()})
scipy.io.savemat('output/codes_lvl1.mat', {'Codes':zs[1].cpu().numpy()})
scipy.io.savemat('output/codes_lvl2.mat', {'Codes':zs[2].cpu().numpy()})

In [8]:
# Set up & load language model (Sparse Transformers)
# ['upsampler_level_0', 'upsampler_level_1', 'prior_5b']
#hparams = setup_hparams('prior_5b', dict()) 
#hparams["prior_depth"] = 1
#top_prior = make_prior(hparams, vqvae, device)
#print(top_prior)

# 1048576/8192 = 128 44.1K-samples per token
# 1048576/44100/8192*1000 = 2.9025 msec (one hopping step)
# 8192/(1048576/44100) = 344.5312 Hz
# 8192 pnts x 4800 dimensions (2048 codes are embedded in 4800 features)

In [9]:
def get_cond(hps, top_prior):
    sample_length_in_seconds = 62
    hps.sample_length = (
        int(sample_length_in_seconds * hps.sr) // top_prior.raw_to_tokens
    ) * top_prior.raw_to_tokens

    # NOTE: the 'lyrics' parameter is required, which is why it is included,
    # but it doesn't actually change anything about the `x_cond`, `y_cond`,
    # nor the `prime` variables
    metas = [
        dict(
            artist="unknown",
            genre="unknown",
            total_length=hps.sample_length,
            offset=0,
            lyrics="""lyrics go here!!!""",
        ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, "cuda")]
    x_cond, y_cond, prime = top_prior.get_cond(None, top_prior.get_y(labels[-1], 0))
    x_cond = x_cond[0, :T][np.newaxis, ...]
    y_cond = y_cond[0][np.newaxis, ...]

    return x_cond, y_cond

In [10]:
# Get conditions for this prior
JUKEBOX_SAMPLE_RATE = 44100 # Hz
T = 8192 # hops
# ['upsampler_level_0', 'upsampler_level_1', 'prior_5b']
hparams = setup_hparams('prior_5b', dict()) 
hparams["prior_depth"] = 1
top_prior = make_prior(hparams, vqvae, device)
x_cond, y_cond = get_cond(hps, top_prior)

Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576
0: Converting to fp16 params
Downloading from azure
Restored from /home/seung-goo.kim/.cache/jukebox/models/5b/prior_level_2.pth.tar
0: Loading prior in eval mode


In [11]:
x_cond[0].shape

torch.Size([8192, 4800])

In [12]:
y_cond[0].shape

torch.Size([1, 4800])

In [13]:
top_prior.prior.only_encode = True
act = top_prior.prior.forward(
    z[:,:T], x_cond=x_cond, y_cond=y_cond, encoder_kv=None, fp16=False
)

In [14]:
act.squeeze().cpu().numpy()

array([[-0.9825132 ,  3.623651  ,  0.12038398, ..., -1.9669914 ,
         1.1183643 , -3.11454   ],
       [-0.86627686,  1.3010759 ,  1.5062073 , ...,  0.08549738,
         1.4286535 ,  1.8755095 ],
       [-1.2182832 ,  1.4547569 ,  1.3098716 , ..., -3.3949356 ,
         3.1289217 , -1.965195  ],
       ...,
       [-1.6962414 ,  0.5384073 , -0.01441598, ..., -3.957605  ,
         0.7643536 , -0.7796922 ],
       [-1.60969   ,  0.07450581, -0.209427  , ..., -3.4036498 ,
         0.3819468 , -0.9061338 ],
       [ 0.94400036,  2.2806125 ,  3.6798608 , ...,  0.4285879 ,
         6.7440853 , -1.7993827 ]], dtype=float32)

In [15]:
for prior_depth in range(1,77): # because [a,b)
    fname_mat = 'output/act_trans_depth%02i.mat' %prior_depth
    if not pathlib.os.path.isfile(fname_mat):
        hparams = setup_hparams('prior_5b', dict()) # mid-/bottom-level: up-samplers
        hparams["prior_depth"] = prior_depth
        top_prior = make_prior(hparams, vqvae, device)
        x_cond, y_cond = get_cond(hps, top_prior)
        top_prior.prior.only_encode = True
        act = top_prior.prior.forward(
            z[:,:T], x_cond=x_cond, y_cond=y_cond, encoder_kv=None, fp16=False
        ).squeeze().cpu().numpy()

        scipy.io.savemat(fname_mat, {'Act':act})
        del top_prior, vqvae
        torch.cuda.empty_cache()  # removing garbage

Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576
0: Converting to fp16 params
Downloading from azure
Restored from /home/seung-goo.kim/.cache/jukebox/models/5b/prior_level_2.pth.tar
0: Loading prior in eval mode
Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576
0: Converting to fp16 params
Downloading from azure
Restored from /home/seung-goo.kim/.cache/jukebox/models/5b/prior_level_2.pth.tar
0: Loading prior in eval mode
Loading artist IDs from /mnt/beegfs/users/seung-goo.kim/jukebox/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /mnt/beegfs/u

RuntimeError: CUDA out of memory. Tried to allocate 1.17 GiB (GPU 0; 23.65 GiB total capacity; 20.88 GiB already allocated; 883.56 MiB free; 22.01 GiB reserved in total by PyTorch)

In [None]:
'output/act_trans_depth%02i.mat' %prior_depth

In [None]:
prior_depth