In [None]:
#|default_exp audio.data
#|default_cls_lvl 3

In [None]:
#|export
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# Audio Data
> Audio DataBlocks and show_batch

In [None]:
#|export
from __future__ import annotations

from torch.nn.functional import interpolate
import torchaudio.transforms as tatfms
import matplotlib.pyplot as plt

from fastcore.dispatch import typedispatch, explode_types
from fastcore.transform import DisplayedTransform

from fastai.basics import defaults
from fastai.torch_core import default_device
from fastai.vision.data import get_grid
from fastai.data.block import TransformBlock

from fastxtend.audio.core import TensorAudio, TensorSpec, TensorMelSpec
from fastxtend.imports import *
from fastxtend.basics import *

## Show methods -

In [None]:
#|exporti
@typedispatch
def show_batch(x:TensorAudio, y, samples, ctxs=None, max_n=9, nrows=None, ncols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, hear=False, **kwargs)
    plt.tight_layout()
    return ctxs

In [None]:
#|exporti
@typedispatch
def show_batch(x:TensorSpec|TensorMelSpec, y, samples, ctxs=None, max_n=9, nrows=None, ncols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
    plt.tight_layout()
    return ctxs

## Spectrogram Transform -

In [None]:
#|export
class Spectrogram(DisplayedTransform):
    "Convert a `TensorAudio` into one or more `TensorSpec`"
    order = 75
    def __init__(self,
        n_fft:Listified[int]=1024,
        win_length:Listified[int]|None=None,
        hop_length:Listified[int]|None=None,
        pad:Listified[int]=0,
        window_fn:Listified[Callable[..., Tensor]]=torch.hann_window,
        power:Listified[float]=2.,
        normalized:Listified[bool]=False,
        wkwargs:Listified[dict]|None=None,
        center:Listified[bool]=True,
        pad_mode:Listified[str]="reflect",
        onesided:Listified[bool]=True,
        norm:Listified[str]|None=None
    ):
        super().__init__()
        listify_store_attr()
        attrs = {k:v for k,v in getattr(self,'__stored_args__',{}).items() if k not in ['size', 'mode']}
        # self.resize = size is not None
        if is_listy(self.n_fft):
            self.specs, self._attrs = nn.ModuleList(), []
            self.len, self.multiple = len(self.n_fft), True
            for i in range(self.len):
                self.specs.append(tatfms.Spectrogram(n_fft=self.n_fft[i], win_length=self.win_length[i],
                                                     hop_length=self.hop_length[i], pad=self.pad[i],
                                                     window_fn=self.window_fn[i], power=self.power[i],
                                                     normalized=self.normalized[i], wkwargs=self.wkwargs[i],
                                                     center=self.center[i], pad_mode=self.pad_mode[i],
                                                     onesided=self.onesided[i], norm=self.norm[i]))

                self._attrs.append({k:v[i] for k,v in self._get_attrs().items()})
        else:
            self.multiple = False
            self.spec = tatfms.Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length,
                                           pad=self.pad, window_fn=self.window_fn, power=self.power,
                                           normalized=self.normalized, wkwargs=self.wkwargs, center=self.center,
                                           pad_mode=self.pad_mode, onesided=self.onesided, norm=self.norm)

            self._attrs = {k:v for k,v in self._get_attrs().items()}

    def encodes(self, x:TensorAudio):
        if self.multiple:
            specs = []
            for i in range(self.len):
                specs.append(TensorSpec.create(self.specs[i](x), settings=self._attrs[i]))
            return tuple(specs)
        else:
            return TensorSpec.create(self.spec(x), settings=self._attrs)

    def to(self, *args, **kwargs):
        device, *_ = torch._C._nn._parse_to(*args, **kwargs)
        if self.multiple:
            self.specs.to(device)
        else:
            self.spec.to(device)

    def _get_attrs(self):
        return {k:v for k,v in getattr(self,'__dict__',{}).items() if k in getattr(self,'__stored_args__',{}).keys()}

## Mel Transform -

