In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 14


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6412889909099888 0.5304120864186969
Epoch:  1


0.3442933575527088 0.21056132657187326
Epoch:  2


0.19527183271743156 0.17951489772115434
Epoch:  3


0.18297465306681557 76.76427187238421
Epoch:  4


0.18065059144754667 0.24453447120530264
Epoch:  5


0.1770550628771653 0.17369910223143442
Epoch:  6


0.1760823279619217 0.16466938384941646
Epoch:  7


0.17596210780981425 0.1628604531288147
Epoch:  8


0.17583467912029577 0.16254407806055887
Epoch:  9


0.17400875002951235 0.15977048235280172
Epoch:  10


0.17306092663391218 0.16108546086720057
Epoch:  11


0.17214438117839195 0.15719552976744516
Epoch:  12


0.16984656936413534 0.1563571755375181
Epoch:  13


0.17103882375601176 0.15922896989754268
Epoch:  14


0.1702447497361415 0.15927584895065852
Epoch:  15


0.1691485047340393 0.1502262247460229
Epoch:  16


0.1688663013078071 0.1558086403778621
Epoch:  17


0.1672763192170375 0.14902339556387492
Epoch:  18


0.16857156358860634 0.15382086166313716
Epoch:  19


0.16910263614074603 0.1487365790775844
Epoch:  20


0.1687653338586962 0.15198007651737758
Epoch:  21


0.16716919114460815 0.15383281026567733
Epoch:  22


0.16656766670781212 0.1479517945221492
Epoch:  23


0.1663185019750853 0.15163916562284743
Epoch:  24


0.16528782369317235 0.15169228187629155
Epoch:  25


0.16412059036461082 0.14638047239610127
Epoch:  26


0.16365894634981412 0.14948359344686782
Epoch:  27


0.16344357503427043 0.1490591892174312
Epoch:  28


0.16435540850098068 0.14603646951062338
Epoch:  29


0.1633767352716343 0.144979898418699
Epoch:  30


0.16302998444518527 0.14628750937325613
Epoch:  31


0.16308084613568075 0.15316297113895416
Epoch:  32


0.16302315568601763 0.14254931147609437
Epoch:  33


0.1644921451807022 0.14408886964832032
Epoch:  34


0.1618877000905372 0.14636564893381937
Epoch:  35


0.16250130091164564 0.1407532755817686
Epoch:  36


0.16102860545789874 0.13904462648289545
Epoch:  37


0.16242694693642692 0.14317000125135695
Epoch:  38


0.1616980530120231 0.1404579154082707
Epoch:  39


0.16114394165374138 0.14800024245466506
Epoch:  40


0.161325501026334 0.1433380799634116
Epoch:  41


0.15948378999490995 0.13938501051494054
Epoch:  42


0.16223647264209953 0.13822930731943675
Epoch:  43


0.15995858811043404 0.14259509848696844
Epoch:  44


0.1605024825076799 0.13708603062799998
Epoch:  45


0.1603655891643988 0.13865084201097488
Epoch:  46


0.15995562922310186 0.14538982723440444
Epoch:  47


0.16007100971969399 0.1437725658927645
Epoch:  48


0.15897867848744263 0.13889224082231522
Epoch:  49


0.15983867604990262 0.13867436349391937
Epoch:  50


0.1588875538594014 0.13975124061107635
Epoch    50: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  51


0.15694020527440147 0.1353517845273018
Epoch:  52


0.1564032652893582 0.1344939363854272
Epoch:  53


0.15493261371109937 0.13420868345669337
Epoch:  54


0.1556674399085947 0.13449170972619737
Epoch:  55


0.15636967323921822 0.13472647752080644
Epoch:  56


0.1558192980450553 0.1337073062147413
Epoch:  57


0.15608688383489042 0.1356436876314027
Epoch:  58


0.15477878741315892 0.13485867849418096
Epoch:  59


0.15595568676252622 0.13417813288314
Epoch:  60


0.15607733500970378 0.13419996734176362
Epoch:  61


0.1558292838367256 0.13373469774212157
Epoch:  62


0.15565343603894516 0.1337944737502507
Epoch    62: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  63


0.15600902808679118 0.13378336493458068
Epoch:  64


0.15530249879166885 0.13386573110307967
Epoch:  65


0.15559768233750318 0.13362565530197962
Epoch:  66


0.15443642236090996 0.1338990405201912
Epoch:  67


0.1553926790082777 0.13304355208362853
Epoch:  68


0.15498631870424426 0.13406448172671454
Epoch:  69


0.1560859043855925 0.1336688527039119
Epoch:  70


0.1542112307774054 0.13380882463284902
Epoch:  71


0.15573649954151464 0.13394241567168916
Epoch:  72


0.1565667221675048 0.13374395349196025
Epoch:  73


0.1542743331677205 0.13361197710037231
Epoch    73: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  74


0.1554685597484176 0.13361380675009318
Epoch:  75


0.15482377240786682 0.13364970896925246
Epoch:  76


0.15546900154771032 0.1334480717778206
Epoch:  77


0.15504644931973638 0.13344050092356546
Epoch:  78


0.1553339805151965 0.13398289041859762
Epoch:  79


0.1558319096629684 0.1331819487469537
Epoch    79: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  80


0.15570542780128685 0.1339655720761844
Epoch:  81


0.15560270080695282 0.1339577574815069
Epoch:  82


0.1555239223950618 0.13309952084507262
Epoch:  83


0.1545420925359468 0.13381136102335794
Epoch:  84


0.15643418721250585 0.13382831322295324
Epoch:  85


0.1552198863512761 0.13402190165860312
Epoch    85: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  86


0.1560617218146453 0.13378159808261053
Epoch:  87


0.15463218415105665 0.13377434334584645
Epoch:  88


0.15583063903692607 0.13331399857997894
Epoch:  89


0.15354415854892214 0.13386973419359752
Epoch:  90


0.1546312582653922 0.13413652990545546
Epoch:  91


0.15566683782113566 0.13360036696706498
Epoch:  92
