In [1]:
# Parameters
until_x = 6


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]


def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()


class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 96, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 96, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 96, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
# Instantiate the model
model = Task5Model(31).to(device)

In [12]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [13]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 15:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.649959762096405 0.5966796517372132
Epoch:  1


0.4239409959316254 0.2875231564044952
Epoch:  2


0.222657208442688 0.1888997882604599
Epoch:  3


0.18192430794239045 0.5959265947341919
Epoch:  4


0.17134532153606416 0.1543442040681839
Epoch:  5


0.16831530690193175 0.15229724049568177
Epoch:  6


0.1646013629436493 0.1455976754426956
Epoch:  7


0.1630277991294861 0.14050363302230834
Epoch:  8


0.1622041267156601 0.14209021031856536
Epoch:  9


0.15961744129657746 0.1370520293712616
Epoch:  10


0.15955488085746766 0.14225030541419983
Epoch:  11


0.15945787847042084 0.13418609499931336
Epoch:  12


0.15837912261486053 0.14610586762428285
Epoch:  13


0.15690551400184632 0.13808976858854294
Epoch:  14


0.15675989985466005 0.1397963285446167
Epoch:  15


0.1549801617860794 0.1320664867758751
Epoch:  16


0.15659677386283874 0.13165130317211152
Epoch:  17


0.15450064539909364 0.13486450910568237
Epoch:  18


0.15460721909999847 0.13491789102554322
Epoch:  19


0.15413872241973878 0.13373906165361404
Epoch:  20


0.15483469963073732 0.13308287262916565
Epoch:  21


0.1536756706237793 0.13006315976381302
Epoch:  22


0.15380186915397645 0.13132989257574082
Epoch:  23


0.15309106767177583 0.13067835122346877
Epoch:  24


0.1535819309949875 0.1286262422800064
Epoch:  25


0.15321243941783905 0.13475569635629653
Epoch:  26


0.1525271040201187 0.1312226802110672
Epoch:  27


0.1520086407661438 0.1302150309085846
Epoch:  28


0.151662962436676 0.1297777220606804
Epoch:  29


0.15030689239501954 0.1293005168437958
Epoch:  30


0.1505553287267685 0.12704911082983017
Epoch:  31


0.15063648641109467 0.13379756808280946
Epoch:  32


0.15113854944705962 0.1276216596364975
Epoch:  33


0.15045201718807222 0.12619773149490357
Epoch:  34


0.14944968521595 0.1277695804834366
Epoch:  35


0.15021296381950378 0.12601051479578018
Epoch:  36


0.14979396045207977 0.1280433714389801
Epoch:  37


0.14984191238880157 0.1278060033917427
Epoch:  38


0.1490360814332962 0.12791645973920823
Epoch:  39


0.14970313549041747 0.12687293738126754
Epoch:  40


0.14894254446029664 0.1288110211491585
Epoch:  41


0.14788398861885071 0.1275141268968582
Epoch    41: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  42


0.1482547092437744 0.12379605621099472
Epoch:  43


0.1455703717470169 0.12393564283847809
Epoch:  44


0.14630764544010164 0.12368223667144776
Epoch:  45


0.14597957253456115 0.12322979420423508
Epoch:  46


0.14653838396072388 0.12376777529716491
Epoch:  47


0.14624236643314362 0.12364142388105392
Epoch:  48


0.14495065867900847 0.12334907650947571
Epoch:  49


0.14527494609355926 0.12384555786848069
Epoch:  50


0.14624543368816376 0.12399297952651978
Epoch:  51


0.145425066947937 0.12307642549276351
Epoch:  52


0.14506561636924745 0.12337618619203568
Epoch:  53


0.1445257580280304 0.12325832396745681
Epoch:  54


0.14492532312870027 0.1232464775443077
Epoch:  55


0.14565094113349913 0.12364208847284316
Epoch:  56


0.1450740510225296 0.12392875403165818
Epoch:  57


0.1445484507083893 0.12317905873060227
Epoch    57: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  58


0.1441653299331665 0.12322867959737778
Epoch:  59


0.14522033393383027 0.1229033574461937
Epoch:  60


0.1435594654083252 0.12294123023748398
Epoch:  61


0.1438668304681778 0.12277159690856934
Epoch:  62


0.14589782536029816 0.12297784239053726
Epoch:  63


0.14528716921806337 0.12299241721630097
Epoch:  64


0.14370645940303803 0.12293245494365693
Epoch:  65


0.1446381038427353 0.12291350215673447
Epoch:  66


0.14501454412937165 0.12298314273357391
Epoch:  67


0.14458676397800446 0.12287571281194687
Epoch    67: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  68


0.14487191081047057 0.12307467609643936
Epoch:  69


0.14411366522312163 0.12293485850095749
Epoch:  70


0.14433532416820527 0.12309885323047638
Epoch:  71


0.14560889303684235 0.12315445244312287
Epoch:  72


0.14501870572566986 0.12332547754049301
Epoch:  73


0.14462798357009887 0.12310681194067001
Epoch    73: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  74


0.14548887252807619 0.12295255959033966
Epoch:  75


0.1443530476093292 0.12292411774396897
Epoch:  76


0.14440439105033875 0.12276145815849304
Epoch:  77


0.14467704117298127 0.12297793179750442
Epoch:  78


0.1443518477678299 0.12334464639425277
Epoch:  79


0.1446109867095947 0.1229987159371376
Epoch    79: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  80


0.14507648408412932 0.12303566038608552
Epoch:  81


0.1446985524892807 0.12322977930307388
Epoch:  82


0.14396515130996704 0.12325798869132995
Epoch:  83


0.14443361580371858 0.1232581153512001
Epoch:  84


0.14544577598571778 0.12312759757041931
Epoch:  85


0.1444365417957306 0.12318299561738968
Epoch:  86


0.14469388723373414 0.12321760505437851
Epoch:  87


0.1438256937265396 0.12287847846746444
Epoch:  88


0.14417918384075165 0.12318172156810761
Epoch:  89


0.14428744494915008 0.12314479798078537
Epoch:  90


0.1449721497297287 0.1235574096441269
Epoch:  91
