# Model training

In [None]:
from pathlib import Path
from typing import Mapping
from IPython.display import Audio
import torchaudio, torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
import torchaudio.transforms as T
from torch import tensor
from fastcore.all import *
from sounds.hits import data
from sounds.hits.data import *
from torch.utils.data import DataLoader, RandomSampler
import fastcore.all as fc
from fastprogress import master_bar, progress_bar

In [None]:
#|default_exp hits.models

In [None]:
path = Path('../data')
items = (path/'0').ls() + (path/'1').ls()
labels = items.map(lambda x: int(x.parent.stem))

In [None]:
sr = 16_000
x_tfms = [lambda x: torchaudio.load(x)[0][0], T.Resample(new_freq=sr)]
y_tfms = [lambda x: tensor(x, dtype=torch.float32)]
dsets = random_split_dataset(data.TfmDataset(merge_items(items, labels), x_tfms, y_tfms))

In [None]:
dls = dataloaders(*dsets, batch_size=32)

In [None]:
xb,yb = next(iter(dls.train))

In [None]:
Audio(xb[0].numpy(), rate=sr)

In [None]:
# plot_spectrogram(a[0][0].numpy())

## Training

In [None]:
#|export
def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None: method(learn)

class with_cbs:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

def_device = 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    res = x.detach().cpu()
    return res.float() if res.dtype==torch.float16 else res

In [None]:
#|export
class Learner:
    def __init__(self, dls, model: nn.Module, loss_func=F.mse_loss, lr=0.1, opt_func=optim.SGD, cbs=None):
        self.model = model
        self.loss_func = loss_func
        self.dls = dls
        self.opt_func = opt_func
        self.cbs = fc.L(cbs)
    
    @with_cbs('fit')
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train: self.one_epoch(True)
            if valid: torch.no_grad()(self.one_epoch)(False)
    
    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        cbs = fc.L(cbs)
        for cb in cbs: self.cbs.append(cb)
        try:
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs: self.cbs.remove(cb)
    
    @with_cbs('epoch')  
    def _one_epoch(self):
        for self.iter,self.batch in enumerate(self.dl): self._one_batch()
            
    def one_epoch(self, training):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()
    
    @with_cbs('batch')     
    def _one_batch(self):
        self.predict()
        self.get_loss()
        if self.training:
            self.backward()
            self.step()
            self.zero_grad()
    
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()
    
    @property
    def training(self): return self.model.training
    
    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)

In [None]:
from copy import copy
from torcheval.metrics import MulticlassAccuracy,Mean

In [None]:
#|export

class Callback(): order = 0

class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)

    def after_batch(self, learn):
        x,y,*_ = to_cpu(learn.batch)
        for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

class DeviceCB(Callback):
    def __init__(self, device=def_device): fc.store_attr()
    def before_fit(self, learn):
        if hasattr(learn.model, 'to'): learn.model.to(self.device)
    def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)
    
class ProgressCB(Callback):
    order = MetricsCB.order+1
    def __init__(self, plot=False): self.plot = plot
    def before_fit(self, learn):
        learn.epochs = self.mbar = master_bar(learn.epochs)
        self.first = True
        if hasattr(learn, 'metrics'): learn.metrics._log = self._log
        self.losses = []
        self.val_losses = []

    def _log(self, d):
        if self.first:
            self.mbar.write(list(d), table=True)
            self.first = False
        self.mbar.write(list(d.values()), table=True)

    def before_epoch(self, learn): learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
    def after_batch(self, learn):
        learn.dl.comment = f'{learn.loss:.3f}'
        if self.plot and hasattr(learn, 'metrics') and learn.training:
            self.losses.append(learn.loss.item())
            if self.val_losses: self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])
    
    def after_epoch(self, learn): 
        if not learn.training:
            if self.plot and hasattr(learn, 'metrics'): 
                self.val_losses.append(learn.metrics.all_metrics['loss'].compute())
                self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch+1).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])


