In [1]:
# Parameters
until_x = -2


In [2]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import sys
sys.path.append("..")
from RandomErasing import RandomErasing

import torchvision.models
from torchvision import transforms

from albumentations import Compose, ShiftScaleRotate, GridDistortion
from albumentations.pytorch import ToTensor

In [3]:
def prepare_data(df, unknown_to_known):
    df = df.reset_index()
    df['slno'] = df.assign(slno=1).groupby('audio_filename')['slno'].cumsum()
    df.set_index(['audio_filename', 'slno'], inplace=True)

    df_unknown = df.copy().loc[:, list(unknown_to_known.keys())]
    df.drop(columns=list(unknown_to_known.keys()), inplace=True)

    y_mask = df.copy()
    y_mask.loc[:, :] = 1
    for unknown, known in unknown_to_known.items():
        y_mask.loc[
            df_unknown[unknown] > 0.5,
            known
        ] = 0

    df = df.swaplevel(i=1, j=0, axis=0).sort_index()

    y_mask = y_mask.swaplevel(i=1, j=0, axis=0).sort_index()

    y = np.concatenate([
        df.loc[[1], :].values[..., np.newaxis],
        df.loc[[2], :].values[..., np.newaxis],
        df.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    y_mask = np.concatenate([
        y_mask.loc[[1], :].values[..., np.newaxis],
        y_mask.loc[[2], :].values[..., np.newaxis],
        y_mask.loc[[3], :].values[..., np.newaxis]
    ], axis=2)

    X = np.concatenate([
        np.expand_dims(np.load('../../data/logmelspec/{}.npy'.format(x)).T[:635, :], axis=0)
        for x in df.loc[[1], :].reset_index(1).audio_filename.tolist()])
    X = np.expand_dims(X, axis=1)

    return X, y, y_mask


random_erasing = RandomErasing()


class AudioDataset(Dataset):

    def __init__(self, X, y, weights, transform=None):
        self.X = X
        self.y = y
        self.weights = weights
        self.transform = transform
        self.pil = transforms.ToPILImage()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        sample = self.X[idx, ...]

        if self.transform:
            # min-max transformation
            this_min = sample.min()
            this_max = sample.max()
            sample = (sample - this_min) / (this_max - this_min)

            # randomly cycle the file
            i = np.random.randint(sample.shape[1])
            sample = torch.cat([
                sample[:, i:, :],
                sample[:, :i, :]],
                dim=1)

            # apply albumentations transforms
            sample = np.array(self.pil(sample))
            sample = self.transform(image=sample)
            sample = sample['image']
            sample = sample[None, :, :].permute(0, 2, 1)

            # apply random erasing
            sample = random_erasing(sample.clone().detach())

            # revert min-max transformation
            sample = (sample * (this_max - this_min)) + this_min

        return sample, self.y[idx, ...], self.weights[idx, ...]


def weight_reset(layer):
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()


class Task5Model(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.mv2 = torchvision.models.mobilenet_v2(pretrained=True)
        
        # Reset after ith layer of mv2
        for i, x in enumerate(self.mv2.features.children()):
            if i > until_x:
                x.apply(weight_reset)

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))

    def forward(self, x):
        x = self.bw2col(x)
        x = self.mv2.features(x)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        x = self.final(x)
        return x

In [4]:
# Load and prepare data
with open('../../data/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)

unknown_to_known = (
    pd.merge(metadata['taxonomy_df'].loc[lambda x: x.fine_id == 'X', ['fine', 'coarse']],
             metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X', ['fine', 'coarse']],
             on='coarse', how='inner')
    .drop(columns='coarse')
    .groupby('fine_x')['fine_y']
    .apply(lambda x: list(x)).to_dict())
known_labels = metadata['taxonomy_df'].loc[lambda x: x.fine_id != 'X'].fine.tolist()

train_df = pd.concat([metadata['coarse_train'], metadata['fine_train']], axis=1, sort=True)
valid_df = pd.concat([metadata['coarse_test'], metadata['fine_test']], axis=1, sort=True)

In [5]:
# manual correction for one data point
train_df.loc[(train_df.sum(axis=1) == 37).copy(), :] = 0
valid_df.loc[(valid_df.sum(axis=1) == 37).copy(), :] = 0

In [6]:
train_X, train_y, train_y_mask = prepare_data(train_df, unknown_to_known)
valid_X, valid_y, valid_y_mask = prepare_data(valid_df, unknown_to_known)

In [7]:
# Channel wise normalization
channel_means = train_X.reshape(-1, 128).mean(axis=0).reshape(1, 1, 1, -1)
channel_stds = train_X.reshape(-1, 128).std(axis=0).reshape(1, 1, 1, -1)
train_X = (train_X - channel_means) / channel_stds
valid_X = (valid_X - channel_means) / channel_stds

In [8]:
# Define the data augmentation transformations
albumentations_transform = Compose([
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0.5),
    GridDistortion(),
    ToTensor()
])

In [9]:
# Create the datasets and the dataloaders
train_dataset = AudioDataset(torch.Tensor(train_X),
                             torch.Tensor(train_y),
                             torch.Tensor(train_y_mask),
                             albumentations_transform)
valid_dataset = AudioDataset(torch.Tensor(valid_X),
                             torch.Tensor(valid_y),
                             torch.Tensor(valid_y_mask),
                             None)

val_loader = DataLoader(valid_dataset, 96, shuffle=False)
train_loader_1 = DataLoader(train_dataset, 96, shuffle=True)
train_loader_2 = DataLoader(train_dataset, 96, shuffle=True)

In [10]:
# Define the device to be used
cuda = True
device = torch.device('cuda:0' if cuda else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [11]:
# Instantiate the model
model = Task5Model(31).to(device)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/checkpoints/mobilenet_v2-b0353104.pth


  0%|                                                                                                                                                                       | 0/14212972 [00:00<?, ?it/s]

  0%|▌                                                                                                                                                      | 49152/14212972 [00:00<00:40, 346880.02it/s]

  2%|██▌                                                                                                                                                   | 237568/14212972 [00:00<00:31, 449852.93it/s]

  5%|███████                                                                                                                                               | 663552/14212972 [00:00<00:22, 614801.05it/s]

 13%|██████████████████▉                                                                                                                                  | 1802240/14212972 [00:00<00:14, 854401.85it/s]

 31%|██████████████████████████████████████████████▍                                                                                                     | 4464640/14212972 [00:00<00:08, 1204014.16it/s]

 82%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                          | 11665408/14212972 [00:00<00:01, 1707781.66it/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14212972/14212972 [00:00<00:00, 18971079.59it/s]




In [12]:
# Define optimizer, scheduler and loss criteria
optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)
criterion = nn.BCEWithLogitsLoss(reduction='none')

In [13]:
epochs = 100
train_loss_hist = []
valid_loss_hist = []
lowest_val_loss = np.inf
epochs_without_new_lowest = 0

for i in range(epochs):
    print('Epoch: ', i)

    this_epoch_train_loss = 0
    for i1, i2 in zip(train_loader_1, train_loader_2):

        # mixup the inputs ---------
        alpha = 1
        mixup_vals = np.random.beta(alpha, alpha, i1[0].shape[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1, 1))
        inputs = (lam * i1[0]) + ((1 - lam) * i2[0])

        lam = torch.Tensor(mixup_vals.reshape(mixup_vals.shape[0], 1, 1))
        labels = (lam * i1[1]) + ((1 - lam) * i2[1])
        masks = (lam * i1[2]) + ((1 - lam) * i2[2])
        # mixup ends ----------

        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            model = model.train()
            outputs = model(inputs)
            # calculate loss for each set of annotations
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            loss.backward()
            optimizer.step()
            this_epoch_train_loss += loss.detach().cpu().numpy()

    this_epoch_valid_loss = 0
    for inputs, labels, masks in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            model = model.eval()
            outputs = model(inputs)
            loss_0 = criterion(outputs, labels[:, :, 0]) * masks[:, :, 0]
            loss_1 = criterion(outputs, labels[:, :, 1]) * masks[:, :, 1]
            loss_2 = criterion(outputs, labels[:, :, 2]) * masks[:, :, 2]
            loss = (loss_0.sum() + loss_1.sum() + loss_2.sum()) / masks.sum()
            this_epoch_valid_loss += loss.detach().cpu().numpy()

    this_epoch_train_loss /= len(train_loader_1)
    this_epoch_valid_loss /= len(val_loader)

    train_loss_hist.append(this_epoch_train_loss)
    valid_loss_hist.append(this_epoch_valid_loss)

    if this_epoch_valid_loss < lowest_val_loss:
        lowest_val_loss = this_epoch_valid_loss
        torch.save(model.state_dict(), './model_system1_until_{}'.format(until_x))
        epochs_without_new_lowest = 0
    else:
        epochs_without_new_lowest += 1

    if epochs_without_new_lowest >= 15:
        break

    print(this_epoch_train_loss, this_epoch_valid_loss)

    scheduler.step(this_epoch_valid_loss)

Epoch:  0


0.6548934149742126 0.6088310599327087
Epoch:  1


0.43257912278175353 0.339142781496048
Epoch:  2


0.23011729657649993 0.2056727647781372
Epoch:  3


0.18757471799850464 0.1786372721195221
Epoch:  4


0.18037922203540802 0.17809796035289766
Epoch:  5


0.176707022190094 0.19625794887542725
Epoch:  6


0.1766222071647644 0.18495525419712067
Epoch:  7


0.17356403648853302 0.17231752574443818
Epoch:  8


0.17322466671466827 0.17770058512687684
Epoch:  9


0.1722402322292328 0.2407662957906723
Epoch:  10


0.17062930047512054 0.1874350756406784
Epoch:  11


0.17094181954860688 0.1576792061328888
Epoch:  12


0.16941840887069703 0.1526887059211731
Epoch:  13


0.1673532736301422 0.16020306944847107
Epoch:  14


0.16779638171195985 0.194276362657547
Epoch:  15


0.16589277982711792 0.1479882389307022
Epoch:  16


0.1668925577402115 0.19421503841876983
Epoch:  17


0.16667916417121886 0.1525086998939514
Epoch:  18


0.1644790291786194 0.14901018440723418
Epoch:  19


0.1639570873975754 0.1525702714920044
Epoch:  20


0.1641359531879425 0.1466073513031006
Epoch:  21


0.1640711259841919 0.1434185564517975
Epoch:  22


0.16271285176277162 0.1630036175251007
Epoch:  23


0.16129158616065978 9.265300750732422
Epoch:  24


0.1639593994617462 0.16198918223381042
Epoch:  25


0.16245438516139984 0.16569050550460815
Epoch:  26


0.16216805398464204 0.151149719953537
Epoch:  27


0.1621774035692215 0.1470192015171051
Epoch    27: reducing learning rate of group 0 to 1.0000e-04.
Epoch:  28


0.16187590718269348 0.13806078732013702
Epoch:  29


0.15766273856163024 0.1363751620054245
Epoch:  30


0.1585393500328064 0.13626811653375626
Epoch:  31


0.15919922947883605 0.1356632187962532
Epoch:  32


0.15814794301986695 0.13524107336997987
Epoch:  33


0.1574694412946701 0.13577266782522202
Epoch:  34


0.1582166635990143 0.13506992757320405
Epoch:  35


0.15753716766834258 0.1360290139913559
Epoch:  36


0.15808188498020173 0.13530665189027785
Epoch:  37


0.1586457884311676 0.13519082814455033
Epoch:  38


0.15840860188007355 0.1346396893262863
Epoch:  39


0.15853190124034883 0.13584672510623932
Epoch:  40


0.1573820835351944 0.13439613431692124
Epoch:  41


0.1572451514005661 0.13476865440607072
Epoch:  42


0.15861910760402678 0.13376125395298005
Epoch:  43


0.15715474665164947 0.13530684560537337
Epoch:  44


0.15697694957256317 0.1353336438536644
Epoch:  45


0.1576683646440506 0.13582926988601685
Epoch:  46


0.15638633847236633 0.13445681482553482
Epoch:  47


0.1560369026660919 0.13317035287618637
Epoch:  48


0.15581931829452514 0.13340434581041336
Epoch:  49


0.15618660151958466 0.13361142873764037
Epoch:  50


0.15706174552440644 0.13363925814628602
Epoch:  51


0.15588376462459563 0.1342515155673027
Epoch:  52


0.15681926846504213 0.13370444029569625
Epoch:  53


0.15639331102371215 0.13548713624477388
Epoch    53: reducing learning rate of group 0 to 1.0000e-05.
Epoch:  54


0.15668463110923767 0.1341100350022316
Epoch:  55


0.15664627373218537 0.13311704248189926
Epoch:  56


0.1565949535369873 0.13317803740501405
Epoch:  57


0.1568106883764267 0.13290374130010604
Epoch:  58


0.1559508115053177 0.1326472282409668
Epoch:  59


0.15724566400051118 0.13258527517318724
Epoch:  60


0.15709789752960204 0.13243540823459626
Epoch:  61


0.15579065203666687 0.13245181739330292
Epoch:  62


0.15639423489570617 0.13251364082098008
Epoch:  63


0.15564496874809264 0.13244261741638183
Epoch:  64


0.15649852573871612 0.1324290633201599
Epoch:  65


0.15626378774642943 0.1324884980916977
Epoch:  66


0.15562741577625275 0.13305846303701402
Epoch    66: reducing learning rate of group 0 to 1.0000e-06.
Epoch:  67


0.15518531441688538 0.1327175170183182
Epoch:  68


0.1554580533504486 0.132789246737957
Epoch:  69


0.15693839013576508 0.13273971080780028
Epoch:  70


0.15692468047142027 0.13237301260232925
Epoch:  71


0.15596487283706664 0.13255232721567153
Epoch:  72


0.15628661394119261 0.13256585747003555
Epoch:  73


0.15744387567043305 0.13235826194286346
Epoch:  74


0.1559969663619995 0.13268985599279404
Epoch:  75


0.15594672858715058 0.13228463381528854
Epoch:  76


0.1551085066795349 0.13251342624425888
Epoch:  77


0.15554336011409758 0.13265000134706498
Epoch:  78


0.15702551186084748 0.13253233730793
Epoch:  79


0.15631657779216768 0.13264208137989045
Epoch:  80


0.15651027143001556 0.1322397455573082
Epoch:  81


0.15629860162734985 0.13251782357692718
Epoch:  82


0.1540591961145401 0.13224446326494216
Epoch:  83


0.15553224444389344 0.13248414248228074
Epoch:  84


0.156521492600441 0.13243428319692613
Epoch:  85


0.15645773828029633 0.13252596259117128
Epoch:  86


0.1555822694301605 0.13260658383369445
Epoch    86: reducing learning rate of group 0 to 1.0000e-07.
Epoch:  87


0.1555362570285797 0.13225715160369872
Epoch:  88


0.15605585932731628 0.13263252824544908
Epoch:  89


0.15591474533081054 0.13259219229221345
Epoch:  90


0.15501319587230683 0.13226936608552933
Epoch:  91


0.15541959106922149 0.1324088156223297
Epoch:  92


0.15513144850730895 0.13229720443487167
Epoch    92: reducing learning rate of group 0 to 1.0000e-08.
Epoch:  93


0.1550103932619095 0.13254128992557526
Epoch:  94


0.15741642773151399 0.13239381164312364
Epoch:  95
