In [5]:
from pathlib import Path
from IPython.display import Audio
from transformers import EncodecModel

import numpy as np
import torch
from essentia.standard import MonoLoader, MonoWriter
from torch import Tensor
from encodecmae import load_model

device = "cuda:1"

In [2]:
mae = load_model("base", device=device)
mae.eval()



EncodecMAE(
  (wav_encoder): EncodecEncoder(
    (model): SEANetEncoder(
      (model): Sequential(
        (0): SConv1d(
          (conv): NormConv1d(
            (conv): Conv1d(1, 32, kernel_size=(7,), stride=(1,))
            (norm): Identity()
          )
        )
        (1): SEANetResnetBlock(
          (block): Sequential(
            (0): ELU(alpha=1.0)
            (1): SConv1d(
              (conv): NormConv1d(
                (conv): Conv1d(32, 16, kernel_size=(3,), stride=(1,))
                (norm): Identity()
              )
            )
            (2): ELU(alpha=1.0)
            (3): SConv1d(
              (conv): NormConv1d(
                (conv): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
                (norm): Identity()
              )
            )
          )
          (shortcut): SConv1d(
            (conv): NormConv1d(
              (conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
              (norm): Identity()
            )
          )
        )
 

In [3]:
encodec = EncodecModel.from_pretrained("facebook/encodec_24khz")
encodec.eval()

EncodecModel(
  (encoder): EncodecEncoder(
    (layers): ModuleList(
      (0): EncodecConv1d(
        (conv): Conv1d(1, 32, kernel_size=(7,), stride=(1,))
      )
      (1): EncodecResnetBlock(
        (block): ModuleList(
          (0): ELU(alpha=1.0)
          (1): EncodecConv1d(
            (conv): Conv1d(32, 16, kernel_size=(3,), stride=(1,))
          )
          (2): ELU(alpha=1.0)
          (3): EncodecConv1d(
            (conv): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
          )
        )
        (shortcut): EncodecConv1d(
          (conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
        )
      )
      (2): ELU(alpha=1.0)
      (3): EncodecConv1d(
        (conv): Conv1d(32, 64, kernel_size=(4,), stride=(2,))
      )
      (4): EncodecResnetBlock(
        (block): ModuleList(
          (0): ELU(alpha=1.0)
          (1): EncodecConv1d(
            (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))
          )
          (2): ELU(alpha=1.0)
          (3): EncodecC

In [7]:
# work with 4 second chunks
chunk_time = 4
sr = 24000
chunk_size = int(chunk_time * sr)
hop_size = chunk_size

wav_file = Path("/mnt/mtgdb-audio/stable/genre_tzanetakis/audio/22kmono/blu/blues.00000.wav")
audio = MonoLoader(filename=str(wav_file), sampleRate=sr)()

In [8]:
i = 0
audio_t = torch.tensor(audio, device=device)
audio_t = audio_t.unsqueeze(0)

x = {'wav': audio_t[:, i:i+chunk_size], 'wav_lens': torch.tensor([audio_t[:, i:i+chunk_size].shape[1]], device=device)}

chunk = x["wav"].clone().detach().cpu().numpy().squeeze()

MonoWriter(sampleRate=sr, filename="input_chunk.wav")(chunk)
Audio(x["wav"].clone().detach().cpu().numpy().squeeze(), rate=sr)

[wav @ 0x12f85f80] Using AVStream.codec.time_base as a timebase hint to the muxer is deprecated. Set AVStream.time_base instead.
[wav @ 0x12f85f80] Encoder did not produce proper pts, making some up.


In [13]:
with torch.no_grad():
    mae.encode_wav(x)
    # Do not mask
    mae.mask(x, ignore_mask=True)
    # audio features to Encoded embeddings
    x['visible_embeddings'] = mae.visible_encoder(x['visible_tokens'], padding_mask=x['visible_padding_mask'])

    # get the codes logits
    mae.decode(x)
    mae.predict_tokens(x)

    # get the estimated codes
    y = torch.argmax(x["predicted_tokens"], dim=-1)

    # fix shape
    y = torch.permute(y,(0, 2, 1))
    y = y.unsqueeze(0)

    # use HuggingFace's Encocded to decode the codes.
    # TODO: Is this model exactly the same? 
    # TODO: Check if EncodecMAE finetunes the encoder
    decoded_audio = encodec.decode(y.cpu(), [None])[0].numpy().squeeze()

Audio(decoded_audio, rate=sr)
MonoWriter(filename="decoded_chunk.wav", sampleRate=sr)(decoded_audio)

[wav @ 0x3e287940] Using AVStream.codec.time_base as a timebase hint to the muxer is deprecated. Set AVStream.time_base instead.


In [None]:
for k, v in x.items():
    if type(v) == torch.Tensor:
        print(f"{k}, {v.shape}")

wav, torch.Size([1, 96000])
wav_lens, torch.Size([1])
wav_features, torch.Size([1, 300, 128])
projected_wav_features, torch.Size([1, 300, 768])
features_len, torch.Size([1])
feature_padding_mask, torch.Size([1, 300])
visible_tokens, torch.Size([1, 300, 768])
visible_mask, torch.Size([1, 300])
non_visible_mask, torch.Size([1, 300])
visible_padding_mask, torch.Size([1, 300])
visible_lens, torch.Size([1])
visible_embeddings, torch.Size([1, 300, 768])
decoder_in, torch.Size([1, 300, 768])
decoder_out, torch.Size([1, 300, 768])
predicted_tokens, torch.Size([1, 300, 8, 1024])
