In [1]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor



In [2]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]

In [3]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [4]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [5]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [6]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [7]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [8]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 64, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 64, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 64, shuffle=True)

In [9]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [10]:
def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

In [11]:
until_x = None

In [12]:
# Parameters
until_x = 8


In [13]:
class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
        
        # Reset until ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i <= until_x:
                x.apply(weight_reset)

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [14]:
# Instantiate the model
model = Task5Model(31).to(device)

In [15]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [16]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 25:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6328648167687494 0.5170100842203412
Epoch:  1


0.33302852148945267 0.21512831321784429
Epoch:  2


0.19094356330665382 0.177444726228714
Epoch:  3


0.17964836753703453 0.17853395427976335
Epoch:  4


0.17672164214623942 0.1669597178697586
Epoch:  5


0.1741034368405471 0.15659765047686441
Epoch:  6


0.1740356103793995 0.1623335416827883
Epoch:  7


0.1729664754223179 0.15615052623408182
Epoch:  8


0.1718916252658174 0.30681165414197104
Epoch:  9


0.17245753471915787 0.16913566631930216
Epoch:  10


0.17110784150458672 0.16378241990293776
Epoch:  11


0.16996237716159304 0.20769240174974715
Epoch:  12


0.16833325415044217 0.14736050154481614
Epoch:  13


0.16644026983428645 0.148767956665584
Epoch:  14


0.16517397882165136 0.16486348424639022
Epoch:  15


0.1662451987330978 0.14774611379419053
Epoch:  16


0.1630992176564964 0.14516595963920867
Epoch:  17


0.16169167128769127 0.13915332087448665
Epoch:  18


0.16249732230160688 0.14692489270653045
Epoch:  19


0.16206138037346504 0.1362525394984654
Epoch:  20


0.16040464873249466 0.13786695791142328
Epoch:  21


0.16013550315354322 0.14463522817407334
Epoch:  22


0.16054210147342166 0.13452864225421632
Epoch:  23


0.1597537213080638 0.13323906809091568
Epoch:  24


0.15774331946630735 0.13197136563914164
Epoch:  25


0.15742638989074811 0.13481378448860987
Epoch:  26


0.15951653066519145 0.13717029775891984
Epoch:  27


0.15598538558225375 0.13434152411563055
Epoch:  28


0.15748778791040988 0.1329316252044269
Epoch:  29


0.15760267948782122 0.13364008494785853
Epoch:  30


0.15667266016070908 0.13432497424738749
Epoch    30: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  31


0.15464550578916395 0.12876785333667481
Epoch:  32


0.1547487861401326 0.12828227026121958
Epoch:  33


0.15331975753242905 0.12775120777743204
Epoch:  34


0.15314363788914037 0.12799507592405593
Epoch:  35


0.1539394436655818 0.12869067915848323
Epoch:  36


0.15312700134676857 0.12753537084375108
Epoch:  37


0.15358260797487722 0.12708889267274312
Epoch:  38


0.1530090075086903 0.12723674731595175
Epoch:  39


0.15281806604282275 0.1269450177039419
Epoch:  40


0.15157489156400836 0.12688233809811728
Epoch:  41


0.15243889874703176 0.12722772785595485
Epoch:  42


0.15161966069324598 0.12716420739889145
Epoch:  43


0.15297861880547292 0.12647063497986114
Epoch:  44


0.1519719185055913 0.12661113696438925
Epoch:  45


0.15225625481154467 0.12591173819133214
Epoch:  46


0.15171678364276886 0.1263027393392154
Epoch:  47


0.1510256822850253 0.12734978646039963
Epoch:  48


0.15166558526657722 0.12636055797338486
Epoch:  49


0.152408456882915 0.1263416941676821
Epoch:  50


0.15136826964648994 0.1264177741748946
Epoch:  51


0.1528633937642381 0.12634393998554774
Epoch    51: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  52


0.15074930400461764 0.12625489490372793
Epoch:  53


0.1517894106942254 0.12567034363746643
Epoch:  54


0.15254041955277725 0.12559419976813452
Epoch:  55


0.15103641555115982 0.12583978261266435
Epoch:  56


0.15128517835526853 0.12569401200328553
Epoch:  57


0.15170190261827932 0.1258447766304016
Epoch:  58


0.15065814232504046 0.12574439815112523
Epoch:  59


0.1499120818602072 0.1255045724766595
Epoch:  60


0.15123209719722336 0.12552128838641302
Epoch:  61


0.1520591881629583 0.12574443966150284
Epoch:  62


0.1513501361415193 0.1257265955209732
Epoch:  63


0.14974544821558772 0.12542925029993057
Epoch:  64


0.15059300894672806 0.12564090639352798
Epoch:  65


0.15227167952705073 0.1256039153252329
Epoch:  66


0.15232599667600683 0.12598420679569244
Epoch:  67


0.15222694825481725 0.125726546560015
Epoch:  68


0.15029512345790863 0.12592318228312901
Epoch:  69


0.1504160375208468 0.12573779374361038
Epoch    69: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  70


0.15069028735160828 0.12581389184509004
Epoch:  71


0.15130393931994568 0.12567978565182006
Epoch:  72


0.1500583641432427 0.12559024670294353
Epoch:  73


0.1514351267266918 0.12582973603691375
Epoch:  74


0.15080045606638934 0.12568010815552302
Epoch:  75


0.15121985085912654 0.12557838750737055
Epoch    75: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  76


0.15098655586307114 0.12577358420406068
Epoch:  77


0.15070466173661723 0.12592349733625138
Epoch:  78


0.15100068417755333 0.12589235710246222
Epoch:  79


0.15048123251747442 0.12561548820563725
Epoch:  80


0.15037954175794446 0.1257752084306308
Epoch:  81


0.1502130659850868 0.12562565611941473
Epoch    81: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  82


0.1512684246172776 0.125571251979896
Epoch:  83


0.15161818346461733 0.1256106293627194
Epoch:  84


0.15209889774386948 0.12570962735584804
Epoch:  85


0.15075862166043874 0.12571737063782557
Epoch:  86


0.15115164099512873 0.12551976953233993
Epoch:  87


0.1501866107856905 0.1254795715212822
Epoch:  88
