In [1]:
# Parameters
until_x = 12


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6054281350728627 0.44007506540843416
Epoch:  1


0.2727736730027843 0.19380945605891092
Epoch:  2


0.17558864765876048 0.1598533626113619
Epoch:  3


0.16863976217604973 0.14973929524421692
Epoch:  4


0.16478194457453652 0.14408890902996063
Epoch:  5


0.16304105721615456 0.1446259468793869
Epoch:  6


0.16009539325495023 0.13443399965763092
Epoch:  7


0.1586087410514419 0.13828631064721517
Epoch:  8


0.15578236370473295 0.1323447727731296
Epoch:  9


0.15457985167567795 0.13525321653911046
Epoch:  10


0.1556176478798325 0.13130017689296178
Epoch:  11


0.1554213854912165 0.12989413206066405
Epoch:  12


0.15334830654634013 0.13202126643487386
Epoch:  13


0.15419813748952504 0.1309652967112405
Epoch:  14


0.15346179217905612 0.13215162285736629
Epoch:  15


0.15307682994249705 0.132175974547863
Epoch:  16


0.15205801741496935 0.12851317652634212
Epoch:  17


0.15244331673995867 0.12737982720136642
Epoch:  18


0.1507236953522708 0.12882687151432037
Epoch:  19


0.15034351276384816 0.1294101785336222
Epoch:  20


0.15005258168723132 0.1317060717514583
Epoch:  21


0.15033515280968435 0.127898208796978
Epoch:  22


0.14883695583085757 0.1306373711143221
Epoch:  23


0.1499198961096841 0.12561932099717005
Epoch:  24


0.1503357581190161 0.1255275649683816
Epoch:  25


0.15050089882837758 0.1266453436442784
Epoch:  26


0.1481111512796299 0.12545087720666612
Epoch:  27


0.1498931642319705 0.1593907486115183
Epoch:  28


0.15030282413637316 0.1284727645771844
Epoch:  29


0.14870077291050474 0.12742890736886434
Epoch:  30


0.14907439294699076 0.1260977972831045
Epoch:  31


0.14749463909381144 0.12708141654729843
Epoch:  32


0.14699501725467476 0.12803080465112412
Epoch    32: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  33


0.14682565710029086 0.12235870957374573
Epoch:  34


0.14562018739210592 0.12185706411089216
Epoch:  35


0.14603593140035062 0.12149677638496671
Epoch:  36


0.1462021953350789 0.12130063559327807
Epoch:  37


0.14502547600784818 0.1212541578071458
Epoch:  38


0.1444305382870339 0.12109073890107018
Epoch:  39


0.14510290042774096 0.12142443124737058
Epoch:  40


0.1438431180006749 0.12118880344288689
Epoch:  41


0.1427815387377868 0.12134509001459394
Epoch:  42


0.14349490444402438 0.12113134669406074
Epoch:  43


0.1442774409377897 0.12118918235812869
Epoch:  44


0.14436796146470146 0.12138437799045018
Epoch    44: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  45


0.14426072707047333 0.12141705730131694
Epoch:  46


0.14369942489508036 0.12123779633215495
Epoch:  47


0.1441567222008834 0.12087457946368627
Epoch:  48


0.14412047170303963 0.12102554951395307
Epoch:  49


0.1443413532263524 0.12068257587296623
Epoch:  50


0.1440759478388606 0.1207289525440761
Epoch:  51


0.14312602176859573 0.12075649521180562
Epoch:  52


0.14279930051919576 0.12052539416721889
Epoch:  53


0.14349928538541537 0.1206759512424469
Epoch:  54


0.14325904805917997 0.12069602310657501
Epoch:  55


0.1433217134830114 0.12075392050402504
Epoch:  56


0.14295020538407402 0.12044055121285575
Epoch:  57


0.1438802500834336 0.12066077866724559
Epoch:  58


0.14332569893952962 0.12040141757045474
Epoch:  59


0.14398085144726006 0.12070105757032122
Epoch:  60


0.1428023608955177 0.12072488239833287
Epoch:  61


0.14240809932753845 0.12037430597203118
Epoch:  62


0.14164065388408867 0.12056072162730354
Epoch:  63


0.14177765959017985 0.12062932870217732
Epoch:  64


0.1452480244475442 0.1204384086387498
Epoch:  65


0.1436282581574208 0.12045469986540931
Epoch:  66


0.14319113984301285 0.12079111273799624
Epoch:  67


0.14320903855401115 0.12049497451101031
Epoch    67: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  68


0.14360911499809575 0.12052408712250846
Epoch:  69


0.14276948329564687 0.12064566143921443
Epoch:  70


0.14291576557868235 0.12060591259172984
Epoch:  71


0.14389916730893626 0.1204485542007855
Epoch:  72


0.1429041853627643 0.12047013427530016
Epoch:  73


0.14289117463537165 0.12047682915415082
Epoch    73: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  74


0.14354939154676488 0.12074348969118935
Epoch:  75


0.14302445907850522 0.12034792133740016
Epoch:  76


0.14275920874363668 0.12055550941399165
Epoch:  77


0.14411851360991196 0.12060287807668958
Epoch:  78


0.14336542503253832 0.12036888194935662
Epoch:  79


0.14235361886991038 0.12057970677103315
Epoch:  80


0.14249367246756683 0.12069697997399739
Epoch:  81


0.14320714288466685 0.12076922293220248
Epoch    81: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  82


0.14210737154290481 0.12058349485908236
Epoch:  83


0.1433732038414156 0.12061957057033267
Epoch:  84


0.14371921444261396 0.12039777849401746
Epoch:  85


0.1433004724818307 0.12041679556880679
Epoch:  86


0.14226993314317754 0.12074991315603256
Epoch:  87


0.14462709668520335 0.12075355648994446
Epoch:  88


0.14412445393768517 0.12060724198818207
Epoch:  89


0.14325453582647685 0.12036857860428947
Epoch:  90


0.1434785320952132 0.12050749680825643
Epoch:  91


0.14281253073666547 0.12054315741573061
Epoch:  92


0.1442434968980583 0.12076765085969653
Epoch:  93


0.1429920111959045 0.12054759051118578
Epoch:  94


0.1427245538782429 0.12054330110549927
Epoch:  95


0.143244388941172 0.12059851948704038
Epoch:  96


0.14279785309288953 0.12056450758661542
Epoch:  97


0.14390221520050153 0.12042991391250066
Epoch:  98


0.14227267978964625 0.12054592477423805
Epoch:  99


0.1431135402337925 0.12056953779288701
