In [1]:
# Parameters
until_x = 8


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6057378891352061 0.42473587819508146
Epoch:  1


0.27268568607600957 0.15750389013971602
Epoch:  2


0.17829044888148438 0.15493414657456533
Epoch:  3


0.16893098523487915 0.15959249436855316
Epoch:  4


0.16461970516153285 0.15709615179470607
Epoch:  5


0.1612029417946532 0.14738116626228606
Epoch:  6


0.1604593794893574 0.15677236233438765
Epoch:  7


0.1599487916037843 0.13475075364112854
Epoch:  8


0.15664100123418345 0.13921200803347997
Epoch:  9


0.15572020451764804 0.12963042833975383
Epoch:  10


0.1552010802803813 0.13197545494352067
Epoch:  11


0.15297164264562968 0.12855156511068344
Epoch:  12


0.15371871195934914 0.12742630392313004
Epoch:  13


0.15205073276081602 0.13002037044082368
Epoch:  14


0.15256365368495117 0.14044741647584097
Epoch:  15


0.1522854143703306 0.13285842537879944
Epoch:  16


0.1511002666241414 0.12952015974691936
Epoch:  17


0.15217414739969615 0.13624989347798483
Epoch:  18


0.15149986945294044 4.156020879745483
Epoch    18: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  19


0.15014741589894165 0.12371935588972909
Epoch:  20


0.1483028913671906 0.12381060421466827
Epoch:  21


0.1489975271192757 0.12358859394277845
Epoch:  22


0.14750122057425008 0.12340769384588514
Epoch:  23


0.14774471400557337 0.12353818118572235
Epoch:  24


0.1484568127103754 0.1230454051068851
Epoch:  25


0.1473827652029089 0.12315412397895541
Epoch:  26


0.1474401531992732 0.12338430966649737
Epoch:  27


0.14658841772659406 0.12302233065877642
Epoch:  28


0.14764466156830658 0.12338516541889735
Epoch:  29


0.1476542563051791 0.12279456108808517
Epoch:  30


0.14788177488623439 0.12359450863940376
Epoch:  31


0.1485101081229545 0.1226645005600793
Epoch:  32


0.1458551581646945 0.12295822692768914
Epoch:  33


0.14667217957006917 0.12355674377509526
Epoch:  34


0.14782314042787295 0.12272025538342339
Epoch:  35


0.14614944321078224 0.1226749558533941
Epoch:  36


0.1468110567814595 0.1236185325043542
Epoch:  37


0.1471495769313864 0.12246407994202205
Epoch:  38


0.14661358296871185 0.12340327139411654
Epoch:  39


0.1462197529303061 0.1223956367799214
Epoch:  40


0.1468493435028437 0.12278869428804942
Epoch:  41


0.14567949401365743 0.1225017426269395
Epoch:  42


0.14657264988164645 0.12228991729872567
Epoch:  43


0.14611218628045675 0.12260016053915024
Epoch:  44


0.146015329940899 0.12298862316778728
Epoch:  45


0.1445913524241061 0.12280936964920589
Epoch:  46


0.14426555021389112 0.1231338062456676
Epoch:  47


0.14571715609447375 0.12317531023706708
Epoch:  48


0.14642939817261053 0.12289907038211823
Epoch    48: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  49


0.14450725067306208 0.12282827283654894
Epoch:  50


0.14533821634344152 0.12229846205030169
Epoch:  51


0.14455343219074043 0.12216273588793618
Epoch:  52


0.14457567076425296 0.12257192922489983
Epoch:  53


0.14489243884344358 0.12248001992702484
Epoch:  54


0.14430687999403155 0.12245171517133713
Epoch:  55


0.14467063787821177 0.1225691916687148
Epoch:  56


0.14521039821006157 0.12288764651332583
Epoch:  57


0.14472466747503024 0.12256285016025815
Epoch    57: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  58


0.14413468136980728 0.12267640658787318
Epoch:  59


0.14461862638189987 0.1225204308118139
Epoch:  60


0.1443143829300597 0.1222415342926979
Epoch:  61


0.14454391155693982 0.12216182372399739
Epoch:  62


0.14526845998055227 0.12240226034607206
Epoch:  63


0.14546944482906446 0.12231054902076721
Epoch    63: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  64


0.14409770756154447 0.12226344645023346
Epoch:  65


0.14555582565230293 0.1222981618983405
Epoch:  66


0.14323145231685122 0.12212173747164863
Epoch:  67


0.14470964589634458 0.12229831197432109
Epoch:  68


0.14495019050868782 0.12238009167569024
Epoch:  69


0.1451965252289901 0.12229423650673457
Epoch:  70


0.14462003917307467 0.1223688753587859
Epoch:  71


0.1441996725829872 0.1221852217401777
Epoch:  72


0.14402759437625473 0.12217513578278678
Epoch    72: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  73


0.1438765111001762 0.12243775065456118
Epoch:  74


0.14504802106200038 0.12230400953974042
Epoch:  75


0.1452136502878086 0.12234172544309072
Epoch:  76


0.14480891864042025 0.12250956680093493
Epoch:  77


0.14378264063113444 0.1224074438214302
Epoch:  78


0.14460097736603506 0.12257215593542371
Epoch:  79


0.1449532295401032 0.12237072097403663
Epoch:  80


0.14375118629352465 0.12241283804178238
Epoch:  81


0.1451578643676397 0.12227541421140943
Epoch:  82


0.14457156569571109 0.12223413693053382
Epoch:  83


0.14503800828714628 0.12240920428718839
Epoch:  84


0.14505980304769567 0.12242175319365092
Epoch:  85


0.14412939790132884 0.1223981071795736
Epoch:  86


0.14501441531890147 0.12218238094023295
Epoch:  87


0.14418346414694916 0.12240368872880936
Epoch:  88


0.14389595671280012 0.12248273193836212
Epoch:  89


0.14537612167564598 0.12244603570018496
Epoch:  90


0.1456521652840279 0.12221455574035645
Epoch:  91
