In [1]:
# Parameters
until_x = 11


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [12]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [13]:
# Instantiate the model
model = Task5Model(31).to(device)

In [14]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [15]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6060611935886177 0.4982283243111202
Epoch:  1


0.27512132034108444 0.16742666278566634
Epoch:  2


0.17476113301676674 0.15809792705944606
Epoch:  3


0.16378406898395434 0.14386284670659474
Epoch:  4


0.15877634488247536 0.13547926715442113
Epoch:  5


0.1571028667527276 0.13582603633403778
Epoch:  6


0.15535428717329697 0.13447102265698568
Epoch:  7


0.15501336874188604 0.13010448749576295
Epoch:  8


0.15399231781830658 0.12963319889136724
Epoch:  9


0.1533797580648113 0.13175921142101288
Epoch:  10


0.15377817725813067 0.1253752378480775
Epoch:  11


0.15281749174401565 0.12674830853939056
Epoch:  12


0.1516914814710617 0.12952517398766109
Epoch:  13


0.1512842343465702 0.12852127530745097
Epoch:  14


0.1517953695477666 0.12863175570964813
Epoch:  15


0.14975567766138026 0.12754622846841812
Epoch:  16


0.1497245728969574 0.1276489379150527
Epoch    16: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  17


0.1476654281487336 0.12368808473859515
Epoch:  18


0.14782464222328082 0.12291506784302848
Epoch:  19


0.14775628455587336 0.12271623313426971
Epoch:  20


0.14665174363432704 0.12301105899470192
Epoch:  21


0.1472435343909908 0.12293922475406102
Epoch:  22


0.14689319319016225 0.1226734316774777
Epoch:  23


0.1455010622739792 0.12272036927086967
Epoch:  24


0.1462730515647579 0.12284009477921895
Epoch:  25


0.14535228283824148 0.12290104691471372
Epoch:  26


0.14630384944580696 0.12211120980126518
Epoch:  27


0.1458910985572918 0.12290621761764799
Epoch:  28


0.14551840239279978 0.12243080990655082
Epoch:  29


0.1452622880806794 0.12181428181273597
Epoch:  30


0.14591326020859383 0.12228013681513923
Epoch:  31


0.14663685898523074 0.121712079005582
Epoch:  32


0.14460526648405436 0.12217841297388077
Epoch:  33


0.1452406085020787 0.1217828808086259
Epoch:  34


0.14395966038510605 0.12199090953384127
Epoch:  35


0.14441258359599757 0.12232991201536995
Epoch:  36


0.14460796721883723 0.1219325235911778
Epoch:  37


0.1438934315700789 0.1214829853602818
Epoch:  38


0.1448226772450112 0.12311415587152753
Epoch:  39


0.14411229982569412 0.12183372250625066
Epoch:  40


0.14354040171649005 0.12169028179986137
Epoch:  41


0.14395789963168068 0.12220050926719393
Epoch:  42


0.14415843140434575 0.12272343039512634
Epoch:  43


0.1453003686022114 0.12221078148909978
Epoch    43: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  44


0.143469965135729 0.12229933376823153
Epoch:  45


0.14295820328029427 0.12177935029779162
Epoch:  46


0.14200489182729978 0.12180209585598537
Epoch:  47


0.14218847453594208 0.12178120974983488
Epoch:  48


0.1432879075810716 0.12190549394914083
Epoch:  49


0.14323097507695895 0.12156569106238228
Epoch    49: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  50


0.14373355419249148 0.12149550872189659
Epoch:  51


0.1427496756250794 0.12159454822540283
Epoch:  52


0.14259006646839348 0.12149780988693237
Epoch:  53


0.14344387360521266 0.12157341625009264
Epoch:  54


0.14296617942887382 0.12164880761078425
Epoch:  55


0.1428637468331569 0.12157793236630303
Epoch    55: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  56


0.1424456575029605 0.12167272610323769
Epoch:  57


0.14255855856715022 0.12142878983701978
Epoch:  58


0.14246745709631894 0.1216496473976544
Epoch:  59


0.14362796214786736 0.12177306839397975
Epoch:  60


0.14297022932284587 0.12168753785746438
Epoch:  61


0.14305781955654556 0.12171016952821187
Epoch:  62


0.14325687410058202 0.12140425188200814
Epoch:  63


0.14282992764099225 0.12150026964289802
Epoch:  64


0.14262082770064072 0.12145811106477465
Epoch:  65


0.14350124831135208 0.1218690669962338
Epoch:  66


0.14318869846898155 0.12157068401575089
Epoch:  67


0.14328761680706129 0.12157987803220749
Epoch:  68


0.1425863294988065 0.1216547531741006
Epoch    68: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  69


0.14341881468489365 0.12153603562286922
Epoch:  70


0.14326674752944224 0.1215033871786935
Epoch:  71


0.14205621827293086 0.12162565759250096
Epoch:  72


0.14294319378363118 0.12177050965172904
Epoch:  73


0.14326396060956492 0.12166617384978703
Epoch:  74


0.14191946951118675 0.12151450025183815
Epoch:  75


0.14402082602720004 0.12157705000468663
Epoch:  76


0.14272403998954877 0.12179391724722725
Epoch:  77


0.1442989207602836 0.12141615152359009
Epoch:  78


0.1430015261914279 0.12158834614924022
Epoch:  79


0.14258054100178383 0.12144295232636589
Epoch:  80


0.14273917916658763 0.12152915873697825
Epoch:  81


0.141944403583939 0.12164172210863658
Epoch:  82


0.1428594838928532 0.12155956881386894
Epoch:  83


0.14359368464431246 0.12162601522036962
Epoch:  84


0.1434924000421086 0.12170192173549108
Epoch:  85


0.14279712253325694 0.12147649590458189
Epoch:  86


0.14285293503387556 0.12151127415043968
Epoch:  87
