In [1]:
# Parameters
until_x = 13


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]


def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()


class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 96, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 96, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 96, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
# Instantiate the model
model = Task5Model(31).to(device)

In [12]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [13]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 15:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6569922542572022 0.5845218539237976
Epoch:  1


0.44113431334495545 0.2993956983089447
Epoch:  2


0.23165853798389435 0.18530858755111695
Epoch:  3


0.18524786531925203 0.17136319279670714
Epoch:  4


0.17933459877967833 0.1591965973377228
Epoch:  5


0.17636922478675843 0.16398302614688873
Epoch:  6


0.1746460509300232 0.16275509297847748
Epoch:  7


0.17266107618808746 0.1687749892473221
Epoch:  8


0.17122861742973328 0.15322066843509674
Epoch:  9


0.1704242032766342 0.1599793553352356
Epoch:  10


0.16981348872184754 0.15611501634120942
Epoch:  11


0.16793382465839385 0.15431393980979918
Epoch:  12


0.16863242030143738 0.1539464771747589
Epoch:  13


0.1672990483045578 0.150533989071846
Epoch:  14


0.1661658263206482 0.14920894503593446
Epoch:  15


0.16587673366069794 0.1468768149614334
Epoch:  16


0.16496883869171142 0.14955613017082214
Epoch:  17


0.16272362530231477 0.14482787549495696
Epoch:  18


0.1637067073583603 0.1482359766960144
Epoch:  19


0.16309078097343443 0.14457378089427947
Epoch:  20


0.16125577211380004 0.1415846139192581
Epoch:  21


0.15993319392204286 0.1470915287733078
Epoch:  22


0.16034017086029054 0.13722630590200424
Epoch:  23


0.16139744579792023 0.13714206218719482
Epoch:  24


0.15965670883655547 0.14401413202285768
Epoch:  25


0.15936348140239714 0.14470045864582062
Epoch:  26


0.1596132743358612 0.140705344080925
Epoch:  27


0.1593293833732605 0.14007100462913513
Epoch:  28


0.15762666404247283 0.13384694308042527
Epoch:  29


0.1590091735124588 0.13749209195375442
Epoch:  30


0.15691735625267028 0.13779704868793488
Epoch:  31


0.15864198803901672 0.13673101663589476
Epoch:  32


0.15887419641017914 0.1454220622777939
Epoch:  33


0.1577167510986328 0.13585546612739563
Epoch:  34


0.15631625890731812 0.1380698263645172
Epoch    34: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  35


0.15769070744514466 0.13045313060283661
Epoch:  36


0.15602068185806275 0.12993387877941132
Epoch:  37


0.15507366478443146 0.12930963337421417
Epoch:  38


0.15464132487773896 0.12903889268636703
Epoch:  39


0.15475778698921203 0.12909804582595824
Epoch:  40


0.15248705446720123 0.1286198690533638
Epoch:  41


0.15393961548805238 0.12849113941192628
Epoch:  42


0.15469785869121552 0.12865707129240037
Epoch:  43


0.15343756675720216 0.12863367348909377
Epoch:  44


0.1527475506067276 0.1287848860025406
Epoch:  45


0.15344662189483643 0.12797483652830124
Epoch:  46


0.15451207101345063 0.12816119343042373
Epoch:  47


0.1529935747385025 0.12804806530475615
Epoch:  48


0.15408552050590515 0.12845190465450287
Epoch:  49


0.15305789709091186 0.12832201421260833
Epoch:  50


0.15229999303817748 0.12754171043634416
Epoch:  51


0.15293217420578004 0.12827301472425462
Epoch:  52


0.15195662081241607 0.12808602601289748
Epoch:  53


0.15306683182716369 0.1276179000735283
Epoch:  54


0.15248985290527345 0.12816448211669923
Epoch:  55


0.1532538950443268 0.1271611288189888
Epoch:  56


0.15274016678333283 0.12725678831338882
Epoch:  57


0.15179262101650237 0.12728870064020156
Epoch:  58


0.15224353611469268 0.12722889930009842
Epoch:  59


0.15284062445163726 0.12713863402605058
Epoch:  60


0.1524561047554016 0.12706537991762162
Epoch:  61


0.1520821452140808 0.12717092633247376
Epoch:  62


0.15176173865795137 0.12696593105793
Epoch:  63


0.1530014967918396 0.12724064886569977
Epoch:  64


0.15185380041599272 0.12674609571695328
Epoch:  65


0.15319251120090485 0.12670467793941498
Epoch:  66


0.15222687482833863 0.12615988850593568
Epoch:  67


0.15322932958602906 0.1262487754225731
Epoch:  68


0.150639368891716 0.12645349651575089
Epoch:  69


0.15102996408939362 0.1258950263261795
Epoch:  70


0.1516800880432129 0.12654831558465957
Epoch:  71


0.15177641332149505 0.12657639384269714
Epoch:  72


0.15187307476997375 0.12705631405115128
Epoch:  73


0.1514297044277191 0.1263577312231064
Epoch:  74


0.15132799685001375 0.1269499957561493
Epoch:  75


0.15136198222637176 0.12582869529724122
Epoch:  76


0.1522689062356949 0.12634607553482055
Epoch:  77


0.15187624990940093 0.12648411095142365
Epoch:  78


0.15134099304676055 0.12606944888830185
Epoch:  79


0.1509278702735901 0.1262400358915329
Epoch:  80


0.15123852729797363 0.12611347883939744
Epoch:  81


0.1523605614900589 0.12704361826181412
Epoch    81: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  82


0.1516638743877411 0.12660836130380632
Epoch:  83


0.15086017608642577 0.12627707421779633
Epoch:  84


0.1506313145160675 0.12606717348098756
Epoch:  85


0.15181243300437927 0.12622664719820023
Epoch:  86


0.15139634609222413 0.12586598992347717
Epoch:  87


0.15095956802368163 0.12599768191576005
Epoch    87: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  88


0.15096316277980804 0.12622511833906175
Epoch:  89


0.1507941687107086 0.1258951485157013
Epoch:  90