In [None]:
#|export
class MelSpectrogram(DisplayedTransform):
    "Convert a `TensorAudio` into one or more `TensorMelSpec`"
    order = 75
    def __init__(self,
        sample_rate:Listified[int]=16000,
        n_fft:Listified[int]=1024,
        win_length:Listified[int]|None=None,
        hop_length:Listified[int]|None=None,
        f_min:Listified[float]=0.,
        f_max:Listified[float]|None=None,
        pad:Listified[int]=0,
        n_mels:Listified[int]=128,
        window_fn:Listified[Callable[..., Tensor]]=torch.hann_window,
        power:Listified[float]=2.,
        normalized:Listified[bool]=False,
        wkwargs:Listified[dict]|None=None,
        center:Listified[bool]=True,
        pad_mode:Listified[str]="reflect",
        norm:Listified[str]|None=None,
        mel_scale:Listified[str]="htk",
        # size:tuple[int,int]|None=None, # If set, resize MelSpectrogram to `size`
        # mode='bilinear'
    ):
        super().__init__()
        listify_store_attr()
        # self.resize = size is not None
        if is_listy(self.n_fft):
            self.mels, self._attrs = nn.ModuleList(), []
            self.len, self.multiple = len(self.n_fft), True
            for i in range(self.len):
                self.win_length[i] = self.win_length[i] if self.win_length[i] is not None else self.n_fft[i]
                self.hop_length[i] = self.hop_length[i] if self.hop_length[i] is not None else self.win_length[i] // 2
                self.mels.append(tatfms.MelSpectrogram(sample_rate=self.sample_rate[i], n_fft=self.n_fft[i],
                                                       win_length=self.win_length[i], hop_length=self.hop_length[i],
                                                       f_min=self.f_min[i], f_max=self.f_max[i], pad=self.pad[i],
                                                       n_mels=self.n_mels[i], window_fn=self.window_fn[i], power=self.power[i],
                                                       normalized=self.normalized[i], wkwargs=self.wkwargs[i],
                                                       center=self.center[i], pad_mode=self.pad_mode[i],
                                                       norm=self.norm[i], mel_scale=self.mel_scale[i]))

                self._attrs.append({**{k:v[i] for k,v in self._get_attrs().items()},**{'sr':self.sample_rate[i]}})
        else:
            self.multiple = False
            self.win_length = self.win_length if self.win_length is not None else self.n_fft
            self.hop_length = self.hop_length if self.hop_length is not None else self.win_length // 2
            self.mel = tatfms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=self.n_fft, win_length=self.win_length,
                                             hop_length=self.hop_length, f_min=self.f_min, f_max=self.f_max, pad=self.pad,
                                             n_mels=self.n_mels, window_fn=self.window_fn, power=self.power,
                                             normalized=self.normalized, wkwargs=self.wkwargs, center=self.center,
                                             pad_mode=self.pad_mode, norm=self.norm, mel_scale=self.mel_scale)

            self._attrs = {**{k:v for k,v in self._get_attrs().items()},**{'sr':self.sample_rate}}

    def encodes(self, x:TensorAudio):
        if self.multiple:
            mels = []
            for i in range(self.len):
                mels.append(TensorMelSpec.create(self.mels[i](x), settings=self._attrs[i]))
            return tuple(mels)
        else:
            return TensorMelSpec.create(self.mel(x), settings=self._attrs)

    def to(self, *args, **kwargs):
        device, *_ = torch._C._nn._parse_to(*args, **kwargs)
        if self.multiple:
            self.mels.to(device)
        else:
            self.mel.to(device)

    def _get_attrs(self):
        return {k:v for k,v in getattr(self,'__dict__',{}).items() if k in getattr(self,'__stored_args__',{}).keys() and k not in ['size', 'mode']}

## TransformBlocks for audio
Audio data blocks for using with the fastai [data block API](https://docs.fast.ai/data.block.html#general-api).

In [None]:
#|export
def AudioBlock(cls=TensorAudio):
    "A `TransformBlock` for audio of `cls`"
    return TransformBlock(type_tfms=cls.create)

In [None]:
#|export
def SpecBlock(cls=TensorAudio,
    # Spectrogram args
    n_fft:Listified[int]=1024,
    win_length:Listified[int]|None=None,
    hop_length:Listified[int]|None=None,
    pad:Listified[int]=0,
    window_fn:Listified[Callable[..., Tensor]]=torch.hann_window,
    power:Listified[float]=2.,
    normalized:Listified[bool]=False,
    wkwargs:Listified[dict]|None=None,
    center:Listified[bool]=True,
    pad_mode:Listified[str]="reflect",
    norm:Listified[str]|None=None
):
    "A `TransformBlock` to read `TensorAudio` and then use the GPU to turn audio into one or more `Spectrogram`s"
    return TransformBlock(type_tfms=cls.create,
                          batch_tfms=[Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length,
                                                  pad=pad, window_fn=window_fn, power=power, normalized=normalized,
                                                  wkwargs=wkwargs, center=center, pad_mode=pad_mode, norm=norm)])

In [None]:
#|export
def MelSpecBlock(cls=TensorAudio,
    # MelSpectrogram args
    sr:Listified[int]=16000,
    n_fft:Listified[int]=1024,
    win_length:Listified[int]|None=None,
    hop_length:Listified[int]|None=None,
    f_min:Listified[float]=0.,
    f_max:Listified[float]|None=None,
    pad:Listified[int]=0,
    n_mels:Listified[int]=128,
    window_fn:Listified[Callable[..., Tensor]]=torch.hann_window,
    power:Listified[float]=2.,
    normalized:Listified[bool]=False,
    wkwargs:Listified[dict]|None=None,
    center:Listified[bool]=True,
    pad_mode:Listified[str]="reflect",
    norm:Listified[str]|None=None,
    mel_scale:Listified[str]="htk"
):
    "A `TransformBlock` to read `TensorAudio` and then use the GPU to turn audio into one or more `MelSpectrogram`s"
    return TransformBlock(type_tfms=cls.create,
                          batch_tfms=[MelSpectrogram(sample_rate=sr, n_fft=n_fft, win_length=win_length,
                                                     hop_length=hop_length, f_min=f_min, f_max=f_max, pad=pad,
                                                     n_mels=n_mels, window_fn=window_fn, power=power,
                                                     normalized=normalized, wkwargs=wkwargs, center=center,
                                                     pad_mode=pad_mode, norm=norm, mel_scale=mel_scale)])