# Model

In [None]:
#|export
def conv1d(n_in, n_out, k_size=3, stride=2, act=nn.ReLU(), p=None, norm=False):
    res = [nn.Conv1d(n_in, n_out, k_size, stride, padding=2, bias=False)]
    if p is not None: res.append(nn.Dropout(p))
    if norm: res.append(nn.GroupNorm(1,n_out))
    if act is not None: res.append(act)
    return nn.Sequential(*res)

In [None]:
#|export
class ConvModel(nn.Module):
    # Feature extractor based on wav2vec http://arxiv.org/abs/1904.05862
    
    def __init__(
        self,
        conv_k_sizes=(10,8,8,4,4),
        conv_dims = (128,128,128,128,128),
        dropout=0.7,
        log_compression=True,
        skip_connections=True,
        residual_scale=0.5,
        act=nn.PReLU(),
    ):
        super().__init__()

        in_d = 1
        self.conv_layers = nn.ModuleList()
        for dim, k in zip(conv_dims,conv_k_sizes):
            self.conv_layers.append(conv1d(in_d, dim, k, k//2, act, p=dropout, norm=True))
            in_d = dim

        self.log_compression = log_compression
        self.skip_connections = skip_connections
        self.residual_scale = math.sqrt(residual_scale)

    def forward(self, x: torch.Tensor):
        x = x.unsqueeze(1)
        for conv in self.conv_layers:
            residual = x
            x = conv(x)
            if self.skip_connections and x.size(1) == residual.size(1):
                tsz = x.size(2)
                r_tsz = residual.size(2)
                residual = residual[..., :: r_tsz // tsz][..., :tsz]
                x = (x + residual) * self.residual_scale

        if self.log_compression: x = x.abs().log1p()
        return x

In [None]:
#|export
class AudioModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(ConvModel(),
                  nn.Flatten(),
                  nn.Linear(7808, 256), 
                  nn.Dropout(), 
                  nn.ReLU(), 
                  nn.Linear(256,1))
    def forward(self, x):
        return self.model(x).squeeze(-1)

In [None]:
from torcheval.metrics import BinaryAccuracy

In [None]:
model = AudioModel()
cbs = [DeviceCB(), MetricsCB(acc=BinaryAccuracy()), ProgressCB()]
learn = Learner(dls,model, F.binary_cross_entropy_with_logits, opt_func=optim.Adam, cbs=cbs)

In [None]:
learn.fit(3, lr=1e-2)

acc,loss,epoch,train
0.886,1.325,0,train
0.928,0.165,0,eval
0.982,0.049,1,train
0.94,0.126,1,eval
0.986,0.046,2,train
0.97,0.097,2,eval


In [None]:
torch.save(model, '../models/model.pth')

# Test

In [None]:
import librosa

In [None]:
model = torch.load('../models/model.pth')

In [None]:
path = Path('../data/thanos_message.wav')
sr = 16_000
max_length_s = 1.2 # seconds
max_l = int(max_length_s*sr)

In [None]:
s, sr = librosa.load(path, sr=sr)

In [None]:
frames = split_audio(s, sr, max_length_s, stride=0.8)
frames = tensor(frames)

In [None]:
n, xs = frames.shape
bs = 128
data = [e for e in frames[:(n//bs)*bs].view(-1,bs,xs)]+[frames[(n//bs)*bs:]]

In [None]:
detected = []
for b in data:
    with torch.no_grad():
        probs = F.sigmoid(model(to_device(b))).cpu()
    detected += [(v,p.item()) for v,p in zip(b,probs) if p>0.8]

In [None]:
L(detected).map(lambda x: x[1])

(#244) [0.988308310508728,0.9218431115150452,0.8049905896186829,0.933873176574707,0.9953827261924744,0.9419439435005188,0.8049966096878052,0.9319188594818115,0.9356750249862671,0.999231219291687...]