In [1]:
# Parameters
until_x = 3


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6063063410488335 0.44287195801734924
Epoch:  1


0.27665077351235057 0.1909938007593155
Epoch:  2


0.1840638812329318 1.2454593692507063
Epoch:  3


0.1778539813853599 0.19749448341982706
Epoch:  4


0.17428551694831332 0.15752170554229192
Epoch:  5


0.17079397311081757 0.15202712161200388
Epoch:  6


0.16876336408628 0.1545067791427885
Epoch:  7


0.16713067325385841 0.1517723309142249
Epoch:  8


0.16494724879393707 0.15252495876380376
Epoch:  9


0.16216612949564652 0.153171392423766
Epoch:  10


0.16271157159998612 0.14952084847858974
Epoch:  11


0.16203230780524178 0.13872445800474711
Epoch:  12


0.16100016800132957 0.14312104455062322
Epoch:  13


0.15952263410026962 0.13797736380781447
Epoch:  14


0.16059004052265272 0.14130199274846486
Epoch:  15


0.15879370634620255 0.13822189079863684
Epoch:  16


0.15849207824951894 0.14203637944800512
Epoch:  17


0.15754541996363047 0.1342030167579651
Epoch:  18


0.15832677684925697 0.14289222019059317
Epoch:  19


0.15687945444841642 0.13598633451121195
Epoch:  20


0.15616568479989026 0.13260340584175928
Epoch:  21


0.15464627017845978 0.13449554038899286
Epoch:  22


0.1558626511612454 0.13216954043933324
Epoch:  23


0.155061969885955 0.1320826272879328
Epoch:  24


0.15695066991690043 0.13551586440631322
Epoch:  25


0.1549877983492774 0.13490273909909384
Epoch:  26


0.15403718320099083 0.13164194353989192
Epoch:  27


0.154964225920471 0.13043577756200517
Epoch:  28


0.15387652532474413 0.13232270628213882
Epoch:  29


0.15315779801961538 0.12898415433509008
Epoch:  30


0.1538828212667156 0.1323664507695607
Epoch:  31


0.152290445727271 0.12830716903720582
Epoch:  32


0.15178159120920542 0.12761075049638748
Epoch:  33


0.15204250409796433 0.1284942403435707
Epoch:  34


0.15254551091709653 0.12625564634799957
Epoch:  35


0.1521509535409309 0.1281192877462932
Epoch:  36


0.15258789223593636 0.12881217471190862
Epoch:  37


0.15225561607528376 0.13400732832295553
Epoch:  38


0.15149845263442477 0.12973022993121827
Epoch:  39


0.1507941536806725 0.13296340086630412
Epoch:  40


0.1509839775594505 0.12964026629924774
Epoch    40: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  41


0.14943803605195638 0.12496213614940643
Epoch:  42


0.148560942024798 0.12518012949398585
Epoch:  43


0.1495183247166711 0.12540914650474275
Epoch:  44


0.1478192649177603 0.1250898391008377
Epoch:  45


0.14846216544911667 0.1255639802132334
Epoch:  46


0.149484447530798 0.12543652206659317
Epoch:  47


0.14779064945272496 0.12487300698246274
Epoch:  48


0.14703080742745786 0.12445283042533058
Epoch:  49


0.14717885569946185 0.12479863102946963
Epoch:  50


0.14782103211493106 0.1254398918577603
Epoch:  51


0.14666951830322678 0.12486412801912852
Epoch:  52


0.1484319978469127 0.12474714006696429
Epoch:  53


0.14746826403849833 0.1247170375926154
Epoch:  54


0.1474078567446889 0.12486587464809418
Epoch    54: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  55


0.14686501589981285 0.12455062355313982
Epoch:  56


0.1483287658240344 0.12449042286191668
Epoch:  57


0.14771119364210078 0.12474240681954793
Epoch:  58


0.14683285919395653 0.1246682054230145
Epoch:  59


0.14770461498080073 0.12455639988183975
Epoch:  60


0.14723710554677086 0.12458704518420356
Epoch    60: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  61


0.1474487257970346 0.12464178672858647
Epoch:  62


0.14791124817487356 0.12462595424481801
Epoch:  63


0.14788711111287814 0.12440574062722069
Epoch:  64


0.14623641726132985 0.12443199860198158
Epoch:  65


0.147357958796862 0.12482737004756927
Epoch:  66


0.14700229264594414 0.12479140822376524
Epoch:  67


0.14772708794555148 0.12452619203499385
Epoch:  68


0.14670236811444565 0.12466528053794589
Epoch:  69


0.14681534267760613 0.12446326549564089
Epoch    69: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  70


0.14680004160146456 0.12460423793111529
Epoch:  71


0.1464615121886537 0.12470602244138718
Epoch:  72


0.14620397179513364 0.12447627208062581
Epoch:  73


0.14667990843991977 0.12477337888308934
Epoch:  74


0.14655736773400693 0.12491690793207713
Epoch:  75


0.1471208100383346 0.12451285975319999
Epoch    75: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  76


0.1464361991431262 0.1246945858001709
Epoch:  77


0.14623883769318863 0.12475648841687612
Epoch:  78


0.1474328395482656 0.12493557589394706
Epoch:  79


0.14766429686868512 0.12470971367188863
Epoch:  80


0.14888999872916453 0.12455199126686368
Epoch:  81


0.14741298153593735 0.12471611478499003
Epoch:  82


0.14560488712143255 0.12485865609986442
Epoch:  83


0.14683706937609492 0.12459675967693329
Epoch:  84


0.14854065751707232 0.1248260959982872
Epoch:  85


0.14638097705067815 0.12456956718649183
Epoch:  86


0.14736981045555425 0.12485182072435107
Epoch:  87


0.1481682734714972 0.12469462837491717
Epoch:  88
