In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 11


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.632578427727158 0.5134319067001343
Epoch:  1


0.3265053412398776 0.20186608604022435
Epoch:  2


0.19167866014145515 0.18420386740139552
Epoch:  3


0.18127720903705907 0.1746014803647995
Epoch:  4


0.17996078125528386 0.1692842479263033
Epoch:  5


0.1786469643985903 0.16711391934326716
Epoch:  6


0.17685176875140216 0.16342199061598098
Epoch:  7


0.17478624831985784 0.16542219477040426
Epoch:  8


0.17576525662396406 0.16796679369040898
Epoch:  9


0.1750767758569202 0.15855515216078078
Epoch:  10


0.17460839973913655 0.1632007701056344
Epoch:  11


0.17222400210999153 0.16089353816849844
Epoch:  12


0.17122325623357618 0.1555887120110648
Epoch:  13


0.1702898905889408 0.15877550414630345
Epoch:  14


0.16961370247441368 0.15716042901788438
Epoch:  15


0.16842224227415548 0.15336031786033086
Epoch:  16


0.16672327470135045 0.1549155797277178
Epoch:  17


0.16729878573804288 0.1529561302491597
Epoch:  18


0.16741013245002642 0.16220851455416
Epoch:  19


0.1661634231741364 0.1492470247404916
Epoch:  20


0.16654845588916056 0.1559013170855386
Epoch:  21


0.16466163984827092 0.1472198771578925
Epoch:  22


0.16421468193466598 0.14532046445778438
Epoch:  23


0.1656221199680019 0.14696895011833735
Epoch:  24


0.16375867055880056 0.14231559421334947
Epoch:  25


0.16199111012188164 0.15163537859916687
Epoch:  26


0.16296546523635452 0.14575286954641342
Epoch:  27


0.16208578404542562 0.1418973228761128
Epoch:  28


0.16341456045975555 0.14881627155201776
Epoch:  29


0.16401856896039602 0.14231730784688676
Epoch:  30


0.16146425620929614 0.14157386549881526
Epoch:  31


0.16251981137572108 0.1450320439679282
Epoch:  32


0.1604765814703864 0.13777297948087966
Epoch:  33


0.16019932764607506 0.14305298775434494
Epoch:  34


0.15961216108219042 0.1465893451656614
Epoch:  35


0.15920941813572034 0.1371255080614771
Epoch:  36


0.15852929893377665 0.13501690328121185
Epoch:  37


0.1583502429562646 0.13499782873051508
Epoch:  38


0.1588420457131154 0.1366928858416421
Epoch:  39


0.15783221335024447 0.13738214863198145
Epoch:  40


0.15898802554285205 0.14375687284129007
Epoch:  41


0.15721023123006564 0.1370434643966811
Epoch:  42


0.15750448647383097 0.13454977103642055
Epoch:  43


0.1566130341710271 0.13541418739727565
Epoch:  44


0.1566978217782201 0.13280046837670462
Epoch:  45


0.1567739095236804 0.13044081734759466
Epoch:  46


0.1567383106495883 0.1331470725791795
Epoch:  47


0.15677420029769074 0.1317978235227721
Epoch:  48


0.15540452744509722 0.13310126747403825
Epoch:  49


0.1552306079381221 0.12899198595966613
Epoch:  50


0.15345672981159106 0.13412117106573923
Epoch:  51


0.1541334092617035 0.1312471181154251
Epoch:  52


0.15459243550493912 0.13314029787267959
Epoch:  53


0.1544432221232234 0.13257361203432083
Epoch:  54


0.15463849660512563 0.13283518063170568
Epoch:  55


0.15393284768671603 0.1314729984317507
Epoch    55: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  56


0.15276827965233777 0.12808399966784886
Epoch:  57


0.1524096433375333 0.126903130539826
Epoch:  58


0.15274857105435552 0.1266293397971562
Epoch:  59


0.15164606998095642 0.12670188503605978
Epoch:  60


0.15108204330946948 0.1271546802350453
Epoch:  61


0.1507462025494189 0.12717059361083166
Epoch:  62


0.15086313759958422 0.12611609165157592
Epoch:  63


0.1513377975773167 0.12712675545896804
Epoch:  64


0.1510052809844146 0.12589771619864873
Epoch:  65


0.1511259638779872 0.12635833557162965
Epoch:  66


0.1514574007408039 0.1262084139244897
Epoch:  67


0.14996894549679113 0.12622291488306864
Epoch:  68


0.1504355056865795 0.1274733149579593
Epoch:  69


0.1504669982839275 0.12696680745908193
Epoch:  70


0.15003661207250646 0.12724962936980383
Epoch    70: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  71


0.15136467282836502 0.12683855848652975
Epoch:  72


0.14876031875610352 0.12638166546821594
Epoch:  73


0.1508565287332277 0.1285245716571808
Epoch:  74


0.15013027674443014 0.12629663731370652
Epoch:  75


0.14923705845265775 0.12654539197683334
Epoch:  76


0.15004347264766693 0.12624131355966842
Epoch    76: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  77


0.1495629057690904 0.12641726647104537
Epoch:  78


0.15037511208572904 0.12655358548675263
Epoch:  79


0.15013712604303617 0.1263121941259929
Epoch:  80


0.14962522967441663 0.12610107766730444
Epoch:  81


0.15012379635024714 0.12620866298675537
Epoch:  82


0.14932905580546404 0.1265116802283696
Epoch    82: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  83


0.14994903392082937 0.12633730896881648
Epoch:  84


0.1493065945200018 0.12626836129597255
Epoch:  85


0.14933335700550596 0.12652731899704253
Epoch:  86


0.1488335921957686 0.12659495111022676
Epoch:  87


0.14876942296285886 0.12617199548653193
Epoch:  88


0.1499305468153309 0.1267425673348563
Epoch    88: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  89
