In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 8


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6333145708651156 0.5076573661395481
Epoch:  1


0.33268645768230026 0.4284894934722355
Epoch:  2


0.19414825213922038 0.2054254753249032
Epoch:  3


0.18233601910036965 15.867994921548027
Epoch:  4


0.18005339197210363 0.1711828772510801
Epoch:  5


0.17981402012141975 0.17798065074852534
Epoch:  6


0.17830396624835762 0.16783657457147325
Epoch:  7


0.17692624918512395 0.17217853239604405
Epoch:  8


0.176333766934034 0.179555743932724
Epoch:  9


0.17404893403117722 0.16709985051836287
Epoch:  10


0.17343766866503535 0.16906969036374772
Epoch:  11


0.1718233143155639 0.15781070291996002
Epoch:  12


0.1697982589940767 0.1502976587840489
Epoch:  13


0.16973144983923114 0.14718834630080632
Epoch:  14


0.16768356593879494 0.14660001546144485
Epoch:  15


0.16614012702091321 0.1632964015007019
Epoch:  16


0.1656844185010807 0.14328960329294205
Epoch:  17


0.16537845255555333 0.1432361283472606
Epoch:  18


0.16390017802650864 0.14334222887243545
Epoch:  19


0.16296235210186727 0.14778009482792445
Epoch:  20


0.1625478460982039 0.14330507389136724
Epoch:  21


0.16264771368052508 0.1383202618786267
Epoch:  22


0.16312130802386515 0.13756293271269118
Epoch:  23


0.16218231376763936 0.14259453862905502
Epoch:  24


0.1616097209421364 0.136851861008576
Epoch:  25


0.15949590262529012 0.13602991721459798
Epoch:  26


0.1591348273528589 0.1356144238795553
Epoch:  27


0.15858268536425926 0.1336292518036706
Epoch:  28


0.15849140971093564 0.1388596615621022
Epoch:  29


0.1592165344470256 0.14516397884913854
Epoch:  30


0.1583898458126429 0.13808914061103547
Epoch:  31


0.1567155034155459 0.13291863352060318
Epoch:  32


0.1553748785643964 0.13183154591492244
Epoch:  33


0.15633912062322772 0.13275827573878424
Epoch:  34


0.15708297209159747 0.13732951028006418
Epoch:  35


0.15596857828062935 0.1337012659226145
Epoch:  36


0.15605453945494988 0.13246345413582666
Epoch:  37


0.1550858866524052 0.134955672281129
Epoch:  38


0.15643537165345373 0.13287814600127085
Epoch    38: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  39


0.15391374359259735 0.12773610012871878
Epoch:  40


0.15188111002380783 0.12731256442410605
Epoch:  41


0.15196704783955137 0.1264817863702774
Epoch:  42


0.1521295973578015 0.1263952755502292
Epoch:  43


0.15298858083583214 0.12594011745282582
Epoch:  44


0.15161818265914917 0.12557560418333327
Epoch:  45


0.15200696885585785 0.12638669993196214
Epoch:  46


0.15246674942003713 0.12551546628986085
Epoch:  47


0.15184613820668813 0.12551439659936087
Epoch:  48


0.15086347549348264 0.12596592200653894
Epoch:  49


0.1509584744234343 0.12588034038032805
Epoch:  50


0.14915047827604655 0.1251679084130696
Epoch:  51


0.15121404866914492 0.12554945370980672
Epoch:  52


0.1516163159866591 0.12546632332461222
Epoch:  53


0.15002814056100072 0.12546622966017043
Epoch:  54


0.15021809009281364 0.1257390614066805
Epoch:  55


0.15117179783614906 0.12569457292556763
Epoch:  56


0.15017035361882802 0.12472973444632121
Epoch:  57


0.1498642068456959 0.1252472081354686
Epoch:  58


0.1496681623362206 0.12511839611189707
Epoch:  59


0.1500442684502215 0.12521889699356897
Epoch:  60


0.1511239804126121 0.1247362270951271
Epoch:  61


0.15192697056241938 0.12565944769552775
Epoch:  62


0.1483198275437226 0.12434182528938566
Epoch:  63


0.15061015212858045 0.12471384980848857
Epoch:  64


0.14996353154246872 0.1252325645514897
Epoch:  65


0.14855275001074816 0.12527128521885192
Epoch:  66


0.15002824204999046 0.12492751755884715
Epoch:  67


0.15059429809853836 0.12475221178361348
Epoch:  68


0.14999738536976479 0.12488105041640145
Epoch    68: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  69


0.14910197257995605 0.1250354243176324
Epoch:  70


0.15015419069174174 0.12480912570442472
Epoch:  71


0.14982491852463903 0.12477796737636838
Epoch:  72


0.1490487637552055 0.12476048299244472
Epoch:  73


0.14929565545674917 0.12474723905324936
Epoch:  74


0.14890496754968488 0.12475930367197309
Epoch    74: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  75


0.14926250762230642 0.12456923616783959
Epoch:  76


0.14934310961414027 0.1245512728180204
Epoch:  77


0.15074793390325597 0.12455433074917112
Epoch:  78


0.14916379347040848 0.12454061103718621
Epoch:  79


0.148367777466774 0.12437779349940163
Epoch:  80


0.14996525121701731 0.12446207872458867
Epoch    80: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  81


0.14981668382077604 0.12474086774247033
Epoch:  82


0.1492502926974683 0.12503240470375335
Epoch:  83


0.14994427481213132 0.12448264126266752
Epoch:  84


0.14765016734600067 0.12478896230459213
Epoch:  85


0.14942600960667068 0.12453021321977888
Epoch:  86


0.14985866844654083 0.12439385056495667
Epoch    86: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  87
