In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 3


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6320178009368278 0.47642860242298674
Epoch:  1


0.3288944460250236 0.19187440829617636
Epoch:  2


0.18467195613964185 0.15578473465783255
Epoch:  3


0.17272557519577644 0.14856116686548507
Epoch:  4


0.16706682661095182 0.1464797705411911
Epoch:  5


0.164354980797381 0.1376021472471101
Epoch:  6


0.16249406176644401 0.1383364275097847
Epoch:  7


0.16225968583210096 0.138007681284632
Epoch:  8


0.16132832217860865 0.13840030772345407
Epoch:  9


0.15827872302081133 0.13810549144233977
Epoch:  10


0.15849686997967796 0.14309721333639963
Epoch:  11


0.15879366486459165 0.1341507307120732
Epoch:  12


0.15636040230055112 0.13731919441904342
Epoch:  13


0.15718174625087428 0.13523863681725093
Epoch:  14


0.15632006284352895 0.13446136457579477
Epoch:  15


0.1553875824084153 0.1372474506497383
Epoch:  16


0.1548872283987097 0.13176527619361877
Epoch:  17


0.15414486301911845 0.15102614888123103
Epoch:  18


0.15429819234319636 0.12692862536226
Epoch:  19


0.15391888449320923 0.1275988934295518
Epoch:  20


0.15362481772899628 0.13158217923981802
Epoch:  21


0.15453709581413785 0.1305398525936263
Epoch:  22


0.15296299755573273 0.12738699040242604
Epoch:  23


0.1518205533156524 0.13144434882061823
Epoch:  24


0.15212505773918047 0.1333294894014086
Epoch    24: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  25


0.15029169981544083 0.1236930319241115
Epoch:  26


0.1501656698214041 0.12375576687710625
Epoch:  27


0.1497905032860266 0.12346844800880977
Epoch:  28


0.14833105214544245 0.12284959533384868
Epoch:  29


0.14860025209349556 0.12361730315855571
Epoch:  30


0.14784363999560074 0.12444242089986801
Epoch:  31


0.14840115365144368 0.12373245933226176
Epoch:  32


0.1483164071231275 0.12328929028340749
Epoch:  33


0.1472659356690742 0.12429360193865639
Epoch:  34


0.14866131906573837 0.12349429726600647
Epoch    34: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  35


0.14815096436320124 0.12284592964819499
Epoch:  36


0.14789162415104942 0.12265158763953618
Epoch:  37


0.14660498701237343 0.12249794176646642
Epoch:  38


0.14791555501319267 0.12240790256432124
Epoch:  39


0.14676501098516825 0.12268669477530889
Epoch:  40


0.14741318451391683 0.12258339673280716
Epoch:  41


0.14655962586402893 0.12284997531345912
Epoch:  42


0.14636542305753036 0.12246719534908022
Epoch:  43


0.14860394113772624 0.1224904911858695
Epoch:  44


0.14788691256497358 0.12309081213814872
Epoch    44: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  45


0.1466731682822511 0.12247677040951592
Epoch:  46


0.14787185715662465 0.12285023501941136
Epoch:  47


0.14743079446457527 0.12316552230290004
Epoch:  48


0.14723262593552872 0.12261195693697248
Epoch:  49


0.147016148712184 0.12238882588488716
Epoch:  50


0.14727688480067896 0.12299563629286629
Epoch:  51


0.14636845383289698 0.1225217655301094
Epoch:  52


0.1477820192639892 0.12251636385917664
Epoch:  53


0.1471952820146406 0.12231781333684921
Epoch:  54


0.14780791949581457 0.1227851254599435
Epoch:  55


0.14672966583355054 0.12263858211891991
Epoch:  56


0.14760015461895917 0.12274202810866493
Epoch:  57


0.14665588494893667 0.12272496095725469
Epoch:  58


0.14773545595439705 0.12251866395984377
Epoch:  59


0.14682531598451976 0.12234705580132348
Epoch    59: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  60


0.1482076644897461 0.12244765354054314
Epoch:  61


0.1478016501342928 0.12249135971069336
Epoch:  62


0.14871715049485904 0.12277393575225558
Epoch:  63


0.14740657081475128 0.12301383912563324
Epoch:  64


0.14587662993250666 0.12260176773582186
Epoch:  65


0.14739675578233358 0.12242529967001506
Epoch    65: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  66


0.14677987147021937 0.12230951019695827
Epoch:  67


0.14797390876589595 0.12236676258700234
Epoch:  68


0.14714190988121806 0.12241261239562716
Epoch:  69


0.14715740004101316 0.12288590627057212
Epoch:  70


0.14762911562984055 0.12262441962957382
Epoch:  71


0.14599247196236173 0.1229138235960688
Epoch:  72


0.1476416024001869 0.12266013771295547
Epoch:  73


0.14913257312130285 0.12237817687647683
Epoch:  74


0.14823726302868612 0.12236382705824715
Epoch:  75


0.1492680363558434 0.12295239631618772
Epoch:  76


0.14768181096863103 0.1226827757699149
Epoch:  77


0.1463212092985978 0.1226905956864357
Epoch:  78


0.14780477816994125 0.12291485284055982
Epoch:  79


0.1473932576340598 0.12271634915045329
Epoch:  80


0.14761009973448677 0.1223150874887194
Epoch:  81


0.14877764840383786 0.12270997677530561
Epoch:  82


0.1474463379866368 0.12280745059251785
Epoch:  83


0.1471410156101794 0.12291407691580909
Epoch:  84


0.14711890953618126 0.12234485255820411
Epoch:  85


0.14611819305935422 0.12237720617226192
Epoch:  86


0.14659384899848216 0.12279485698257174
Epoch:  87


0.14746245258563273 0.12254691549709865
Epoch:  88


0.1476825918700244 0.12290334169353757
Epoch:  89


0.14849957622386314 0.12268817850521632
Epoch:  90


0.1473884852351369 0.12309291958808899
Epoch:  91
