In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 1


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6353200029682469 0.4803148295198168
Epoch:  1


0.3352491609953545 0.20229683816432953
Epoch:  2


0.1825929847923485 0.16395507752895355
Epoch:  3


0.1672119546580959 0.1436991702233042
Epoch:  4


0.16255627169802384 0.13607036641665868
Epoch:  5


0.16071750948557983 0.138501943222114
Epoch:  6


0.16004917790760864 0.1324333559189524
Epoch:  7


0.15764375030994415 0.13201800733804703
Epoch:  8


0.1557336002588272 0.1328009037034852
Epoch:  9


0.15554669418850461 0.13450467267206737
Epoch:  10


0.1549801999652708 0.15822033690554754
Epoch:  11


0.1558316269436398 0.1308967566915921
Epoch:  12


0.1547152275169218 0.13556811958551407
Epoch:  13


0.15343697812106158 0.13086060860327312
Epoch:  14


0.15314464512709025 0.12569886978183473
Epoch:  15


0.1524210749445735 0.1285751685500145
Epoch:  16


0.15172832197434194 0.1293634527495929
Epoch:  17


0.15233501348946546 0.12743094989231654
Epoch:  18


0.15197081622239705 0.12611872277089528
Epoch:  19


0.15164922137518186 0.12579600193670817
Epoch:  20


0.1510129933421676 0.12995730021170207
Epoch    20: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  21


0.14984770844111572 0.12389863814626421
Epoch:  22


0.1484632250424978 0.12314373893397194
Epoch:  23


0.1468907369149698 0.12360561639070511
Epoch:  24


0.146875136204668 0.12290025289569582
Epoch:  25


0.14616971805288986 0.12178941603217806
Epoch:  26


0.14597041139731537 0.12179980214153018
Epoch:  27


0.1454719135890136 0.12203990561621529
Epoch:  28


0.14732865103193232 0.1216811825122152
Epoch:  29


0.14721294955627337 0.12213710589068276
Epoch:  30


0.1444960047264357 0.1211018179144178
Epoch:  31


0.14722939117534742 0.12153908184596471
Epoch:  32


0.14532393378180428 0.12132534916911807
Epoch:  33


0.14517807356409124 0.12145977360861641
Epoch:  34


0.14589981533385613 0.12124288082122803
Epoch:  35


0.14534443375226613 0.1214570722409657
Epoch:  36


0.1449005897786166 0.12126321877752032
Epoch    36: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  37


0.14536937869883873 0.12085235970360893
Epoch:  38


0.1449738608018772 0.12088640034198761
Epoch:  39


0.14566898869501577 0.12094464365925107
Epoch:  40


0.14542791891742396 0.12085960060358047
Epoch:  41


0.1436763299075333 0.12092311786753791
Epoch:  42


0.14328711298671928 0.12102537708623069
Epoch:  43


0.14452135442076502 0.12087612173386983
Epoch    43: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  44


0.14467260845609614 0.12093242577144078
Epoch:  45


0.14436045450133247 0.12097179463931493
Epoch:  46


0.14468793732088966 0.12084893243653434
Epoch:  47


0.14410344934141314 0.1206920753632273
Epoch:  48


0.14457408358921875 0.120902062526771
Epoch:  49


0.14450277025635178 0.12080476539475578
Epoch:  50


0.14440509475566246 0.12094572080033165
Epoch:  51


0.1460902928500562 0.12075928279331752
Epoch:  52


0.14509451067125476 0.12073627433606557
Epoch:  53


0.1449633236672427 0.12104063161781856
Epoch    53: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  54


0.14558898556876826 0.1207789129444531
Epoch:  55


0.14567156981777502 0.12063218866075788
Epoch:  56


0.14409524202346802 0.12082941297973905
Epoch:  57


0.14586680281806635 0.12060873955488205
Epoch:  58


0.14404803957488085 0.12080162763595581
Epoch:  59


0.14398456626647227 0.12079172368560519
Epoch:  60


0.14482830182926074 0.12070986202784947
Epoch:  61


0.14472923568777135 0.1206290242927415
Epoch:  62


0.1438636441488524 0.12051525499139513
Epoch:  63


0.14411671177760974 0.12086916182722364
Epoch:  64


0.14396515045617078 0.12064078237329211
Epoch:  65


0.14449420369960167 0.12081821688583919
Epoch:  66


0.14490847490929268 0.12088109765733991
Epoch:  67


0.14461507265632217 0.12104207809482302
Epoch:  68


0.14607052505016327 0.12090536526271276
Epoch    68: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  69


0.14487224132628054 0.12079589707510811
Epoch:  70


0.1453430330431139 0.1209438613482884
Epoch:  71


0.14460289357481776 0.12082890527588981
Epoch:  72


0.1438567755190102 0.12073018401861191
Epoch:  73


0.14384422068660324 0.12089640434299197
Epoch:  74


0.1448806685370368 0.12088513800076076
Epoch:  75


0.14548289252294078 0.12070087024143764
Epoch:  76


0.14419008952540321 0.12091538948672158
Epoch:  77


0.14533976163413073 0.12088222588811602
Epoch:  78


0.14522819825120875 0.12083084561995097
Epoch:  79


0.1455233358853572 0.1207889850650515
Epoch:  80


0.1437675807927106 0.12079182692936488
Epoch:  81


0.14578605785563187 0.12083057633468083
Epoch:  82


0.14488692199056213 0.1206949097769601
Epoch:  83


0.14524647150490735 0.12074031680822372
Epoch:  84


0.1449647755236239 0.12080709955521993
Epoch:  85


0.14371385646832957 0.12082694683756147
Epoch:  86


0.144909866758295 0.120869480073452
Epoch:  87
