In [1]:
# Parameters
until_x = 12


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.5988876473259281 0.4327020261968885
Epoch:  1


0.2657741791493184 0.17036830740315573
Epoch:  2


0.17186190630938555 0.17656629426138742
Epoch:  3


0.16239520222754092 0.13897136918136052
Epoch:  4


0.15863027524303747 0.1345299333333969
Epoch:  5


0.15669840171530441 0.13659135677984782
Epoch:  6


0.1553009667912045 0.13390735111066274
Epoch:  7


0.15603121229120204 0.14276938779013498
Epoch:  8


0.1545815044963682 0.1364899000951222
Epoch:  9


0.15241690060576876 0.13087034225463867
Epoch:  10


0.15189038579528397 0.1287146881222725
Epoch:  11


0.1528266176984117 0.13122416606971196
Epoch:  12


0.15237375610583537 0.13321753484862192
Epoch:  13


0.15026273276354815 0.12805854103394917
Epoch:  14


0.1515792737136016 0.12752396081175124
Epoch:  15


0.1496420792631201 0.12637033313512802
Epoch:  16


0.14918837394263293 0.13111051810639246
Epoch:  17


0.14957403492283178 0.12697371946913855
Epoch:  18


0.15020548773778453 0.12983592918940953
Epoch:  19


0.14961580287765813 0.12515862392527716
Epoch:  20


0.14793231680586533 0.12908248709780829
Epoch:  21


0.14885373051102097 0.13086782821587153
Epoch:  22


0.14861220846305023 0.1270049980708531
Epoch:  23


0.14787150436156504 0.12807512815509522
Epoch:  24


0.14678905139098297 0.12448563213859286
Epoch:  25


0.14720412484697393 0.12744794466665813
Epoch:  26


0.1467862684984465 0.12861369124480657
Epoch:  27


0.1473269736444628 0.12679022869893483
Epoch:  28


0.1469081970485481 0.1247212045959064
Epoch:  29


0.14543498851157524 0.13054343525852477
Epoch:  30


0.14646553509944193 0.12500136026314326
Epoch    30: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  31


0.14438290894031525 0.12188434387956347
Epoch:  32


0.14446242959112734 0.12125929551465171
Epoch:  33


0.1431217974907643 0.12074090327535357
Epoch:  34


0.14328607675191518 0.12120421869414193
Epoch:  35


0.14205211562079353 0.12074736079999379
Epoch:  36


0.1442470715658085 0.12084536999464035
Epoch:  37


0.14234360126224724 0.12133496361119407
Epoch:  38


0.1416058665191805 0.12101885144199644
Epoch:  39


0.1423569563272837 0.12062379717826843
Epoch:  40


0.14161402067622622 0.12100145433630262
Epoch:  41


0.14207901382768476 0.12109936773777008
Epoch:  42


0.1427859684100022 0.12046781395162855
Epoch:  43


0.14136402993588834 0.12112301694495338
Epoch:  44


0.14236093534005656 0.12041492866618293
Epoch:  45


0.14190225544813517 0.12107318746192115
Epoch:  46


0.14210238086210714 0.12136386228459221
Epoch:  47


0.14207692404051084 0.12132872215339116
Epoch:  48


0.14169476604139483 0.12155399684395109
Epoch:  49


0.14203957448134552 0.1211745451603617
Epoch:  50


0.14105733463893066 0.12104938498565129
Epoch    50: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  51


0.1397402761755763 0.12091532562460218
Epoch:  52


0.14016139789207563 0.12063380969422204
Epoch:  53


0.1402065802264858 0.12063011952808925
Epoch:  54


0.13996336729945363 0.12049270101955958
Epoch:  55


0.14152996604507034 0.12051080167293549
Epoch:  56


0.14104822600210035 0.12058015593460628
Epoch    56: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  57


0.14049036035666596 0.12066874014479774
Epoch:  58


0.1411388109664659 0.12061711187873568
Epoch:  59


0.14145064233122645 0.12079026124307088
Epoch:  60


0.1404706191372227 0.12100557770047869
Epoch:  61


0.14050851359560684 0.12067355215549469
Epoch:  62


0.14107073642112114 0.1205466804759843
Epoch    62: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  63


0.14030808975567688 0.12035380942480904
Epoch:  64


0.13971731992992195 0.12028587928840093
Epoch:  65


0.1406449812489587 0.12070037531001228
Epoch:  66


0.14054615272058024 0.12062521065984454
Epoch:  67


0.14096409162959536 0.12072265786784035
Epoch:  68


0.14042788623152552 0.1206751646740096
Epoch:  69


0.1391592996345984 0.12074087453739983
Epoch:  70


0.14008261786924825 0.12038966906922204
Epoch    70: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  71


0.1416784367851309 0.12053843587636948
Epoch:  72


0.14018613261145516 0.12064190634659358
Epoch:  73


0.14054243427676125 0.12045030295848846
Epoch:  74


0.14115939752475634 0.12065159529447556
Epoch:  75


0.14088251703494303 0.12037189091954913
Epoch:  76


0.1407218104278719 0.12068374135664531
Epoch:  77


0.14113150375920372 0.12060516859803881
Epoch:  78


0.14084646387680158 0.120799007160323
Epoch:  79


0.14084744413156766 0.12061007427317756
Epoch:  80


0.1389378635464488 0.12045944588524955
Epoch:  81


0.14191015185536565 0.120487708066191
Epoch:  82


0.14012120865486763 0.12090179643460683
Epoch:  83


0.14078366514798757 0.12061264791658946
Epoch:  84


0.13980428189844699 0.12063910918576377
Epoch:  85


0.1398257053381688 0.12048350487436567
Epoch:  86


0.14049961317229914 0.12046802150351661
Epoch:  87


0.14098265243543162 0.12053713628223964
Epoch:  88


0.141017254139926 0.12053063192537852
Epoch:  89
