# Audio Embedders

In [None]:
#| default_exp audio.embedding

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

## EncoDec

In [None]:
#| export
from encodec import EncodecModel
from encodec.utils import convert_audio

import torchaudio
import torch

from lhotse.features import FeatureExtractor
from lhotse.utils import compute_num_frames, Seconds
from lhotse import CutSet, Fbank

from matplotlib import pyplot as plt
import IPython.display as ipd
import numpy as np

from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Pattern, Union
from plum import dispatch

from nimrod.audio.utils import plot_waveform
from nimrod.utils import get_device

In [None]:
#| export
class EncoDec():
    def __init__(self, device:str='cpu'):
        self.model = EncodecModel.encodec_model_24khz()
        self._device = device
        self.model.to(self._device)
        self.model.set_target_bandwidth(6.0)

    @dispatch
    def __call__(self, wav:torch.Tensor, sr:int)->torch.Tensor:
        # (CxT) -> (CxDxT_frames)
        if sr != self.model.sample_rate:
            wav = convert_audio(wav, sr, self.model.sample_rate, self.model.channels) # model.sample_rate=24kHz
        wav = wav.unsqueeze(0)
        with torch.no_grad():
            encoded_frames = self.model.encode(wav)
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
        return(codes)
    
    @dispatch
    def __call__(self, wav:np.ndarray, sr:int)->torch.Tensor:
        wav = torch.from_numpy(wav).float().unsqueeze(0)
        if sr != self.model.sample_rate:
            wav = convert_audio(wav, sr, self.model.sample_rate, self.model.channels) # model.sample_rate=24kHz
        # wav = wav.unsqueeze(0)
        with torch.no_grad():
            encoded_frames = self.model.encode(wav.to(self._device))
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
        return(codes)

    def decode(self, codes:torch.Tensor)->torch.Tensor:
        # (CxDxT_frames) -> (CxT)
        frames_from_code = [(codes, None)]
        return(self.model.decode(encoded_frames=frames_from_code))

    @property
    def sample_rate(self):
        return self.model.sample_rate
    
    @property
    def device(self):
        return self._device

### Usage

In [None]:
wav, sr = torchaudio.load("../data/audio/obama.wav")
# wav, sr = torch.rand((1, 24000)), 24000
# wav, sr = np.random.random((1, 24000)), 24000

encodec = EncoDec(device='cpu')
codes = encodec(wav,sr)
print(f"wav: {wav.shape}, code: {codes.shape} ")
plt.rcParams["figure.figsize"] = (5,5)
plt.xlabel('frames')
plt.ylabel('quantization')
plt.imshow(codes.squeeze().cpu().numpy())
decoded = encodec.decode(codes)
plot_waveform(decoded.detach().cpu().squeeze(0), encodec.sample_rate)


In [None]:
plt.plot(codes[0][0])
print(codes[0][0].shape)

In [None]:
#| hide
ipd.Audio(wav.squeeze(0).numpy(), rate=sr)

### Lhotse-style Encodec feature extractor

In [None]:
#| export
# https://lhotse.readthedocs.io/en/v0.6_ba/features.html#creating-custom-feature-extractor
@dataclass
class EncoDecConfig:
    # The encoder produces embeddings at 75 Hz for input waveforms at 24 kHz,
    # which is a 320-fold reduction in the sampling rate.
    frame_shift: float = 320.0 / 24000
    n_q: int = 8

class EncoDecExtractor(FeatureExtractor):
    name = 'encodec'
    config_type = EncoDecConfig
    def __init__(self, config=EncoDecConfig()):
        super().__init__(config)
        self.encodec = EncoDec()

    def extract(self, samples:Union[torch.Tensor, np.ndarray], sampling_rate: int) -> np.ndarray:    
        codes = self.encodec(samples, sampling_rate)
        duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
        expected_num_frames = compute_num_frames(
            duration=duration,
            frame_shift=self.frame_shift,
            sampling_rate=sampling_rate,
        )
        assert abs(codes.shape[-1] - expected_num_frames) <= 1
        codes = codes[..., :expected_num_frames]
        return codes.cpu().squeeze(0).permute(1, 0).numpy()

    @property
    def frame_shift(self)->float:
        return self.config.frame_shift

    def feature_dim(self, sampling_rate: int) -> int:
        return self.config.n_q

    

In [None]:
encodec_extractor = EncoDecExtractor()
# cuts = CutSet.from_file("../recipes/tts/ljspeech/data/first_3.jsonl.gz")
cuts = CutSet.from_file("../data/en/LJSpeech-1.1/first_3.encodec.jsonl.gz")
print(cuts[0])
print(cuts[1])


In [None]:
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)

In [None]:
# feats = cuts.compute_and_store_features(extractor=Fbank(), storage_path="../recipes/tts/ljspeech/data/feats")

In [None]:
# storage_path = "../.data/en/LJSpeech-1.1"
# # storage_path = "../recipes/tts/ljspeech/data/feats"
# # TODO: make it work with num_jobs>1
# cuts = cuts.compute_and_store_features(
#     extractor=encodec_extractor,
#     storage_path=storage_path,
#     num_jobs=1,
# )
# cuts.to_file("../recipes/tts/ljspeech/data/cuts_encodec.jsonl.gz")
# print(cuts[0])
# cuts[0].plot_features()
# print(cuts)

In [None]:
files = "../data/en/LJSpeech-1.1/cuts_encodec.jsonl.gz"
# files = "../recipes/tts/ljspeech/data/cuts_encodec.jsonl.gz"
cuts = CutSet.from_file(files)
print(cuts)

In [None]:
### HF
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor

# dummy dataset, however you can swap this with an dataset on the 🤗 hub or bring your own
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# load the model + processor (for pre-processing the audio)
model = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
librispeech_dummy[0]
# cast the audio data to the correct sampling rate for the model
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[0]["audio"]["array"]

## AudioLM

In [None]:
# TO DO

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()