## Train_Generator

In [5]:
import os;
import sys;
import numpy as np;
import random;
from tensorflow import keras;

import glob;
import math;
import time;


In [6]:
sys.path.append(os.getcwd());
sys.path.append(os.path.join(os.getcwd(), 'common'));
import common.utils as U;
import common.opts as opts;
import resources.TLModels as tlmodels;

In [3]:
class Generator(keras.utils.Sequence):
    #Generates data for Keras
    def __init__(self, samples, labels, options):
        random.seed(42);
        #Initialization
        self.data = [(samples[i], labels[i]) for i in range (0, len(samples))];
        self.opt = options;
        self.batch_size = options.batchSize;
        self.preprocess_funcs = self.preprocess_setup();

    def __len__(self):
        #Denotes the number of batches per epoch
        return int(np.floor(len(self.data) / self.batch_size));
        #return len(self.samples);

    def __getitem__(self, batchIndex):
        #Generate one batch of data
        batchX, batchY = self.generate_batch(batchIndex);
        batchX = np.expand_dims(batchX, axis=1);
        batchX = np.expand_dims(batchX, axis=3);
        return batchX, batchY

    def generate_batch(self, batchIndex):
        #Generates data containing batch_size samples
        sounds = [];
        labels = [];
        indexes = None;
        for i in range(self.batch_size):
            # Training phase of BC learning
            # Select two training examples
            while True:
                sound1, label1 = self.data[random.randint(0, len(self.data) - 1)]
                sound2, label2 = self.data[random.randint(0, len(self.data) - 1)]
                if label1 != label2:
                    break
            sound1 = self.preprocess(sound1)
            sound2 = self.preprocess(sound2)

            # Mix two examples
            r = np.array(random.random())
            sound = U.mix(sound1, sound2, r, self.opt.sr).astype(np.float32)
            eye = np.eye(self.opt.nClasses)
            label = (eye[label1] * r + eye[label2] * (1 - r)).astype(np.float32)

            #For stronger augmentation
            sound = U.random_gain(6)(sound).astype(np.float32)

            sounds.append(sound);
            labels.append(label);

        sounds = np.asarray(sounds);
        labels = np.asarray(labels);

        return sounds, labels;

    def preprocess_setup(self):
        funcs = []
        if self.opt.strongAugment:
            funcs += [U.random_scale(1.25)]

        funcs += [U.padding(self.opt.inputLength // 2),
                  U.random_crop(self.opt.inputLength),
                  U.normalize(32768.0)]
        return funcs

    def preprocess(self, sound):
        for f in self.preprocess_funcs:
            sound = f(sound)

        return sound;

In [4]:
def setup(opt, split):
    dataset = np.load(os.path.join(opt.data, opt.dataset, 'wav{}.npz'.format(opt.sr // 1000)), allow_pickle=True);
    train_sounds = []
    train_labels = []
    for i in range(1, opt.nFolds + 1):
        sounds = dataset['fold{}'.format(i)].item()['sounds']
        labels = dataset['fold{}'.format(i)].item()['labels']
        if i != split:
            train_sounds.extend(sounds)
            train_labels.extend(labels)

    trainGen = Generator(train_sounds, train_labels, opt);

    return trainGen

## Transfer-Learning Trainer

In [None]:
class TLTrainer:
    def __init__(self, opt=None):
        self.opt = opt;
        self.trainGen = train_generator.setup(self.opt, self.opt.split);

    def Train(self):
        # model = tlmodels.GetAcdnetModel(self.opt.inputLength, 50, self.opt.sr, ch_config = self.opt.model_config);
        model  = tlmodels.GetTLACDNet();
        model.summary();

        loss = 'kullback_leibler_divergence';
        optimizer = keras.optimizers.SGD(learning_rate=self.opt.LR, weight_decay=self.opt.weightDecay, momentum=self.opt.momentum, nesterov=True)

        model.compile(loss=loss, optimizer=optimizer , metrics=['accuracy']);

        # learning schedule callback
        lrate = keras.callbacks.LearningRateScheduler(self.GetLR);
        # best_model = keras.callbacks.ModelCheckpoint('tf/trained_models/{}.h5'.format(self.opt.model_name), monitor='val_acc', save_best_only=True, verbose=0);
        custom_evaluator = CustomCallback(self.opt);
        # callbacks_list = [lrate, custom_evaluator, best_model];
        callbacks_list = [lrate, custom_evaluator];

        model.fit(self.trainGen, epochs=self.opt.nEpochs, steps_per_epoch=len(self.trainGen.data)//self.trainGen.batch_size, callbacks=callbacks_list, verbose=0);

    def GetLR(self, epoch):
        divide_epoch = np.array([self.opt.nEpochs * i for i in self.opt.schedule]);
        decay = sum(epoch > divide_epoch);
        if epoch <= self.opt.warmup:
            decay = 1;
        return self.opt.LR * np.power(0.1, decay);

In [7]:
class CustomCallback(keras.callbacks.Callback):
    def __init__(self, opt):
        self.opt = opt;
        self.testX = None;
        self.testY = None;
        self.curEpoch = 0;
        self.curLr = opt.LR;
        self.cur_epoch_start_time = time.time();
        self.bestAcc = 0.0;
        self.bestAccEpoch = 0;

    def on_epoch_begin(self, epoch, logs=None):
        self.curEpoch = epoch+1;
        self.curLr = Trainer(self.opt).GetLR(epoch+1);
        self.cur_epoch_start_time = time.time();

    def on_epoch_end(self, epoch, logs=None):
        train_time = time.time() - self.cur_epoch_start_time;
        self.load_test_data();
        val_acc, val_loss = self.validate(self.model);
        logs['val_acc'] = val_acc;
        logs['val_loss'] = val_loss;
        if val_acc > self.bestAcc:
            self.bestAcc = val_acc;
            self.bestAccEpoch = epoch + 1;
        epoch_time = time.time() - self.cur_epoch_start_time;
        val_time = epoch_time - train_time;
        # print(logs);
        line = 'SP-{}, Epoch: {}/{} | Time: {} (Train {}  Val {}) | Train: LR {}  Loss {:.2f}  Acc {:.2f}% | Val: Loss {:.2f}  Acc(top1) {:.2f}% | HA {:.2f}@{}\n'.format(
            self.opt.split, epoch+1, self.opt.nEpochs, U.to_hms(epoch_time), U.to_hms(train_time), U.to_hms(val_time),
            self.curLr, logs['loss'], logs['accuracy'] if 'accuracy' in logs else logs['acc'], val_loss, val_acc, self.bestAcc, self.bestAccEpoch);
        # print(line)
        sys.stdout.write(line);
        sys.stdout.flush();

    def load_test_data(self):
        if self.testX is None:
            data = np.load(os.path.join(self.opt.data, self.opt.dataset, 'test_data_20khz/fold{}_test4000.npz'.format(self.opt.split)), allow_pickle=True);
            self.testX = data['x'];
            self.testY = data['y'];

    def validate(self, model):
        y_pred = None;
        y_target = None;
        batch_size = (self.opt.batchSize//self.opt.nCrops)*self.opt.nCrops;
        for batchIndex in range(math.ceil(len(self.testX) / batch_size)):
            x = self.testX[batchIndex*batch_size : (batchIndex+1)*batch_size];
            y = self.testY[batchIndex*batch_size : (batchIndex+1)*batch_size];
            scores = model.predict(x, batch_size=len(y), verbose=0);
            y_pred = scores if y_pred is None else np.concatenate((y_pred, scores));
            y_target = y if y_target is None else np.concatenate((y_target, y));
            #break;

        acc, loss = self.compute_accuracy(y_pred, y_target);
        return acc, loss;

    #Calculating average prediction (10 crops) and final accuracy
    def compute_accuracy(self, y_pred, y_target):
        #Reshape y_pred to shape it like each sample comtains 10 samples.
        if self.opt.nCrops > 1:
            y_pred = (y_pred.reshape(y_pred.shape[0]//self.opt.nCrops, self.opt.nCrops, y_pred.shape[1])).mean(axis=1);
            y_target = (y_target.reshape(y_target.shape[0]//self.opt.nCrops, self.opt.nCrops, y_target.shape[1])).mean(axis=1);

        loss = keras.losses.KLD(y_target, y_pred).numpy().mean();

        #Get the indices that has highest average value for each sample
        y_pred = y_pred.argmax(axis=1);
        y_target = y_target.argmax(axis=1);
        accuracy = (y_pred==y_target).mean()*100;

        return accuracy, loss;