In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# load all dependencies
import sys
import os
import torch
from src.model.jukebox_diffusion import UnconditionalJukeboxDiffusion
from transformers import JukeboxVQVAEConfig, JukeboxVQVAE
from einops import rearrange

sys.path.append('../')

from src.dataset.jukebox_dataset import JukeboxDataset

In [66]:
LVL = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [67]:
dataset = JukeboxDataset(
    root_dir=os.environ["MAESTRO_DATASET_DIR"],
    split="train",
    lvl=LVL,
    #sequence_len=8192,
    )

dataset[0].shape

# 1 min of audio
# lvl2:  20671
# lvl1:  82687 | x4
# lvl0: 330750 | x4

torch.Size([330750, 64])

In [68]:
# only reload if variable 'vae' is not defined
try:
    vae
except NameError:
    vae_path = os.environ["JUKEBOX_VQVAE_PATH"]
    config = JukeboxVQVAEConfig.from_pretrained("openai/jukebox-1b-lyrics")
    vae = JukeboxVQVAE(config)
    vae.load_state_dict(torch.load(vae_path, map_location="cpu"))
    vae.eval().to(device)
    print("Loaded!")

In [69]:
@torch.no_grad()
def decode(embeddings):
    embeddings = embeddings.to(device)
    if embeddings.dim() == 2:
        embeddings = embeddings.unsqueeze(0)
    embeddings = rearrange(embeddings, "b t c -> b c t")
    # Use only lowest level
    decoder = vae.decoders[LVL]
    de_quantised_state = decoder([embeddings], all_levels=False)
    de_quantised_state = de_quantised_state.permute(0, 2, 1)
    return de_quantised_state

In [70]:
import IPython.display as ipd
def play_audio(audio):
    for a in audio:
        ipd.display(ipd.Audio(a.cpu().numpy().flatten(), rate=44100))

In [74]:
x = dataset[8]

In [75]:
play_audio(decode(x))

In [73]:
dataset.file_paths[7]

'2004/MIDI-Unprocessed_XP_21_R1_2004_03_ORIG_MID--AUDIO_21_R1_2004_04_Track04_wav.part022-of-033.jukebox.lvl0.pt'