# Model training

In [None]:
from pathlib import Path
from IPython.display import Audio
import torchaudio
import torchaudio.functional as taF
import torch.nn.functional as F
import torchaudio.transforms as T
from torch import tensor
from sounds.hits import data as D
from sounds.hits.data import *
from torch.utils.data import DataLoader, RandomSampler
import numpy as np

In [None]:
#|export
from typing import Mapping
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
import math
import fastcore.all as fc
from fastprogress import master_bar, progress_bar

In [None]:
#|default_exp hits.models

In [None]:
sr = 16_000
xs = np.load('../data/train/dataset_x.npy')
ys = np.load('../data/train/dataset_y.npy')
valid_xs = np.load('../data/valid/dataset_x.npy')
valid_ys = np.load('../data/valid/dataset_y.npy')

In [None]:
def augment_data(audio):
    if np.random.rand() < 0.5: 
        audio = T.Vol(gain=1.5)(audio)
    if np.random.rand() < 0.5: 
        audio = taF.add_noise(audio, torch.randn_like(audio)*0.5, snr = torch.randint(12, 30, ()))
    return audio

In [None]:
from functools import partial

x_tfms = [tensor, T.Resample(new_freq=sr), augment_data]
y_tfms = [partial(tensor, dtype=torch.float32)]
train = D.TfmDataset(merge_items(xs, ys), x_tfms, y_tfms)
test = D.TfmDataset(merge_items(valid_xs, valid_ys), x_tfms, y_tfms)

In [None]:
dls = dataloaders(train, test, batch_size=32, shuffle=True)

In [None]:
xb,yb = next(iter(dls.train))

In [None]:
Audio(xb[1].numpy(), rate=sr)

# Model

In [None]:
#|export
def conv1d(n_in, n_out, k_size=3, stride=2, act=nn.ReLU(), p=None, norm=False):
    res = [nn.Conv1d(n_in, n_out, k_size, stride, padding=2, bias=False)]
    if p is not None: res.append(nn.Dropout(p))
    if norm: res.append(nn.GroupNorm(1,n_out))
    if act is not None: res.append(act)
    return nn.Sequential(*res)

In [None]:
#|export
class ConvModel(nn.Module):
    # Feature extractor based on wav2vec http://arxiv.org/abs/1904.05862
    
    def __init__(
        self,
        conv_k_sizes=(10,8,8,4,4),
        conv_dims = (32,32,32,32,32),
        dropout=0.7,
        log_compression=True,
        skip_connections=True,
        residual_scale=0.5,
        act=nn.PReLU(),
    ):
        super().__init__()

        in_d = 1
        self.conv_layers = nn.ModuleList()
        for dim, k in zip(conv_dims,conv_k_sizes):
            self.conv_layers.append(conv1d(in_d, dim, k, k//2, act, p=dropout, norm=True))
            in_d = dim

        self.log_compression = log_compression
        self.skip_connections = skip_connections
        self.residual_scale = math.sqrt(residual_scale)

    def forward(self, x: torch.Tensor):
        x = x.unsqueeze(1)
        for conv in self.conv_layers:
            residual = x
            x = conv(x)
            if self.skip_connections and x.size(1) == residual.size(1):
                tsz = x.size(2)
                r_tsz = residual.size(2)
                residual = residual[..., :: r_tsz // tsz][..., :tsz]
                x = (x + residual) * self.residual_scale

        if self.log_compression: x = x.abs().log1p()
        return x

In [None]:
#|export
class AudioModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(ConvModel(),
                  nn.Flatten(),
                  nn.ReLU(), 
                  nn.Dropout(), 
                  nn.Linear(2432, 1))
    def forward(self, x):
        return self.model(x).squeeze(-1)

In [None]:
from torcheval.metrics import BinaryAccuracy
from sounds.learner import *

In [None]:
model = AudioModel()
cbs = [DeviceCB(), MetricsCB(acc=BinaryAccuracy()), ProgressCB()]
learn = Learner(dls,model, F.binary_cross_entropy_with_logits, opt_func=optim.Adam, cbs=cbs)

In [None]:
learn.fit(5, lr=1e-4)

acc,loss,epoch,train
0.53,0.592,0,train
0.723,0.418,0,eval
0.685,0.45,1,train
0.949,0.283,1,eval
0.88,0.355,2,train
0.943,0.226,2,eval
0.923,0.296,3,train
0.93,0.21,3,eval
0.942,0.254,4,train
0.93,0.195,4,eval


In [None]:
torch.save(model, '../models/model.pth')

# Test

In [None]:
import librosa

In [None]:
model = torch.load('../models/model.pth')

In [None]:
path = Path('../data/thanos_message.wav')
sr = 16_000
max_length_s = 1.5 # seconds
max_l = int(max_length_s*sr)

In [None]:
s, sr = librosa.load(path, sr=sr)

In [None]:
frames = split_audio(s, sr, max_length_s, stride=0.9)
dl = DataLoader(D.TfmDataset(merge_items(frames, frames), x_tfms), batch_size=512)

In [None]:
detected = []
for xb, yb in progress_bar(dl):
    with torch.no_grad():
        probs = F.sigmoid(model(to_device(xb))).cpu()
    detected += [(v,p.item()) for v,p in zip(yb,probs) if p>=0.95]

In [None]:
fc.L(detected).map(lambda x: x[1])

(#21) [0.9517852663993835,0.9554659724235535,0.9521438479423523,0.953728973865509,0.9534500241279602,0.9508523941040039,0.950554370880127,0.9516769647598267,0.9584924578666687,0.9532450437545776...]

In [None]:
len(detected)

21

In [None]:
Audio(torch.cat(list(fc.L(detected).map(lambda x: x[0]))), rate=sr)