# Colab Jukemir: quickly extract representations with OpenAI's Jukebox

While the [`jukemir` GitHub repo](https://github.com/p-lambda/jukemir) prioritizes exactly replicating results from our [ISMIR 2021 paper](https://arxiv.org/abs/2107.05677), this notebook is a user-friendly companion that enables fast experimentation by making extracting representations from [Jukebox](https://arxiv.org/abs/2005.00341) as easy and painless as possible. It'll take about an hour to get the model weights downloaded, but once they're saved you'll be able to extract representations in seconds.

We showcase:
- extracting mean-pooled intermediate representations from 24 seconds of a music clip
    - [useful for clip-level tasks such a genre classification or tagging](https://arxiv.org/abs/2107.05677)
- extracting frame-wise representations from later layers from 24 seconds of a music clip
    - [useful for time-varying tasks such as transcription](https://archives.ismir.net/ismir2021/latebreaking/000049.pdf)
- extracting representations from clips shorter than 24 seconds at speeds commensurate with clip length
    - useful for short loops

In general, we enable you to extract representations from the layers of your choice, with the pooling method of your choice, from an audio file of your choice with a simple, clean API! And, thanks to some additional optimizations, we now enable you to extract representations from the **full model** with only 12GB RAM and 12GB VRAM! Previously, you either needed a 24GB VRAM GPU or to perform two separate forward passes and cache intermediate representations.

A disclaimer:
- We use 24 seconds because Jukebox's context window is 1048576 / 44100 ~= 23.7 seconds, so you can only extract up to that size. Feel free to concatenate for clips greater than 24 seconds.

Last tested October 5th, 2022.

## Environment setup

Set up the environment:
- get the Jukebox codebase
- get our demo song
- install a couple packages
- patch a file in the Jukebox codebase, which allows us to monitor download progress

In [None]:
#@title Installation

# jukebox codebase
!pip install git+https://github.com/openai/jukebox.git

# youtube-dl
!sudo curl -L https://yt-dl.org/downloads/latest/youtube-dl -o /usr/local/bin/youtube-dl
!sudo chmod a+rx /usr/local/bin/youtube-dl

# rickroll
!youtube-dl -x --audio-format mp3 https://www.youtube.com/watch?v=dQw4w9WgXcQ

!pip install wget accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/openai/jukebox.git
  Cloning https://github.com/openai/jukebox.git to /tmp/pip-req-build-jgp9ba87
  Running command git clone -q https://github.com/openai/jukebox.git /tmp/pip-req-build-jgp9ba87
Collecting fire==0.1.3
  Downloading fire-0.1.3.tar.gz (33 kB)
Collecting tqdm==4.45.0
  Downloading tqdm-4.45.0-py2.py3-none-any.whl (60 kB)
[K     |████████████████████████████████| 60 kB 9.8 MB/s 
Collecting unidecode==1.1.1
  Downloading Unidecode-1.1.1-py2.py3-none-any.whl (238 kB)
[K     |████████████████████████████████| 238 kB 58.7 MB/s 
[?25hCollecting numba==0.48.0
  Downloading numba-0.48.0-1-cp37-cp37m-manylinux2014_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 56.4 MB/s 
[?25hCollecting librosa==0.7.2
  Downloading librosa-0.7.2.tar.gz (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 48.8 MB/s 
[?25hCollecting mp

In [None]:
#@title Create patch file for make_models.py
%%writefile make_models.py.patched
--- make_models.py	2022-09-11 11:45:47.000000000 -0400
+++ make_models.py.patched	2022-09-11 12:06:46.000000000 -0400
@@ -14,6 +14,9 @@
 from jukebox.vqvae.vqvae import calculate_strides
 import fire

+import wget
+import sys
+
 MODELS = {
     '5b': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b"),
     '5b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b_lyrics"),
@@ -31,7 +34,15 @@
             if not os.path.exists(os.path.dirname(local_path)):
                 os.makedirs(os.path.dirname(local_path))
             if not os.path.exists(local_path):
-                download(remote_path, local_path)
+                # create this bar_progress method which is invoked automatically from wget
+                def bar_progress(current, total, width=80):
+                  progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
+                  # Don't use print() as it will print in new line every time.
+                  sys.stdout.write("\r" + progress_message)
+                  sys.stdout.flush()
+
+                wget.download(remote_path, local_path, bar=bar_progress)
+                #download(remote_path, local_path)
         restore = local_path
     dist.barrier()
     checkpoint = t.load(restore, map_location=t.device('cpu'))
@@ -58,7 +69,7 @@
         #     if checkpoint_hps.get(k, None) != hps.get(k, None):
         #         print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
         checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
-        model.load_state_dict(checkpoint['model'])
+        model.load_state_dict(checkpoint['model'], strict=False)
         if 'step' in checkpoint: model.step = checkpoint['step']

 def restore_opt(opt, shd, checkpoint_path):

Writing make_models.py.patched


In [None]:
#@title Patch it
!patch /usr/local/lib/python3.7/dist-packages/jukebox/make_models.py make_models.py.patched

patching file /usr/local/lib/python3.7/dist-packages/jukebox/make_models.py


## Model setup

Here we'll set up the VQ-VAE and LM prior models, whose weights will be downloaded from Azure. Note that it'll take a long time (on the order of an hour) to download the LM prior; the weights are ~10GB. To avoid having to download this every time you run this notebook, we allow you to cache the weights to a google drive folder of your choice.

Thanks to the [`accelerate` library from Hugging Face](https://huggingface.co/docs/accelerate/index) and some engineering work, we've performed significant memory optimizations on top of `jukemir` that now enable you to load in and perform inference with the *entire* Jukebox model in a Colab free instance, which includes only 12GB RAM and 12GB VRAM; previously, we were only able to fit 36 layers under these tight memory constraints. This optimization makes the full model much more accessible to the community.

The second code snippet sets up the extraction code, which bears similarities to [this file on our GitHub](https://github.com/p-lambda/jukemir/blob/main/representations/jukebox/main.py).

In [None]:
#@title Set up hyperparameters + download model weights
gdrive_cache_dir = "drive/Shareddrives/Jukemir" #@param {type:"string"}
cache_gdrive = True #@param {type:"boolean"}

import os
import shutil
from pathlib import Path

VQVAE_CACHE_PATH = '/root/.cache/jukebox/models/5b/vqvae.pth.tar'
PRIOR_CACHE_PATH = '/root/.cache/jukebox/models/5b/prior_level_2.pth.tar'

if cache_gdrive:
    from google.colab import drive
    drive.mount('drive')

    VQVAE_CACHE_PATH = gdrive_cache_dir + '/vqvae.pth.tar'
    PRIOR_CACHE_PATH = gdrive_cache_dir + '/prior_level_2.pth.tar'

    os.makedirs(gdrive_cache_dir, exist_ok=True)
else:
    os.makedirs(Path(VQVAE_CACHE_PATH).parent, exist_ok=True)

# imports and set up Jukebox's multi-GPU parallelization
import jukebox
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.make_models import MODELS, make_prior, make_vqvae
from jukebox.utils.dist_utils import setup_dist_from_mpi
from tqdm import tqdm

from accelerate import init_empty_weights

# Set up MPI
rank, local_rank, device = setup_dist_from_mpi()

# Set up VQVAE
model = "5b"  # or "1b_lyrics"
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 3 if model == "5b_lyrics" else 8
hps.name = "samples"
chunk_size = 16 if model == "5b_lyrics" else 32
max_batch_size = 3 if model == "5b_lyrics" else 16
hps.levels = 3
hps.hop_fraction = [0.5, 0.5, 0.125]
vqvae, *priors = MODELS[model]

hparams = setup_hparams(vqvae, dict(sample_length=1048576))

if cache_gdrive:
    hparams.restore_vqvae = VQVAE_CACHE_PATH

# don't actually load any weights in yet,
# leave it for later. memory optimization
with init_empty_weights():
    vqvae = make_vqvae(
        hparams, 'meta'#device
    )

# Set up language model
hparams = setup_hparams(priors[-1], dict())

# IMPORTANT LINE: only include layers UP TO prior_depth
#hparams["prior_depth"] = 72

if cache_gdrive:
    hparams.restore_prior = PRIOR_CACHE_PATH

# don't actually load any weights in yet,
# leave it for later. memory optimization
with init_empty_weights():
    top_prior = make_prior(hparams, vqvae, 'meta')#device)

# flips a bit that tells the model to return activations
# instead of projecting to tokens and getting loss for
# forward pass
top_prior.prior.only_encode = True

##############################################
# actually loading in the model weights now! #
##############################################

import torch
from tqdm import tqdm
import torch.nn as nn

top_prior_weights = torch.load(PRIOR_CACHE_PATH, map_location='cpu')

def set_module_tensor_to_device(
    module: nn.Module, tensor_name: str, device, value=None
):
    """
    A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
    `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
    Args:
        module (`torch.nn.Module`): The module in which the tensor we want to move lives.
        param_name (`str`): The full name of the parameter/buffer.
        device (`int`, `str` or `torch.device`): The device on which to set the tensor.
        value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any
            other device).
    """
    # Recurse if needed
    if "." in tensor_name:
        splits = tensor_name.split(".")
        for split in splits[:-1]:
            new_module = getattr(module, split)
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            module = new_module
        tensor_name = splits[-1]

    if tensor_name not in module._parameters and tensor_name not in module._buffers:
        raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
    is_buffer = tensor_name in module._buffers
    old_value = getattr(module, tensor_name)

    if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
        raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

    with torch.no_grad():
        if value is None:
            new_value = old_value.to(device)
        elif isinstance(value, torch.Tensor):
            new_value = value.to(device)
        else:
            new_value = torch.tensor(value, device=device)

        if is_buffer:
            module._buffers[tensor_name] = new_value
        elif value is not None or torch.device(device) != module._parameters[tensor_name].device:
            param_cls = type(module._parameters[tensor_name])
            kwargs = module._parameters[tensor_name].__dict__
            new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
            module._parameters[tensor_name] = new_value

# load_state_dict, basically
for k in tqdm(top_prior_weights['model'].keys()):
    set_module_tensor_to_device(top_prior, k, 'cuda', value=top_prior_weights['model'][k])

del top_prior_weights

import gc
gc.collect()

vqvae_weights = torch.load(VQVAE_CACHE_PATH, map_location='cpu')

for k in tqdm(vqvae_weights['model'].keys()):
    set_module_tensor_to_device(vqvae, k, 'cuda', value=vqvae_weights['model'][k])

Mounted at drive
Using cuda True
Restored from drive/Shareddrives/Jukemir/vqvae.pth.tar
0: Loading vqvae in eval mode
Loading artist IDs from /usr/local/lib/python3.7/dist-packages/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /usr/local/lib/python3.7/dist-packages/jukebox/data/ids/v2_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576
0: Converting to fp16 params
Restored from drive/Shareddrives/Jukemir/prior_level_2.pth.tar
0: Loading prior in eval mode


100%|██████████| 872/872 [00:02<00:00, 317.92it/s]
100%|██████████| 669/669 [00:00<00:00, 16270.28it/s]


In [None]:
#@title Jukebox extraction code

###########################
# Jukebox extraction code #
###########################

# Note: this code was written by reverse-engineering the model, which entailed
# combing through https://github.com/openai/jukebox all the way down the stack
# trace together with the readily-executable Colab example https://colab.research.google.com/github/openai/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb
# and modifying values as necessary to get what we needed.

import librosa as lr
import torch
import torch as t
import gc
import numpy as np

JUKEBOX_SAMPLE_RATE = 44100
T = 8192

# 1048576 found in paper, last page
DEFAULT_DURATION = 1048576 / JUKEBOX_SAMPLE_RATE

VQVAE_RATE = T / DEFAULT_DURATION

def empty_cache():
    torch.cuda.empty_cache()
    gc.collect()

def load_audio_from_file(fpath, offset=0.0, duration=None):
    if duration is not None:
        audio, _ = lr.load(fpath,
                           sr=JUKEBOX_SAMPLE_RATE,
                           offset=offset,
                           duration=duration)
    else:
        audio, _ = lr.load(fpath,
                           sr=JUKEBOX_SAMPLE_RATE,
                           offset=offset)

    if audio.ndim == 1:
        audio = audio[np.newaxis]
    audio = audio.mean(axis=0)

    # normalize audio
    norm_factor = np.abs(audio).max()
    if norm_factor > 0:
        audio /= norm_factor

    return audio.flatten()


def get_z(audio):
    # don't compute unnecessary discrete encodings
    audio = audio[: JUKEBOX_SAMPLE_RATE * 25]

    zs = vqvae.encode(torch.cuda.FloatTensor(audio[np.newaxis, :, np.newaxis]))

    z = zs[-1].flatten()[np.newaxis, :]

    return z


def get_cond(hps, top_prior):
    # model only accepts sample length conditioning of
    # >60 seconds
    sample_length_in_seconds = 62

    hps.sample_length = (
        int(sample_length_in_seconds * hps.sr) // top_prior.raw_to_tokens
    ) * top_prior.raw_to_tokens

    # NOTE: the 'lyrics' parameter is required, which is why it is included,
    # but it doesn't actually change anything about the `x_cond`, `y_cond`,
    # nor the `prime` variables. The `prime` variable is supposed to represent
    # the lyrics, but the LM prior we're using does not condition on lyrics,
    # so it's just an empty tensor.
    metas = [
        dict(
            artist="unknown",
            genre="unknown",
            total_length=hps.sample_length,
            offset=0,
            lyrics="""lyrics go here!!!""",
        ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, "cuda")]

    x_cond, y_cond, prime = top_prior.get_cond(None, top_prior.get_y(labels[-1], 0))

    x_cond = x_cond[0, :T][np.newaxis, ...]
    y_cond = y_cond[0][np.newaxis, ...]

    return x_cond, y_cond

def downsample(representation,
               target_rate=30,
               method=None):
    if method is None:
        method = 'librosa_fft'

    if method == 'librosa_kaiser':
        resampled_reps = lr.resample(np.asfortranarray(representation.T),
                                     T / DEFAULT_DURATION,
                                     target_rate).T
    elif method in ['librosa_fft', 'librosa_scipy']:
        resampled_reps = lr.resample(np.asfortranarray(representation.T),
                                     T / DEFAULT_DURATION,
                                     target_rate,
                                     res_type='fft').T
    elif method == 'mean':
        raise NotImplementedError

    return resampled_reps

def get_final_activations(z, x_cond, y_cond, top_prior):

    x = z[:, :T]

    input_length = x.shape[1]

    if x.shape[1] < T:
        # arbitrary choices
        min_token = 0
        max_token = 100

        x = torch.cat((x,
                       torch.randint(min_token, max_token, size=(1, T - input_length,), device='cuda')),
                      dim=-1)

    # encoder_kv and fp16 are set to the defaults, but explicitly so
    out = top_prior.prior.forward(
        x, x_cond=x_cond, y_cond=y_cond, encoder_kv=None, fp16=False
    )

    # chop off, in case input was already chopped
    out = out[:,:input_length]

    return out

def roll(x, n):
    return t.cat((x[:, -n:], x[:, :-n]), dim=1)

def get_activations_custom(x,
                           x_cond,
                           y_cond,
                           layers_to_extract=None,
                           fp16=False,
                           fp16_out=False):

    # this function is adapted from:
    # https://github.com/openai/jukebox/blob/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/jukebox/prior/autoregressive.py#L116

    # custom jukemir stuff
    if layers_to_extract is None:
        layers_to_extract = [36]

    x = x[:,:T]  # limit to max context window of Jukebox

    input_seq_length = x.shape[1]

    # chop x_cond if input is short
    x_cond = x_cond[:, :input_seq_length]

    # Preprocess.
    with t.no_grad():
        x = top_prior.prior.preprocess(x)

    N, D = x.shape
    assert isinstance(x, t.cuda.LongTensor)
    assert (0 <= x).all() and (x < top_prior.prior.bins).all()

    if top_prior.prior.y_cond:
        assert y_cond is not None
        assert y_cond.shape == (N, 1, top_prior.prior.width)
    else:
        assert y_cond is None

    if top_prior.prior.x_cond:
        assert x_cond is not None
        assert x_cond.shape == (N, D, top_prior.prior.width) or x_cond.shape == (N, 1, top_prior.prior.width), f"{x_cond.shape} != {(N, D, top_prior.prior.width)} nor {(N, 1, top_prior.prior.width)}. Did you pass the correct --sample_length?"
    else:
        assert x_cond is None
        x_cond = t.zeros((N, 1, top_prior.prior.width), device=x.device, dtype=t.float)

    x_t = x # Target
    # self.x_emb is just a straightforward embedding, no trickery here
    x = top_prior.prior.x_emb(x) # X emb
    # this is to be able to fit in a start token/conditioning info: just shift to the right by 1
    x = roll(x, 1) # Shift by 1, and fill in start token
    # self.y_cond == True always, so we just use y_cond here
    if top_prior.prior.y_cond:
        x[:,0] = y_cond.view(N, top_prior.prior.width)
    else:
        x[:,0] = top_prior.prior.start_token

    # for some reason, p=0.0, so the dropout stuff does absolutely nothing
    x = top_prior.prior.x_emb_dropout(x) + top_prior.prior.pos_emb_dropout(top_prior.prior.pos_emb())[:input_seq_length] + x_cond # Pos emb and dropout

    layers = top_prior.prior.transformer._attn_mods

    reps = {}

    if fp16:
        x = x.half()

    for i, l in enumerate(layers):
        # to be able to take in shorter clips, we set sample to True,
        # but as a consequence the forward function becomes stateful
        # and its state changes when we apply a layer (attention layer
        # stores k/v's to cache), so we need to clear its cache religiously
        l.attn.del_cache()

        x = l(x, encoder_kv=None, sample=True)

        l.attn.del_cache()

        if i + 1 in layers_to_extract:
            reps[i + 1] = np.array(x.squeeze().cpu())

            # break if this is the last one we care about
            if layers_to_extract.index(i + 1) == len(layers_to_extract) - 1:
                break

    return reps


# important, gradient info takes up too much space,
# causes CUDA OOM
@torch.no_grad()
def get_acts_from_audio(audio=None,
                        fpath=None,
                        meanpool=False,
                        # pick which layer(s) to extract from
                        layers=None,
                        # pick which part of the clip to load in
                        offset=0.0,
                        duration=None,
                        # downsampling frame-wise reps
                        downsample_target_rate=None,
                        downsample_method=None,
                        # for speed-saving
                        fp16=False,
                        # for space-saving
                        fp16_out=False,
                        # for GPU VRAM. potentially slows it
                        # down but we clean up garbage VRAM.
                        # disable if your GPU has a lot of memory
                        # or if you're extracting from earlier
                        # layers.
                        force_empty_cache=True):

    # main function that runs extraction end-to-end.

    if layers is None:
        layers = [36]  # by default

    if audio is None:
        assert fpath is not None
        audio = load_audio_from_file(fpath, offset=offset, duration=duration)
    elif fpath is None:
        assert audio is not None

    if force_empty_cache: empty_cache()

    # run vq-vae on the audio to get discretized audio
    z = get_z(audio)

    if force_empty_cache: empty_cache()

    # get conditioning info
    x_cond, y_cond = get_cond(hps, top_prior)

    if force_empty_cache: empty_cache()

    # get the activations from the LM
    acts = get_activations_custom(z,
                                  x_cond,
                                  y_cond,
                                  layers_to_extract=layers,
                                  fp16=fp16,
                                  fp16_out=fp16_out)

    if force_empty_cache: empty_cache()

    # postprocessing
    if downsample_target_rate is not None:
        for num in acts.keys():
            acts[num] = downsample(acts[num],
                                   target_rate=downsample_target_rate,
                                   method=downsample_method)

    if meanpool:
        acts = {num: act.mean(axis=0) for num, act in acts.items()}

    if not fp16_out:
        acts = {num: act.astype(np.float32) for num, act in acts.items()}

    return acts


## Extract!

Now, extract 24 seconds of representations from a single audio file, offset by a particular duration in seconds of your choosing.

In [None]:
#@title Pick file ("Cancel upload" to use default song)
offset = 0 #@param {type:"number"}

from IPython.display import Audio

from google.colab import files

fnames = files.upload()

if len(fnames.keys()) == 0:
    fname = 'Rick Astley - Never Gonna Give You Up (Official Music Video)-dQw4w9WgXcQ.mp3'
else:
    fname = list(fnames.keys())[0]

audio, sr = lr.load(fname,
                    sr=None,
                    offset=offset,
                    duration=25)

Audio(data=audio, rate=sr)



### Extract mean-pooled intermediate representations

Here, we'll demonstrate extracting mean-pooled intermediate representations from our example clip. We'll be extracting from *layer 36*, which is exactly in the center. It'll only take a couple of lines of code!

First, we get the audio from the given file. This will apply the same preprocessing as we applied to obtain the results in our paper (resample to `JUKEBOX_SAMPLE_RATE`, mean across channels, then normalize).

In [None]:
audio = load_audio_from_file(fname, offset=0.0, duration=25)



And now we extract, passing in the audio, desired layers, and `meanpool` flag to indicate that we would like the representations meanpooled across time.

In [None]:
representations = get_acts_from_audio(audio=audio,
                                      layers=[36],
                                      meanpool=True)

print(f"Got representations {representations}")
print(f"Its shape is {representations[36].shape}")

Got representations {36: array([ 1.1042459 ,  0.5964859 , -0.06048424, ..., -0.4644476 ,
        0.37787658,  0.30803093], dtype=float32)}
Its shape is (4800,)


As we can see, we've got a dictionary where each key represents the layer

*   List item
*   List item

number from which the activations were taken from with their corresponding activations, and the size is 4800.

### Extracting frame-wise representations from a later layer

Notice that the API exposes a `layers` parameter, which allows us to extract multiple layers at once. Let's extract frame-wise representations from layers 54, 58, and 62 simultaneously...

In [None]:
audio = load_audio_from_file(fname, offset=0.0, duration=25)

representations = get_acts_from_audio(audio=audio,
                                      layers=[54, 58, 62],
                                      meanpool=False)

print(f"Got representations {representations}")
print(f"The shape of each array is {[rep.shape for rep in representations.values()]}")



Got representations {54: array([[-6.5488410e-01, -1.9020307e+00,  1.1130141e-01, ...,
        -5.2601069e-01, -2.4407668e+00, -2.1879044e+00],
       [ 9.8869938e-01,  2.7309281e-01,  1.7522010e-01, ...,
        -3.2298431e-01,  1.2508178e+00,  1.9323468e+00],
       [ 1.1082835e+00, -3.6255893e-01, -1.1494322e+00, ...,
        -4.2663392e-01,  1.5655483e+00,  2.4646497e-01],
       ...,
       [ 1.7140074e+00, -2.1496266e-03,  1.8859807e+00, ...,
        -3.2228312e-01,  2.0356667e+00,  2.2030625e+00],
       [ 6.5384918e-01,  1.6972164e+00,  1.7530643e+00, ...,
        -3.7936896e-02, -3.4884918e-01,  7.2903872e-02],
       [ 2.2959871e+00,  1.3109661e+00,  3.6313772e+00, ...,
        -2.1705849e+00,  4.6568050e+00, -1.0985476e+00]], dtype=float32), 58: array([[-0.3483994 , -0.92356634,  0.8723887 , ..., -0.46199495,
        -2.8107705 , -0.91144466],
       [ 0.39271677, -0.67747015, -0.01698004, ...,  0.548148  ,
         1.5664817 ,  2.2167678 ],
       [ 0.82011145, -0.8566189 , 

If you're using these representations for some downstream task whose prediction frames are set at a different sample rate, this might be a bit problematic. Fortunately, we've added a couple extra arguments to handle this automatically. Let's extract, downsampling our predictions to 30fps.

In [None]:
audio = load_audio_from_file(fname, offset=0.0, duration=25)

representations = get_acts_from_audio(audio=audio,
                                      layers=[54, 58, 62],
                                      meanpool=False,
                                      downsample_target_rate=30)

print(f"Got representations {representations}")
print(f"The shape of each array is {[rep.shape for rep in representations.values()]}")



Got representations {54: array([[ 0.54728496, -0.2811083 ,  0.32937825, ...,  0.05592334,
         1.0386577 ,  0.3298235 ],
       [ 0.5496505 , -0.16039667, -0.92832434, ..., -0.45035437,
        -0.27714533, -0.38926223],
       [ 0.11733392, -0.53121537, -0.6412263 , ..., -0.12891561,
        -0.258517  , -0.21131144],
       ...,
       [ 1.421645  ,  0.2956466 , -0.5130327 , ...,  0.90778184,
         0.46138397, -1.2169237 ],
       [ 2.4711752 , -0.08333597, -0.7736691 , ...,  1.8281366 ,
        -0.06805634, -0.7634933 ],
       [ 2.028346  ,  1.2069563 ,  1.3009052 , ...,  1.6956192 ,
         0.43768117,  0.14164865]], dtype=float32), 58: array([[ 0.37905717, -0.50836253,  0.25710008, ...,  1.1788434 ,
         0.77161634,  0.60939085],
       [ 0.0557559 , -0.20562318,  0.02686153, ..., -1.1484089 ,
        -0.2374229 ,  0.06390017],
       [-0.35016868, -0.17340958, -0.36217585, ..., -0.17487833,
        -0.2589977 ,  0.09674562],
       ...,
       [-0.05398118,  0.081335

And now it's downsampled! (under the hood, this is just librosa resampling)

### Extracting representations from short clips

What if you want to extract representations from short clips or loops? We can extract from these without padding and incurring additional performance cost.

In [None]:
audio = load_audio_from_file(fname, offset=0.0, duration=5)
print(f"the audio clip is {audio.shape[0] / JUKEBOX_SAMPLE_RATE} seconds long")

representations = get_acts_from_audio(audio=audio,
                                      layers=[54, 58, 62],
                                      meanpool=False)

print(f"Got representations {representations}")
print(f"The shape of each array is {[rep.shape for rep in representations.values()]}")



the audio clip is 5.0 seconds long
Got representations {54: array([[-0.6548871 , -1.9020298 ,  0.1112991 , ..., -0.52601165,
        -2.4407587 , -2.1879015 ],
       [ 0.9886987 ,  0.27309182,  0.17522153, ..., -0.32298163,
         1.2508204 ,  1.9323473 ],
       [ 1.1082867 , -0.36255932, -1.1494346 , ..., -0.4266355 ,
         1.5655509 ,  0.24646433],
       ...,
       [ 0.246226  ,  3.1807413 , -0.5397909 , ...,  1.3804241 ,
         0.28671706,  3.5623958 ],
       [ 1.3964403 ,  1.0338751 , -0.8635645 , ...,  0.14509486,
         1.3929801 ,  0.9724273 ],
       [-0.09438097, -1.5011432 , -0.63438535, ...,  0.21416903,
         1.0999573 ,  1.2860197 ]], dtype=float32), 58: array([[-0.34840524, -0.92357194,  0.8723905 , ..., -0.46199757,
        -2.8107626 , -0.911443  ],
       [ 0.39271784, -0.6774707 , -0.01697882, ...,  0.54814947,
         1.5664833 ,  2.2167695 ],
       [ 0.82011515, -0.856619  , -0.10058692, ..., -0.22079834,
         1.0946267 ,  0.55781764],
       

Only of length 1722 instead of 8192! And, we can verify that it was faster, too.

In [None]:
import time

start = time.time()
representations = get_acts_from_audio(audio=audio,
                                      layers=[54, 58, 62],
                                      meanpool=False)

end = time.time()

print(f"5-second clip took {end-start:.1f} seconds.")

padded_audio = np.concatenate((audio, np.zeros((1048576 - audio.shape[0],))))

start = time.time()
representations = get_acts_from_audio(audio=padded_audio,
                                      layers=[54, 58, 62],
                                      meanpool=False)

end = time.time()

print(f"24-second clip took {end-start:.1f} seconds.")

5-second clip took 4.9 seconds.
24-second clip took 22.3 seconds.


As a disclaimer here, you will get different (albeit valid) representations if you first chop the audio and then pass it through the model rather than passing the full 24 seconds through and then chopping. These differences will occur at the end of the chopped subclip, where the VQ-VAE encoder's input differs (either includes audio past the endpoint or does not).