In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# load all dependencies
import sys
import os
import torch
from einops import rearrange
import numpy as np
import matplotlib.pyplot as plt

sys.path.append('../')

from src.datamodule.maestro_datamodule import MaestroDataModule
from src.model.jukebox_diffusion import JukeboxDiffusion
from src.model.jukebox_vqvae import JukeboxVQVAEModel
from src.module.diffusion_attn_unet_1d import DiffusionAttnUnet1D

from IPython.display import Audio

In [None]:
def play_audio(audio: torch.Tensor, num_samples: int = 1):
    for a in torch.clamp(audio[:num_samples], -1, 1).cpu().numpy():
        display(Audio(a.flatten(), rate=44100))

# Load model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device= "cpu"
print(f'Using device: {device}')

In [None]:
#ckpt_path = '../logs/train/runs/2022-12-13_23-55-32/checkpoints/last.ckpt'
ckpt_path = "../logs/train/runs/2023-02-07_15-34-32/checkpoints/last.ckpt"
ckpt = torch.load(ckpt_path, map_location=device)

In [None]:
seqmodel = DiffusionAttnUnet1D(
        io_channels=64,
        n_attn_layers=6,
        channel_sizes=[128, 128, 128, 128, 256, 256, 256, 256, 512, 512]
    )

# remove all keys from state_dict that start with 'jukebox'
state_dict = {k: v for k, v in ckpt['state_dict'].items() if not 'vqvae' in k}

model = JukeboxDiffusion(model=seqmodel, load_vqvae=False)
model.load_state_dict(state_dict=state_dict)
model = model.to(device)
model.hparams.jukebox_embedding_lvl = 1

In [None]:
model.prepare_data()
model.vqvae = JukeboxVQVAEModel(device=model.device)

# Load dataset

In [None]:
# check that decoding works
datamodule = MaestroDataModule(root_dir=os.environ['MAESTRO_DATASET_DIR'],
            batch_size=8,
            num_workers=4,
            sample_length= 131072)
datamodule.setup()
dataloader = datamodule.val_dataloader()
audio = next(iter(dataloader)).to(device)

# Inspect dataset

In [None]:
audio[0].shape

In [None]:
sample_lvl2 = model.encode(audio.to(model.device), lvl=2)
sample_lvl1 = model.encode(audio.to(model.device), lvl=1)
sample_lvl0 = model.encode(audio.to(model.device), lvl=0)

In [None]:
audio_lvl2 = model.decode(sample_lvl2, lvl=2)
audio_lvl1 = model.decode(sample_lvl1, lvl=1)
audio_lvl0 = model.decode(sample_lvl0, lvl=0)

# Dataset sample statistics

In [None]:
import seaborn as sns
from einops import rearrange

def compute_sample_statistics(sample: torch.Tensor):
    return {
        'mean': sample.mean().item(),
        'std': sample.std().item(),
        'min': sample.min().item(),
        'max': sample.max().item(),
    }

def compute_stats_per_channel(sample: torch.Tensor):
    sample = rearrange(sample, 'b t c -> (b t) c')
    return {
        'mean': sample.mean(dim=0).tolist(),
        'std': sample.std(dim=0).tolist(),
        'min': sample.min(dim=0),
        'max': sample.max(dim=0),
    }

def sample_histogram(sample: torch.Tensor, n_channels: int = 0, flatten=True):
    """
    Args:
        sample: [B, T, C]
        n_channels: number of channels to include in histogram. 0: all
        flatten: flatten along channels, If False, plot histogram for each channel
    """
    if n_channels == 0:
        n_channels = sample.shape[-1]
    if flatten:
        sample = rearrange(sample, 'b t c -> (b t) c')
        sample = sample[:, :n_channels].flatten()
        plt.hist(sample.detach().numpy(), bins=100)
        plt.xlabel("Value", fontsize=16)
        plt.ylabel("Frequency", fontsize=16)
    else:
        ncols = min(n_channels, 8)
        nrows = int(np.ceil(n_channels / ncols))
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10*ncols, 6*nrows))
        for i, ax in enumerate(axes.flat):
            ax.hist(sample[:, :, i].detach().numpy().flatten(), bins=100)
            ax.axvline(x=0, color='red', linestyle='--')
            ax.set_xlabel("Value", fontsize=16)
            ax.set_ylabel("Frequency", fontsize=16)
    plt.show()

In [None]:
sample_lvl1.shape

### LVL2

In [None]:
compute_sample_statistics(sample_lvl2)

In [None]:
sample_histogram(sample_lvl2, n_channels=64, flatten=False)

### LVL1

In [None]:
compute_sample_statistics(sample_lvl1)

In [None]:
compute_stats_per_channel(sample_lvl1)

In [None]:
sample_histogram(sample_lvl1, n_channels=64, flatten=False)

### LVL0

In [None]:
compute_sample_statistics(sample_lvl0)

In [None]:
sample_histogram(sample_lvl0, n_channels=64, flatten=False)

# Play audio

### LVL2

In [None]:
play_audio(audio_lvl2, num_samples=4)

### LVL1

In [None]:
play_audio(audio_lvl1, num_samples=4)

### LVL0

In [None]:
play_audio(audio_lvl0, num_samples=4)

In [None]:
with torch.no_grad():
    embeddings = model.generate_unconditionally(
                batch_size=4,
                seq_len=2048,
                num_inference_steps=100,
                seed=420,
            )
    

In [None]:
compute_sample_statistics(embeddings*10)

In [None]:
audio_gen = model.decode(embeddings*8, lvl=1)

In [None]:
play_audio(audio_gen)