## Need to test out SNR on augerino to get solid width

In [75]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *
from dataset import *

In [112]:
# torch layer:
class GaussianNoiseAug(nn.Module):
    """
    Differetiable Gaussian noise injection
    """
    def __init__(self):
        super().__init__()
        self.aug=True
        self.log_lims = nn.Parameter(torch.tensor([0., 1.]))

    @property
    def lims(self):
        return F.sigmoid(self.log_lims) * 2 - 1

    def forward(self, x):
        bs = x.shape[0]
        sample_length = x.shape[1]

        g_noise = np.random.randn(sample_length)
        snr = torch.rand(bs, device=self.lims.device) * (self.lims[1] - self.lims[0]) + self.lims[0]

        noise_power = np.mean(np.power(g_noise, 2))
        sig_power = torch.mean(torch.pow(x, 2))

        snr_linear = 10**(snr / 10.0)
        noise_factor = torch.sqrt( (sig_power / noise_power) * (1 / snr_linear) )

        noise_factor = noise_factor.unsqueeze(dim=1).expand(bs, sample_length)
        g_noise = torch.from_numpy(g_noise)
        vals = noise_factor * g_noise
        
        return torch.add(x, vals)
    

In [153]:
class PitchShiftAug(nn.Module):
    """
    Differetiable pitch shift
    """
    def __init__(self):
        super().__init__()
        self.aug=True
        self.log_lims = nn.Parameter(torch.tensor([0., 5.]))

    @property
    def lims(self):
        return F.sigmoid(self.log_lims) * 2 - 1

    def forward(self, x):
        bs = x.shape[0]
        out = np.zeros(x.shape)
        factor = torch.rand(bs, device=self.lims.device) * (self.lims[1] - self.lims[0]) + self.lims[0]
        for index, row in enumerate(x):
            out[index] = librosa.effects.pitch_shift(row.numpy(), BASE_SAMPLE_RATE, (1 + factor[index]))
        return torch.from_numpy(out)

In [154]:
r = torch.rand(18, 40000)
model = PitchShiftAug()
res = model(r)
print(res)

(18, 40000)
tensor([[ 0.8393,  0.4487,  0.5009,  ...,  0.4946,  0.4023,  0.0000],
        [ 0.9275,  0.8033, -0.0428,  ...,  0.5360,  0.5367,  0.8003],
        [ 0.5906,  0.9398,  0.5518,  ...,  0.4537,  0.2921,  0.4374],
        ...,
        [ 0.5391,  0.4780,  0.1669,  ...,  0.4785,  0.4414,  0.2954],
        [ 0.2867,  0.5631,  0.4280,  ...,  0.3832,  0.3686,  0.0000],
        [ 0.2516,  0.0224,  0.0853,  ...,  0.1919,  0.4658,  0.3742]],
       dtype=torch.float64)


In [70]:
args
print('Loading dataset')
GTZAN, data_count = load_wave_data(
            'data',
            aug_params=aug_params,
            segmented=self.model_config.segmented,
            is_pre_augmented=self.model_config.pre_augment,
            is_local=self.args.local)
loaders = {
    'train': torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    ),
    'test': torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
}
num_classes = max(train_set.targets) + 1

print('Preparing model')
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
aug = nn.Sequential(ContrastAug(), BrightnessAug(), 
                    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
model = AugAveragedModel(model, aug, ncopies=1)
model.cuda()


criterion = F.cross_entropy
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=args.lr_init,
    momentum=args.momentum,
    weight_decay=args.wd
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

start_epoch = 0
if args.resume is not None:
    print('Resume training from %s' % args.resume)
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']

utils.save_checkpoint(
    args.dir,
    start_epoch,
    state_dict=model.state_dict(),
    optimizer=optimizer.state_dict()
)

for epoch in range(start_epoch, args.epochs):
    time_ep = time.time()

    train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer, aug_reg=args.aug_reg)
    if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
        test_res = utils.eval(loaders['test'], model, criterion)
    else:
        test_res = {'loss': None, 'accuracy': None}

    lr = optimizer.param_groups[0]['lr']
    print("Brightness", model.aug[0].lims)
    print("Contrast", model.aug[1].lims)
    scheduler.step()

    if (epoch + 1) % args.save_freq == 0:
        utils.save_checkpoint(
            args.dir,
            epoch + 1,
            state_dict=model.state_dict(),
            optimizer=optimizer.state_dict()
        )

    time_ep = time.time() - time_ep
    values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], time_ep]
    table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
    if epoch % 40 == 0:
        table = table.split('\n')
        table = '\n'.join([table[1]] + table)
    else:
        table = table.split('\n')[2]
    print(table)

if args.epochs % args.save_freq != 0:
    utils.save_checkpoint(
        args.dir,
        args.epochs,
        state_dict=model.state_dict(),
        optimizer=optimizer.state_dict()
    )

0.4436741488860185