In [None]:
# default_exp audio.learner

# Audio Learner
> Learner which knows to stack `FlattenTransform`

In [None]:
#export
from __future__ import annotations

from fastcore.dispatch import retain_type

from fastai.callback.core import Callback
from fastai.callback.fp16 import MixedPrecision
from fastai.learner import Learner, defaults
from fastai.optimizer import Adam

from fastxtend.audio.core import TensorSpec, TensorMelSpec
from fastxtend.audio.data import MelSpectrogram, Spectrogram
from fastxtend.imports import *

In [None]:
#hide
from nbdev.showdoc import *

## DetupleCallback -

In [None]:
#export
class StackSpecCallback(Callback):
    "Stacks tuples of TensorSpec or TensorMelSpec. ToDo: add resizing"
    order = MixedPrecision.order-1
    def before_batch(self):
        xb = L(self.xb)
        idx = xb.argwhere(lambda x: isinstance(x, (TensorSpec, TensorMelSpec)))
        ts = []
        for i in idx:
            ts.append(xb[i])
        stacked = torch.stack(ts, dim=2)
        xb = retain_type(torch.flatten(stacked, 1, 2), xb[i])
        self.learn.xb = tuple(xb)

## audio_learner -

In [None]:
  #export
def audio_learner(
    dls, 
    model, 
    loss_func=None, 
    opt_func=Adam, 
    lr=defaults.lr, 
    splitter=trainable_params, 
    cbs=None, 
    metrics=None, 
    path=None, 
    model_dir='models', 
    wd=None, 
    wd_bn_bias=False, 
    train_bn=True,
    moms=(0.95,0.85,0.95)
) -> Learner:
    "An Audio specific Learner that stacks tuples of TensorSpec or TensorMelSpec"
    detuple = False
    for i in range(len(dls.train.after_batch.fs)):
        if not detuple and isinstance(dls.train.after_batch[i], (Spectrogram, MelSpectrogram)):
            detuple = is_listy(dls.train.after_batch[i].n_fft)

    if detuple:
        if cbs is None: cbs = DetupleSpecCallback()
        else: cbs = L(cbs) + L(DetupleSpecCallback())

    return Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                    metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
                    moms=moms)