In [1]:
# Parameters
until_x = 11


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.605525222179052 0.419086400951658
Epoch:  1


0.27238961087690816 0.18917773025376455
Epoch:  2


0.17612618530118787 0.18400304445198604
Epoch:  3


0.16857081933601484 0.14740507623979024
Epoch:  4


0.1632860774123991 0.14183673049722398
Epoch:  5


0.15897365599065214 0.13790873757430486
Epoch:  6


0.15874695898713292 0.133261824292796
Epoch:  7


0.1579408375798045 0.1364217453769275
Epoch:  8


0.15644721686840057 0.1331552192568779
Epoch:  9


0.15456248819828033 0.13357102764504297
Epoch:  10


0.15367567901675766 0.13149378022977284
Epoch:  11


0.15322063541090167 0.15054039018494741
Epoch:  12


0.15368466038961667 0.13784470515591757
Epoch:  13


0.1535441424395587 0.12967385138784135
Epoch:  14


0.15210765157196973 0.1264821578349386
Epoch:  15


0.15119879108828468 0.13678020132439478
Epoch:  16


0.15168600428748774 0.12768382791961944
Epoch:  17


0.15084038956745252 0.12872731579201563
Epoch:  18


0.1505330563397021 0.1318497434258461
Epoch:  19


0.15170737174717155 0.1263251251408032
Epoch:  20


0.1508278423869932 0.12977759220770427
Epoch:  21


0.15047034540691892 0.13061193376779556
Epoch:  22


0.14899537128371163 0.12538239998476847
Epoch:  23


0.14849345466575106 0.12669298691408976
Epoch:  24


0.1479934582839141 0.12652855260031565
Epoch:  25


0.14831957422398231 0.12939845025539398
Epoch:  26


0.14785860317784386 0.12809589079448155
Epoch:  27


0.14771204139735247 0.12595948044742858
Epoch:  28


0.14796940136600184 0.12972843327692576
Epoch    28: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  29


0.1457162310142775 0.12314704805612564
Epoch:  30


0.14544545154313784 0.12231022651706423
Epoch:  31


0.1450396333191846 0.1225813552737236
Epoch:  32


0.14341757748578046 0.12192315076078687
Epoch:  33


0.1435491201039907 0.12222588913781303
Epoch:  34


0.14413523190730326 0.12209547523941312
Epoch:  35


0.14340705323863673 0.12150584906339645
Epoch:  36


0.14377428248927399 0.12153600369180952
Epoch:  37


0.14396896958351135 0.121783618416105
Epoch:  38


0.14262851669981674 0.12233369903905052
Epoch:  39


0.1426356801310101 0.12189895233937673
Epoch:  40


0.14367161006540866 0.12171323703868049
Epoch:  41


0.14385245418226397 0.12227184006146022
Epoch    41: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  42


0.14292753548235507 0.12166907319000789
Epoch:  43


0.14260550890419935 0.12182438905750002
Epoch:  44


0.14302126660540299 0.12183068266936711
Epoch:  45


0.14307681934253588 0.12169075225080762
Epoch:  46


0.14299517668582298 0.1218553323830877
Epoch:  47


0.14325867593288422 0.12162699018205915
Epoch    47: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  48


0.1422309114320858 0.12166837070669446
Epoch:  49


0.14320449007524028 0.12177323124238423
Epoch:  50


0.1427824505277582 0.12156311422586441
Epoch:  51


0.14191236528190407 0.12228686256068093
Epoch:  52


0.14287577126477216 0.12181589752435684
Epoch:  53


0.14277794755793907 0.12176420220306941
Epoch    53: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  54


0.14343519831025922 0.12153003364801407
Epoch:  55


0.14295790968714533 0.12178719150168556
Epoch:  56


0.1430104580279943 0.12183466979435512
Epoch:  57


0.14163006479675705 0.12151090055704117
Epoch:  58


0.14311298727989197 0.12147656721728188
Epoch:  59


0.1427300354918918 0.12179561172212873
Epoch:  60


0.1432980528554401 0.12175915922437396
Epoch:  61


0.1415880495632017 0.12176055248294558
Epoch:  62


0.1423478565506033 0.12162655698401588
Epoch:  63


0.14301593641977053 0.12170002396617617
Epoch:  64


0.143192428189355 0.12169954074280602
Epoch    64: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  65


0.1422005557530635 0.12167723796197347
Epoch:  66


0.14317456733536077 0.12162962662322181
Epoch:  67


0.1426800847858996 0.12190797179937363
Epoch:  68


0.14258766738144127 0.12179042505366462
Epoch:  69


0.14243619989704442 0.1216541656426021
Epoch:  70


0.1421404690355868 0.12186821763004575
Epoch:  71


0.14438286705597028 0.12140935978719167
Epoch:  72


0.1427187279269502 0.12147813077483859
Epoch:  73


0.142383649542525 0.12160139211586543
Epoch:  74


0.14313310505570592 0.1214622323002134
Epoch:  75


0.1428391353504078 0.12168090684073311
Epoch:  76


0.14249590925268224 0.12188697393451418
Epoch:  77


0.14306798617582064 0.12144529074430466
Epoch:  78


0.1429334959468326 0.12149521069867271
Epoch:  79


0.14246070304432432 0.12183865691934313
Epoch:  80


0.14252542120379372 0.12177020630666188
Epoch:  81


0.14201927064238368 0.12161433058125633
Epoch:  82


0.14394210560901746 0.12149254871266228
Epoch:  83


0.14250620071952408 0.1216212904879025
Epoch:  84


0.1425529275391553 0.12173158888305936
Epoch:  85


0.14281483357017105 0.12157445613827024
Epoch:  86


0.14241450138994166 0.12176202663353511
Epoch:  87


0.14381199430775 0.12180211820772716
Epoch:  88


0.14297845033374992 0.12148999209914889
Epoch:  89


0.1444406602028254 0.121707341500691
Epoch:  90


0.1427319291475657 0.12169059472424644
Epoch:  91


0.14253622414292516 0.1214405255658286
Epoch:  92


0.1426057251724037 0.12157793662377767
Epoch:  93


0.14404086808900576 0.12155018853289741
Epoch:  94


0.14193868475991325 0.12168099411896297
Epoch:  95


0.14246135486944303 0.12180105809654508
Epoch:  96
