In [1]:
# Parameters
until_x = 11


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]


def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()


class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 96, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 96, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 96, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
# Instantiate the model
model = Task5Model(31).to(device)

In [12]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [13]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 15:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6573796582221985 0.5871642827987671
Epoch:  1


0.44100417375564577 0.2942626476287842
Epoch:  2


0.23199626326560974 0.19117619693279267
Epoch:  3


0.18669704616069793 0.16937430799007416
Epoch:  4


0.17857955157756805 0.17632395327091216
Epoch:  5


0.17660362184047698 0.16670210361480714
Epoch:  6


0.17520296990871428 0.16421490013599396
Epoch:  7


0.17220871806144714 0.17440535128116608
Epoch:  8


0.1721055829524994 0.1697957545518875
Epoch:  9


0.17037399709224701 0.15950251817703248
Epoch:  10


0.16879774868488312 0.15706695914268493
Epoch:  11


0.1680523121356964 0.15775453746318818
Epoch:  12


0.16567965745925903 0.15463703870773315
Epoch:  13


0.16599717378616333 0.14588253498077391
Epoch:  14


0.16519246220588685 0.14541817009449004
Epoch:  15


0.16360439479351044 0.1409922033548355
Epoch:  16


0.1630954897403717 0.13958994448184966
Epoch:  17


0.16307302951812744 0.1416044145822525
Epoch:  18


0.16190525650978088 0.1414241909980774
Epoch:  19


0.15965012907981874 0.1395222544670105
Epoch:  20


0.1604899126291275 0.13807631134986878
Epoch:  21


0.16100211083889007 0.14336706101894378
Epoch:  22


0.1597912847995758 0.1360596239566803
Epoch:  23


0.1593809622526169 0.140662282705307
Epoch:  24


0.1593678629398346 0.13471782803535462
Epoch:  25


0.15971291482448577 0.14064353704452515
Epoch:  26


0.1589425188302994 0.1360001802444458
Epoch:  27


0.15861113607883454 0.13700391799211503
Epoch:  28


0.15741988360881806 0.13603025376796724
Epoch:  29


0.15784075319767 0.14404643177986146
Epoch:  30


0.15726967930793762 0.13228271454572677
Epoch:  31


0.15533959805965425 0.1364932417869568
Epoch:  32


0.156905180811882 0.13155368119478225
Epoch:  33


0.15677837252616883 0.13386076986789702
Epoch:  34


0.15567866027355193 0.1356778621673584
Epoch:  35


0.15597333312034606 0.13413832038640977
Epoch:  36


0.15765925347805024 0.13338450640439986
Epoch:  37


0.1554953783750534 0.13378492891788482
Epoch:  38


0.15518288791179657 0.13554245978593826
Epoch    38: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  39


0.15364849269390107 0.12821654230356216
Epoch:  40


0.15250653445720672 0.1277000829577446
Epoch:  41


0.15098841071128846 0.12814818024635316
Epoch:  42


0.15249750196933745 0.12721634954214095
Epoch:  43


0.1524213445186615 0.12754081487655639
Epoch:  44


0.15206636726856232 0.12727797478437425
Epoch:  45


0.1515219908952713 0.12698612213134766
Epoch:  46


0.15177830815315246 0.1270136296749115
Epoch:  47


0.15096842050552367 0.1267359271645546
Epoch:  48


0.1511205893754959 0.12695482671260833
Epoch:  49


0.15198862969875335 0.12748500257730483
Epoch:  50


0.15078220784664154 0.12733033150434495
Epoch:  51


0.15150575697422028 0.12715983688831328
Epoch:  52


0.15019946098327636 0.12665632367134094
Epoch:  53


0.15107200145721436 0.1272321820259094
Epoch:  54


0.15029142081737518 0.12661171555519105
Epoch:  55


0.15200840592384338 0.1269172713160515
Epoch:  56


0.15123094499111175 0.12693236917257308
Epoch:  57


0.15231478035449983 0.12738105952739714
Epoch:  58


0.1502407443523407 0.12700539082288742
Epoch:  59


0.15216370284557343 0.12737804800271987
Epoch:  60


0.15093180119991303 0.12791045904159545
Epoch    60: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  61


0.14877516627311707 0.12735113650560378
Epoch:  62


0.15117193818092345 0.12744565457105636
Epoch:  63


0.1502450329065323 0.126961287856102
Epoch:  64


0.15130356073379517 0.12698195725679398
Epoch:  65


0.15038065671920775 0.12651017755270005
Epoch:  66


0.1505451852083206 0.12673879265785218
Epoch:  67


0.1502487391233444 0.12662488371133804
Epoch:  68


0.1497390967607498 0.12660699635744094
Epoch:  69


0.15064841985702515 0.12653855383396148
Epoch:  70


0.15186129808425902 0.12628263980150223
Epoch:  71


0.14945589900016784 0.1262471005320549
Epoch:  72


0.15021683514118195 0.12630759477615355
Epoch:  73


0.1480940216779709 0.1264030009508133
Epoch:  74


0.15010285377502441 0.12601725310087203
Epoch:  75


0.14890552282333375 0.12637526392936707
Epoch:  76


0.1502268707752228 0.12653405964374542
Epoch:  77


0.14963542222976683 0.12652814090251924
Epoch:  78


0.15018563449382782 0.12659180760383607
Epoch:  79


0.1499471426010132 0.12645024359226226
Epoch:  80


0.1514018452167511 0.1263727441430092
Epoch    80: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  81


0.14963926434516905 0.12644476741552352
Epoch:  82


0.1504477071762085 0.12615735828876495
Epoch:  83


0.1498405557870865 0.12626439183950425
Epoch:  84


0.15016643822193146 0.12642237693071365
Epoch:  85


0.15054866313934326 0.12643856704235076
Epoch:  86


0.15118359088897704 0.12609044313430787
Epoch    86: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  87


0.15139950633049012 0.12609817683696747
Epoch:  88


0.1505480593442917 0.12634956538677217
Epoch:  89
