In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# load all dependencies
import sys
import os
import torch
from einops import rearrange

sys.path.append('../')

from src.dataset.jukebox_dataset import JukeboxDataset
from src.model.jukebox_diffusion import JukeboxDiffusion
from src.module.diffusion_attn_unet_1d import DiffusionAttnUnet1D

from IPython.display import Audio

In [None]:
def play_audio(audio: torch.Tensor):
    for a in torch.clamp(audio, -1, 1).cpu().numpy():
        display(Audio(a.flatten(), rate=44100))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device= "cpu"
print(f'Using device: {device}')

In [None]:
ckpt_path = '../logs/train/runs/2022-12-13_23-55-32/checkpoints/last.ckpt'
ckpt = torch.load(ckpt_path, map_location=device)

In [None]:

seqmodel = DiffusionAttnUnet1D(
        io_channels=64,
        n_attn_layers=6,
        channel_sizes=[128, 128, 128, 128,  256, 256, 256, 256,  512, 512, 512]
    )

# remove all keys from state_dict that start with 'jukebox'
state_dict = {k: v for k, v in ckpt['state_dict'].items() if not k.startswith('jukebox_vqvae')}

model = JukeboxDiffusion(model=seqmodel)
model.load_state_dict(state_dict=state_dict)
model = model.to(device)
model.hparams.jukebox_embedding_lvl = 2

In [None]:
model.prepare_data()

In [None]:
# check that decoding works
dataset = JukeboxDataset(root_dir=os.environ['MAESTRO_DATASET_DIR'], split='train', lvl=[2, 1, 0], sequence_len=4096*4, samples_per_file=2)
sample = dataset[0]
#test_audio = model.preprocess(dataset[0].unsqueeze(0).to(device))

In [None]:
sample[1].shape

In [None]:
high_lvl_sample = sample[2][0]
mid_lvl_sample = sample[1][0]
low_lvl_sample = sample[0][0]

high_lvl_audio = model.decode(model.preprocess(high_lvl_sample.unsqueeze(0).to(device)), lvl=2)
mid_lvl_audio = model.decode(model.preprocess(mid_lvl_sample.unsqueeze(0).to(device)), lvl=1)
low_lvl_sample = model.decode(model.preprocess(low_lvl_sample.unsqueeze(0).to(device)), lvl=0)

In [None]:
test_audio.min(), test_audio.max()

In [None]:
with torch.no_grad():
    play_audio(high_lvl_audio)
    play_audio(mid_lvl_audio)
    play_audio(low_lvl_sample)

In [None]:
with torch.no_grad():
    embeddings = model.generate_unconditionally(
                batch_size=4,
                seq_len=4096,
                num_inference_steps=100,
                seed=420,
            )
    audio = model.decode(embeddings)

In [None]:
embeddings.min(), embeddings.max()

In [None]:
play_audio(audio)