In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 12


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6351420814926559 0.5022036858967373
Epoch:  1


0.3351763088155437 0.20105066469737462
Epoch:  2


0.1921169745760995 0.17418464592524938
Epoch:  3


0.18052732904215116 0.1819809377193451
Epoch:  4


0.17567438008011999 0.17418428191116878
Epoch:  5


0.1734967884179708 0.15866474381514958
Epoch:  6


0.1720557502798132 0.1647740602493286
Epoch:  7


0.17079201703136032 0.15067538831915175
Epoch:  8


0.17070055370395248 0.16790652913706644
Epoch:  9


0.16910990750467456 0.1554839653628213
Epoch:  10


0.16824918300718875 0.1539553361279624
Epoch:  11


0.16770069019214526 0.14547744393348694
Epoch:  12


0.16733467820528392 0.1493618403162275
Epoch:  13


0.1653703930410179 0.1495118790439197
Epoch:  14


0.16607062599143466 0.14802950939961843
Epoch:  15


0.16606145533355507 0.14691916640315736
Epoch:  16


0.16463330831076647 0.145212037222726
Epoch:  17


0.1649940984474646 0.14895038093839372
Epoch:  18


0.1642183729925671 0.1437478129352842
Epoch:  19


0.16246752440929413 0.14449641002076014
Epoch:  20


0.1631380114200953 0.14499598102910177
Epoch:  21


0.1635578709679681 0.14949784108570643
Epoch:  22


0.16311018611933734 0.14243677790675843
Epoch:  23


0.1608902940878997 0.13909272317375457
Epoch:  24


0.16268126706819278 0.1404229604772159
Epoch:  25


0.16043865761241397 0.141819566488266
Epoch:  26


0.15980542632373604 0.1371953029717718
Epoch:  27


0.16157788320167646 0.14024039570774352
Epoch:  28


0.16083903932893598 0.1436158259000097
Epoch:  29


0.16146357196408348 0.13445134780236653
Epoch:  30


0.15978338790906443 0.13535504362412862
Epoch:  31


0.15956125267454097 0.1394774807350976
Epoch:  32


0.15873690110606117 0.13722056043999536
Epoch:  33


0.15872978237835136 0.1426125168800354
Epoch:  34


0.1594635163610046 0.1324458909886224
Epoch:  35


0.1573567281703691 0.13691458318914687
Epoch:  36


0.15693408452175758 0.1348840328199523
Epoch:  37


0.15836353600025177 0.1385918887598174
Epoch:  38


0.15694540336325363 0.13357648040567124
Epoch:  39


0.15659616484835343 0.13382441124745778
Epoch:  40


0.1575881162205258 0.13310407102108002
Epoch    40: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  41


0.1541543349220946 0.12793634406157903
Epoch:  42


0.15295225140210744 0.1277360532964979
Epoch:  43


0.15341279192550764 0.12728231506688253
Epoch:  44


0.15282213486529686 0.12682250993592398
Epoch:  45


0.1536671869658135 0.1271331352846963
Epoch:  46


0.15286742191056948 0.12699393502303533
Epoch:  47


0.15300218801240664 0.12632904840367182
Epoch:  48


0.15171252674347646 0.12682908879859106
Epoch:  49


0.15218571593632568 0.12657883763313293
Epoch:  50


0.15374894279080467 0.12837132811546326
Epoch:  51


0.15229489795259527 0.12731085185493743
Epoch:  52


0.15270207822322845 0.12707005547625677
Epoch:  53


0.1525779930320946 0.12643505526440485
Epoch    53: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  54


0.15141474314638087 0.12616359548909323
Epoch:  55


0.15184666055279808 0.12631815565483911
Epoch:  56


0.1533327569832673 0.12604444473981857
Epoch:  57


0.15207730878043818 0.12649258013282502
Epoch:  58


0.15269601908889976 0.12678719418389456
Epoch:  59


0.1513287300193632 0.12626390478440694
Epoch:  60


0.1525123473760244 0.1265091033918517
Epoch:  61


0.15226890549466415 0.12661136899675643
Epoch:  62


0.1518864011442339 0.12648983938353403
Epoch    62: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  63


0.15132499747985118 0.1259847953915596
Epoch:  64


0.1509278831449715 0.1263710292322295
Epoch:  65


0.15243909729493632 0.12631163958992278
Epoch:  66


0.15182130159558477 0.12625312272991454
Epoch:  67


0.1519779478375976 0.1264314992087228
Epoch:  68


0.15119306582051353 0.12615734870944703
Epoch:  69


0.1518493624957832 0.12632584146090917
Epoch    69: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  70


0.15140739284657143 0.12642860944781983
Epoch:  71


0.15206127432552544 0.12629746113504683
Epoch:  72


0.15067776997347135 0.12613060431821005
Epoch:  73


0.15232354402542114 0.12627451228243963
Epoch:  74


0.15334698920314377 0.126144640147686
Epoch:  75


0.1521760591784039 0.12622443586587906
Epoch    75: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  76


0.15207107163764336 0.12632855347224645
Epoch:  77


0.1513951618123699 0.12640558076756342
Epoch:  78


0.15261023995038625 0.12649568915367126
Epoch:  79


0.15197052504565264 0.1262932015316827
Epoch:  80


0.15283836263257103 0.12614475190639496
Epoch:  81


0.15225212678715988 0.126291067472526
Epoch:  82


0.15191229533504796 0.12615292093583516
Epoch:  83


0.15218041515028155 0.12647233584097453
Epoch:  84


0.15219334976093188 0.12654537281819753
Epoch:  85


0.15109391953494097 0.12632245463984354
Epoch:  86


0.1517612901893822 0.12634060744728362
Epoch:  87


0.15216203879665685 0.1264984788639205
Epoch:  88
