In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 5


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.632320850282102 0.5045398431164878
Epoch:  1


0.3273960017674678 0.19559329535279954
Epoch:  2


0.18907777924795408 0.16418684806142533
Epoch:  3


0.17724887342066378 0.16467522723334177
Epoch:  4


0.17146071831922274 0.16073731439454214
Epoch:  5


0.17030857463140744 0.1596076978104455
Epoch:  6


0.16735472872450546 0.150810284273965
Epoch:  7


0.1670508231665637 0.15296080069882528
Epoch:  8


0.16315299514177684 0.15596179238387517
Epoch:  9


0.1641501607121648 0.14381276390382222
Epoch:  10


0.163403551723506 0.1421619345034872
Epoch:  11


0.1610866064155424 0.1436224611742156
Epoch:  12


0.1610564565336382 0.14048796040671213
Epoch:  13


0.16060865730852694 0.13842494360038213
Epoch:  14


0.15968547318432783 0.13835788731064116
Epoch:  15


0.15966107716431488 0.14687093879495347
Epoch:  16


0.15817846680009687 0.13769724113600595
Epoch:  17


0.15811604543312177 0.1381825762135642
Epoch:  18


0.15764420177485491 0.13519214625869477
Epoch:  19


0.15729724716495824 0.1368390821984836
Epoch:  20


0.15793463265573657 0.1390912596668516
Epoch:  21


0.15676342474447713 0.13270719136510575
Epoch:  22


0.15555874036776052 0.13052824352468764
Epoch:  23


0.15402852764000763 0.12940695668969834
Epoch:  24


0.15515135913281827 0.13611033984592982
Epoch:  25


0.1537709075051385 0.1288115531206131
Epoch:  26


0.15493725602691238 0.1286552090729986
Epoch:  27


0.15369367035659584 0.13323165263448442
Epoch:  28


0.15389729633524613 0.1286602754678045
Epoch:  29


0.15500696526991353 0.1281371031488691
Epoch:  30


0.15275599384630048 0.1280520898955209
Epoch:  31


0.15418212156038028 0.1335630076272147
Epoch:  32


0.15358650241349195 0.13148655742406845
Epoch:  33


0.15314314534535278 0.13012469985655375
Epoch:  34


0.15235992621731115 0.12962742043393
Epoch:  35


0.15145953120412053 0.12706391087600163
Epoch:  36


0.15238485626272252 0.13010619474308832
Epoch:  37


0.15176802108416687 0.1323470420071057
Epoch:  38


0.1509215380694415 0.12741758674383163
Epoch:  39


0.15075900909062978 0.12664852610656194
Epoch:  40


0.15089984479788188 0.1295751311949321
Epoch:  41


0.15008853013451035 0.12923524422304972
Epoch:  42


0.150497312884073 0.13127123351608003
Epoch:  43


0.15021727173715024 0.12696275966508047
Epoch:  44


0.15043833086619507 0.13083018788269588
Epoch:  45


0.150567660460601 0.1334210176553045
Epoch    45: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  46


0.1511098871359954 0.12590683038745606
Epoch:  47


0.14859250348967476 0.12482423761061259
Epoch:  48


0.1478515923828692 0.12429438531398773
Epoch:  49


0.1478810839959093 0.1256997904607228
Epoch:  50


0.14747090798777504 0.13080441313130514
Epoch:  51


0.14743868523352854 0.13084521889686584
Epoch:  52


0.1470673213133941 0.12503225888524735
Epoch:  53


0.14558489419318535 0.12631457831178391
Epoch:  54


0.14657618790059476 0.12397377086537224
Epoch:  55


0.14708354988613645 0.12482888996601105
Epoch:  56


0.14652974702216484 0.12394401324646813
Epoch:  57


0.1457562160653037 0.1243415315236364
Epoch:  58


0.14536111056804657 0.12397528120449611
Epoch:  59


0.14562117771522418 0.12427159292357308
Epoch:  60


0.14598029771366636 0.1239529720374516
Epoch:  61


0.14649680941491514 0.12421674707106181
Epoch:  62


0.14504140333549395 0.12332552884306226
Epoch:  63


0.1454813875056602 0.12379335718495506
Epoch:  64


0.14463054288077998 0.12416037917137146
Epoch:  65


0.14420525326922135 0.12427998121295657
Epoch:  66


0.14507392975124153 0.12442340701818466
Epoch:  67


0.14570241481871218 0.12827398840870177
Epoch:  68


0.14452957865354177 0.12556511802332743
Epoch    68: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  69


0.14475813065026258 0.12407748188291277
Epoch:  70


0.14416314500409202 0.1256020345858165
Epoch:  71


0.14527315787366918 0.12553443972553527
Epoch:  72


0.14458073836726112 0.12839135208300181
Epoch:  73


0.14303722977638245 0.12572755771023886
Epoch:  74


0.14418870814748713 0.1260424086025783
Epoch    74: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  75


0.14464044772289894 0.12491630656378609
Epoch:  76


0.14344324453456983 0.12333716452121735
Epoch:  77


0.14379589501264933 0.12379065049546105
Epoch:  78


0.14369631659340215 0.12458890250750951
Epoch:  79


0.14593795586276698 0.12766845737184798
Epoch:  80


0.14486113392017982 0.12372833703245435
Epoch    80: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  81


0.1447714277215906 0.12716479918786458
Epoch:  82


0.14490614025979429 0.1277762587581362
Epoch:  83


0.14397176053072955 0.12808956844466074
Epoch:  84


0.1438411192314045 0.12591216287442616
Epoch:  85


0.14351237182681625 0.12494482738631112
Epoch:  86


0.14526451479744268 0.1250943605388914
Epoch    86: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  87
