In [1]:
# Parameters
until_x = 4


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6061520487875551 0.41978244270597187
Epoch:  1


0.27460576070321574 0.18008349410125188
Epoch:  2


0.18244864207667275 0.16717689165047236
Epoch:  3


0.17237484092647964 0.17269554947103774
Epoch:  4


0.1700085731776985 0.1649267077445984
Epoch:  5


0.16644091501429276 0.15634199338299887
Epoch:  6


0.1655160007444588 0.14208813543830598
Epoch:  7


0.163294811103795 0.16691283030169352
Epoch:  8


0.16338366108971672 0.14477527567318507
Epoch:  9


0.16048137360327952 0.13553863763809204
Epoch:  10


0.1602558425149402 0.13392042262213572
Epoch:  11


0.1587804004147246 0.1318444237112999
Epoch:  12


0.1584013442735414 0.14023691415786743
Epoch:  13


0.15753631213226835 0.13357895932027272
Epoch:  14


0.15739306847791415 0.13088830879756383
Epoch:  15


0.15500884080255353 0.13165166229009628
Epoch:  16


0.15564408495619492 0.12890015968254634
Epoch:  17


0.15446152598471255 0.12764829929385865
Epoch:  18


0.1541582007665892 0.13540214725903102
Epoch:  19


0.15565206070204038 0.13110518561942236
Epoch:  20


0.15371123600650477 0.13002987844603403
Epoch:  21


0.15464942197541934 0.13011359529835836
Epoch:  22


0.15269306986718564 0.12921968528202601
Epoch:  23


0.15376620236280802 0.1273867530482156
Epoch:  24


0.15234996338148374 0.12928128987550735
Epoch:  25


0.1531403837977229 0.1277741374714034
Epoch:  26


0.15209881357244542 0.1287240151848112
Epoch:  27


0.15139747390875946 0.1328326708504132
Epoch:  28


0.15095075562193588 0.12789128720760345
Epoch:  29


0.15047923898374713 0.1316822607602392
Epoch    29: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  30


0.1498740520831701 0.12492276302405766
Epoch:  31


0.14878068703251915 0.12485261580773763
Epoch:  32


0.14849061498770844 0.12424758608852114
Epoch:  33


0.1479114382653623 0.12436453785215106
Epoch:  34


0.14928082155214772 0.1252893422331129
Epoch:  35


0.14781232178211212 0.1247869176524026
Epoch:  36


0.14709828391268448 0.12422172512326922
Epoch:  37


0.14766108143973994 0.12419188129050392
Epoch:  38


0.14625776176517075 0.12474579789808818
Epoch:  39


0.14775713513026367 0.12430769630840846
Epoch:  40


0.14689665227322965 0.1240817978978157
Epoch:  41


0.14809240884072072 0.12456888066870826
Epoch:  42


0.14792820368264173 0.12384438621146339
Epoch:  43


0.14803306677856962 0.12373146414756775
Epoch:  44


0.1462707314136866 0.124337765787329
Epoch:  45


0.14578699662878708 0.12378069545541491
Epoch:  46


0.1458838622312288 0.1238163805433682
Epoch:  47


0.1470519529806601 0.1251214380775179
Epoch:  48


0.14679776977848363 0.12367444485425949
Epoch:  49


0.1473096564814851 0.12452406755515508
Epoch:  50


0.1459473323177647 0.1237095051578113
Epoch:  51


0.14749984966742025 0.12385598357234683
Epoch:  52


0.1464871131890529 0.12341741153172084
Epoch:  53


0.14704512422149246 0.12316253887755531
Epoch:  54


0.14673146927678907 0.12429832773549217
Epoch:  55


0.1471463973457749 0.12380389017718178
Epoch:  56


0.14605874749454292 0.12313016929796763
Epoch:  57


0.14611619630375425 0.12363540487630027
Epoch:  58


0.14504601705718684 0.12265682433332716
Epoch:  59


0.14669588651206042 0.12299640370266778
Epoch:  60


0.1455587672220694 0.12336491048336029
Epoch:  61


0.1452559385750745 0.12287587885345731
Epoch:  62


0.14579596632235758 0.12333189483199801
Epoch:  63


0.1466650197634826 0.12238589567797524
Epoch:  64


0.14518369653740446 0.12328494340181351
Epoch:  65


0.14553578761783806 0.12336784494774682
Epoch:  66


0.14527213734549446 0.12304315716028214
Epoch:  67


0.14584939141531247 0.12334563370261874
Epoch:  68


0.14522075210068677 0.12282990983554296
Epoch:  69


0.14591614016004512 0.1238269933632442
Epoch    69: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  70


0.1460934635755178 0.12362215561526162
Epoch:  71


0.1443912680890109 0.12306992709636688
Epoch:  72


0.14487118133016536 0.12293356444154467
Epoch:  73


0.14485139742090897 0.12283702939748764
Epoch:  74


0.14410344128673142 0.12287710181304387
Epoch:  75


0.14450669570549116 0.12289553561380931
Epoch    75: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  76


0.14572522406642502 0.12265534166778837
Epoch:  77


0.14482049966180646 0.12304566374846868
Epoch:  78


0.14648099968562256 0.12269969178097588
Epoch:  79


0.14491866408167658 0.12304185330867767
Epoch:  80


0.14400901383644826 0.1228697555405753
Epoch:  81


0.1437240122137843 0.12300137430429459
Epoch    81: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  82


0.14477406683805827 0.12358261219092778
Epoch:  83


0.14457222336047404 0.12257899982588631
Epoch:  84


0.14587501980162956 0.12300697820527214
Epoch:  85


0.14519262434663 0.12325058664594378
Epoch:  86


0.14459855653144219 0.12282926695687431
Epoch:  87


0.14413221301259221 0.12281428171055657
Epoch    87: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  88
