In [1]:
import sys;
import os;
import glob;
import math;
import numpy as np;
import glob;
import random;
import time;
import torch;
import torch.optim as optim;
import torch.nn as nn;

sys.path.append(os.getcwd());
sys.path.append('../');
# sys.path.append(os.path.join(os.getcwd(), 'torch/resources'));
import common.utils as U;
import common.opts as opts;
import resources.models as models;
import resources.calculator as calc;
import common.tlopts as tlopts
# import resources.train_generator as train_generator;
import argparse
from itertools import repeat

In [2]:
#Reproducibility
seed = 42;
random.seed(seed);
np.random.seed(seed);
torch.manual_seed(seed);
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed);
torch.backends.cudnn.deterministic = True;
torch.backends.cudnn.benchmark = False;

## define TLTraining Generator Class
The Class is an python iterator class for generating data for trainer to train the model.

In [3]:
class TLGenerator():
    #Generates data for Keras
    def __init__(self, samples, labels, options):
        random.seed(42);
        #Initialization
        print(f"length of samples:{len(samples)}")
        self.data = [(samples[i], labels[i]) for i in range (0, len(samples))];
        self.opt = options;
        self.batch_size = options.batchSize;
        self.preprocess_funcs = self.preprocess_setup();
        self.mapdict = {
            17:1, #pouring_water
            18:2, #toilet_flushing
            21:3, #snezzing
            24:4, #coughing
            51:5, #kettle_sound
            52:6, #alarm
            #53:"53_boiling_water_bubble_sound", #boiling_water_bubble_sound
            54:7, #rington
            55:8, #shower_water
            56:9, #pain_sounds
            57:10, #footsteps
            98:11, #silence
            99:12, #other_sounds
        };

    def __len__(self):
        #Denotes the number of batches per epoch
        return int(np.floor(len(self.data) / self.batch_size));
        #return len(self.samples);

    def __getitem__(self, batchIndex):
        #Generate one batch of data
        batchX, batchY = self.generate_batch(batchIndex);
        batchX = np.expand_dims(batchX, axis=1);
        batchX = np.expand_dims(batchX, axis=3);
        return batchX, batchY

    def generate_batch(self, batchIndex):
        #Generates data containing batch_size samples
        sounds = [];
        labels = [];
        indexes = None;
        for i in range(self.batch_size):
            # Training phase of BC learning
            # Select two training examples
            while True:
                sound1, label1 = self.data[random.randint(0, len(self.data) - 1)]
                sound2, label2 = self.data[random.randint(0, len(self.data) - 1)]
                if label1 != label2:
                    break
            sound1 = self.preprocess(sound1)
            sound2 = self.preprocess(sound2)

            # Mix two examples
            r = np.array(random.random())
            sound = U.mix(sound1, sound2, r, self.opt.sr).astype(np.float32)
            # print(f"sound length after U.mix is {len(sound)}")
            eye = np.eye(self.opt.nClasses)
            idx1 = self.mapdict[label1]- 1
            idx2 = self.mapdict[label2] - 1
            label = (eye[idx1] * r + eye[idx2] * (1 - r)).astype(np.float32)
            # label = (eye[label1] * r + eye[label2] * (1 - r)).astype(np.float32)

            #For stronger augmentation
            sound = U.random_gain(6)(sound).astype(np.float32)
            # print(f"sound length after U.random_gain is {len(sound)}")
            sounds.append(sound);
            labels.append(label);

        sounds = np.asarray(sounds);
        labels = np.asarray(labels);
        print(f"total sounds is {len(sounds)}")
        # print(f"labels in generate_batch is:\n{labels}")

        return sounds, labels;

    def preprocess_setup(self):
        funcs = []
        if self.opt.strongAugment:
            funcs += [U.random_scale(1.25)]

        funcs += [U.padding(self.opt.inputLength // 2),
                  U.random_crop(self.opt.inputLength),
                  U.normalize(32768.0)]
        return funcs

    def preprocess(self, sound):
        for f in self.preprocess_funcs:
            sound = f(sound)

        return sound;

## ACDNetV2 define the acdnet model structure.

In [4]:
class ACDNetV2(nn.Module):
    def __init__(self, input_length, n_class, sr, ch_conf=None):
        super(ACDNetV2, self).__init__();
        self.input_length = input_length;
        self.ch_config = ch_conf;

        stride1 = 2;
        stride2 = 2;
        channels = 8;
        k_size = (3, 3);
        n_frames = (sr/1000)*10; #No of frames per 10ms

        sfeb_pool_size = int(n_frames/(stride1*stride2));
        # tfeb_pool_size = (2,2);
        if self.ch_config is None:
            self.ch_config = [channels, channels*8, channels*4, channels*8, channels*8, channels*16, channels*16, channels*32, channels*32, channels*64, channels*64, n_class];
        # avg_pool_kernel_size = (1,4) if self.ch_config[1] < 64 else (2,4);
        fcn_no_of_inputs = self.ch_config[-1];
        conv1, bn1 = self.make_layers(1, self.ch_config[0], (1, 9), (1, stride1));
        conv2, bn2 = self.make_layers(self.ch_config[0], self.ch_config[1], (1, 5), (1, stride2));
        conv3, bn3 = self.make_layers(1, self.ch_config[2], k_size, padding=1);
        conv4, bn4 = self.make_layers(self.ch_config[2], self.ch_config[3], k_size, padding=1);
        conv5, bn5 = self.make_layers(self.ch_config[3], self.ch_config[4], k_size, padding=1);
        conv6, bn6 = self.make_layers(self.ch_config[4], self.ch_config[5], k_size, padding=1);
        conv7, bn7 = self.make_layers(self.ch_config[5], self.ch_config[6], k_size, padding=1);
        conv8, bn8 = self.make_layers(self.ch_config[6], self.ch_config[7], k_size, padding=1);
        conv9, bn9 = self.make_layers(self.ch_config[7], self.ch_config[8], k_size, padding=1);
        conv10, bn10 = self.make_layers(self.ch_config[8], self.ch_config[9], k_size, padding=1);
        conv11, bn11 = self.make_layers(self.ch_config[9], self.ch_config[10], k_size, padding=1);
        conv12, bn12 = self.make_layers(self.ch_config[10], self.ch_config[11], (1, 1));
        fcn = nn.Linear(fcn_no_of_inputs, n_class);
        nn.init.kaiming_normal_(fcn.weight, nonlinearity='sigmoid') # kaiming with sigoid is equivalent to lecun_normal in keras

        self.sfeb = nn.Sequential(
            #Start: Filter bank
            conv1, bn1, nn.ReLU(),\
            conv2, bn2, nn.ReLU(),\
            nn.MaxPool2d(kernel_size=(1, sfeb_pool_size))
        );

        tfeb_modules = [];
        self.tfeb_width = int(((self.input_length / sr)*1000)/10); # 10ms frames of audio length in seconds
        tfeb_pool_sizes = self.get_tfeb_pool_sizes(self.ch_config[1], self.tfeb_width);
        p_index = 0;
        for i in [3,4,6,8,10]:
            tfeb_modules.extend([eval('conv{}'.format(i)), eval('bn{}'.format(i)), nn.ReLU()]);

            if i != 3:
                tfeb_modules.extend([eval('conv{}'.format(i+1)), eval('bn{}'.format(i+1)), nn.ReLU()]);

            h, w = tfeb_pool_sizes[p_index];
            if h>1 or w>1:
                tfeb_modules.append(nn.MaxPool2d(kernel_size = (h,w)));
            p_index += 1;

        tfeb_modules.append(nn.Dropout(0.2));
        tfeb_modules.extend([conv12, bn12, nn.ReLU()]);
        h, w = tfeb_pool_sizes[-1];
        if h>1 or w>1:
            tfeb_modules.append(nn.AvgPool2d(kernel_size = (h,w)));
        tfeb_modules.extend([nn.Flatten(), fcn]);

        self.tfeb = nn.Sequential(*tfeb_modules);

        self.output = nn.Sequential(
            nn.Softmax(dim=1)
        );
        

    def forward(self, x):
        x = self.sfeb(x);
        #swapaxes
        x = x.permute((0, 2, 1, 3));
        x = self.tfeb(x);
        y = self.output[0](x);
        return y;

    def make_layers(self, in_channels, out_channels, kernel_size, stride=(1,1), padding=0, bias=False):
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias);
        nn.init.kaiming_normal_(conv.weight, nonlinearity='relu'); # kaiming with relu is equivalent to he_normal in keras
        bn = nn.BatchNorm2d(out_channels);
        return conv, bn;

    def get_tfeb_pool_sizes(self, con2_ch, width):
        h = self.get_tfeb_pool_size_component(con2_ch);
        w = self.get_tfeb_pool_size_component(width);
        # print(w);
        pool_size = [];
        for  (h1, w1) in zip(h, w):
            pool_size.append((h1, w1));
        return pool_size;

    def get_tfeb_pool_size_component(self, length):
        # print(length);
        c = [];
        index = 1;
        while index <= 6:
            if length >= 2:
                if index == 6:
                    c.append(length);
                else:
                    c.append(2);
                    length = length // 2;
            else:
               c.append(1);

            index += 1;

        return c;

def GetACDNetModel(input_len=30225, nclass=50, sr=20000, channel_config=None):
    net = ACDNetV2(input_len, nclass, sr, ch_conf=channel_config);
    return net;

## load pretrained acdnet weights of 20khz

In [5]:
# acdnet_model = GetACDNetModel()
# pretrain_weight= torch.load('./resources/pretrained_models/acdnet_20khz_trained_model_fold4_91.00.pt', map_location=torch.device('cpu'))['weight']

# model_state = acdnet_model.state_dict()
# model_state.update(pretrain_weight)
# acdnet_model.load_state_dict(pretrain_weight, strict=False)

# for k, v in pretrain_weight['weight'].items():
#     print("name:", k)
#     print("\n")

# remove the unexpected keys: weight and config
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in checkpoint.items():
#     name = k.replace("weight", "") # remove `module.`
#     new_state_dict[name] = v
#     name = k.replace("config", "") # remove `module.`
#     new_state_dict[name] = v

# model_state = acdnet_model.state_dict()
# model_state.update(new_state_dict)
# acdnet_model.load_state_dict(new_state_dict, strict=False)

# print("acdnet_model state_dict:\n",acdnet_model.state_dict())
# print("pretrain_weight: \n",pretrain_weight)  

In [6]:
# layer_38_of_tfeb = list(acdnet_model.tfeb.children())[38]

# print(layer_38_of_tfeb)
# print(nn.Sequential(*list(acdnet_model.tfeb.children())[:-6]))
# print(nn.Sequential(*list(acdnet_model.tfeb.children())))
# print(acdnet_model)
# for item_v in nn.Sequential(*list(acdnet_model.tfeb.children())):
#     for internal_k, internal_v in item_v.named_parameters():
#         print(internal_v.requires_grad)

In [7]:
# print(acdnet_model.fc)
#acdnet 包含三部份：sfeb, tfeb and output
# print(nn.Sequential(*list(acdnet_model.children())))
# print(nn.Sequential(*list(acdnet_model.children())[:-1]))
# for k, v in acdnet_model.named_parameters():
#     print("key:", k)
#     v.requires_grad = False

# acdnet_model.fcn = nn.Linear(num_ftrs, 10)
# print(acdnet_model)

In [45]:
def getOpts():
    parser = argparse.ArgumentParser(description='Transfer Learning for ACDNet');
    parser.add_argument('--netType', default='ACDNet_TL_Model_Extend',  required=False);
    parser.add_argument('--data', default='../datasets/processed/',  required=False);
    parser.add_argument('--dataset', required=False, default='uec_iot', choices=['10']);
    parser.add_argument('--BC', default=True, action='store_true', help='BC learning');
    parser.add_argument('--strongAugment', default=True,  action='store_true', help='Add scale and gain augmentation');
    #在ipynb中，不能使用parser.parse，要改用parser.parse_known_args()
    opt, unknown = parser.parse_known_args()
    #Leqarning settings
    opt.batchSize = 32;
    opt.weightDecay = 5e-4;
    opt.momentum = 0.09;
    opt.nEpochs = 1000;#2000;
    opt.LR = 0.1;
    opt.schedule = [0.3, 0.6, 0.9];
    opt.warmup = 10;
    if torch.backends.mps.is_available():
        opt.device="mps"; #for apple m2 gpu
    elif torch.cuda.is_available():
        opt.device="cuda:0"; #for nVidia gpu
    else:
        opt.device="cpu"
    print(f"***Use device:{opt.device}");
    # opt.device = torch.device("cuda:0" if  else "cpu");
    #Basic Net Settings
    opt.nClasses = 12#50;
    opt.nFolds = 5;
    opt.splits = [i for i in range(1, opt.nFolds + 1)];
    opt.sr = 20000;
    opt.inputLength = 30225;
    #Test data
    opt.nCrops = 1;
    return opt
    # opt = parser.parse_args();

In [46]:
def make_layers(in_channels, out_channels, kernel_size, stride=(1,1), padding=0, bias=False):
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias);
        nn.init.kaiming_normal_(conv.weight, nonlinearity='relu'); # kaiming with relu is equivalent to he_normal in keras
        bn = nn.BatchNorm2d(out_channels);
        return conv, bn;

In [47]:
ch_confing_10 = 8 * 64
ch_n_class = 12
fcn_no_of_inputs = 12
# conv12, bn12 = self.make_layers(self.ch_config[10], self.ch_config[11], (1, 1));
conv12, bn12 = make_layers(in_channels = ch_confing_10, out_channels = ch_n_class, kernel_size = (1, 1));
fcn = nn.Linear(fcn_no_of_inputs, ch_n_class);

In [48]:
class ACDNet_TL_Model_Extend(nn.Module):
    def __init__(self, PretrainedWeights='./resources/pretrained_models/acdnet_20khz_trained_model_fold4_91.00.pt',opt=None):
        super(ACDNet_TL_Model_Extend, self).__init__()
        acdnet_model = GetACDNetModel(); # load original acdnet model first
        # device = opt#torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(f"device is {opt.device}")
        pretrain_weight= torch.load(PretrainedWeights, map_location=torch.device(opt.device))['weight']
        model_state = acdnet_model.state_dict()
        model_state.update(pretrain_weight)
        acdnet_model.load_state_dict(pretrain_weight, strict=False)
        # print(type(acdnet_model))
        # count = 0;
        for k, v in acdnet_model.named_parameters():
            # count += 1;
            # print(f"set {k} required_grade to False");
            v.requires_grad = False
        # print(f"count is {count}");
        self.sfeb = nn.Sequential(*list(acdnet_model.children())[0])
        tfeb_modules = []
        tfeb_modules.extend([*list(acdnet_model.tfeb.children())[:-6]])
        tfeb_modules.extend([conv12, bn12, nn.ReLU()]);
        tfeb_modules.append(nn.AvgPool2d(kernel_size = (2,4)));
        tfeb_modules.extend([nn.Flatten(), fcn]);
        # self.retrained_layers = nn.Sequential(*list(acdnet_model.tfeb.children())[:-1])
        # fcn_no_of_inputs = 50, n_class=10
        # n_class=6
        # fc = nn.Linear(50, n_class);
        # fc.requires_grad = True
        # tfeb_modules.extend([fc])
        self.tfeb = nn.Sequential(*tfeb_modules)
        self.output = nn.Sequential(
        nn.Softmax(dim=1));
        # print(f"type of self.tfeb is {type(self.tfeb)}")
        # for k2, v2 in self.tfeb:
        #     print(f"k:{k}'s requires_grad is {v2.requires_grad}");

    def forward(self, x):
        x = self.sfeb(x);
        #swapaxes
        x = x.permute((0, 2, 1, 3));
        x = self.tfeb(x);
        y = self.output[0](x);
        return y;

In [49]:
def GetTLACDNet():
    model = ACDNet_TL_Model_Extend(opt=getOpts());#ACDNet_TL_Model()
    return model

In [50]:
# test_model = GetTLACDNet()
# calc.summary(test_model, (1,1,30225))
# print(test_model)
# print(test_model.state_dict())

In [51]:
from datetime import datetime;

In [52]:
def genDataTimeStr():
    return datetime.today().strftime('%Y-%m-%d %H:%M:%S').replace('-',"").replace(' ',"").replace(':',"");

In [53]:
class TLTrainer:
    def __init__(self, opt=None):
        self.opt = opt;
        self.testX = None;
        self.testY = None;
        self.bestAcc = 0.0;
        self.bestAccEpoch = 0;
        self.trainGen = getTrainGen(opt)#train_generator.setup(opt, split);
        self.opt = opt;
        # self.opt.trainer = self;
        self.trainGen = getTrainGen(self.opt, self.opt.splits)#train_generator.setup(self.opt, self.opt.split);
        # self.pretrainedmodelpath = "./resources/pretrained_models/acdnet20_20khz_fold4.h5"

    def Train(self):
        train_start_time = time.time();
        net = GetTLACDNet().to(self.opt.device)#models.GetACDNetModel().to(self.opt.device);
        #print networks parameters' require_grade value
        for k_, v_ in net.named_parameters():
            print(f"{k_}:{v_.requires_grad}")
        print('ACDNet model has been prepared for training');

        calc.summary(net, (1,1,self.opt.inputLength));

        # training_text = "Re-Training" if self.opt.retrain else "Training from Scratch";
        # print("{} has been started. You will see update after finishing every training epoch and validation".format(training_text));

        lossFunc = torch.nn.KLDivLoss(reduction='batchmean');
        optimizer = optim.SGD(net.parameters(), lr=self.opt.LR, weight_decay=self.opt.weightDecay, momentum=self.opt.momentum, nesterov=True);

        # self.opt.nEpochs = 1957 if self.opt.split == 4 else 2000;
        for epochIdx in range(self.opt.nEpochs):
            epoch_start_time = time.time();
            optimizer.param_groups[0]['lr'] = self.__get_lr(epochIdx+1);
            cur_lr = optimizer.param_groups[0]['lr'];
            running_loss = 0.0;
            running_acc = 0.0;
            n_batches = math.ceil(len(self.trainGen.data)/self.opt.batchSize);
            for batchIdx in range(n_batches):
                # with torch.no_grad():
                x,y = self.trainGen.__getitem__(batchIdx)
                x = torch.tensor(np.moveaxis(x, 3, 1)).to(self.opt.device);
                y = torch.tensor(y).to(self.opt.device);
                # zero the parameter gradients
                optimizer.zero_grad();

                # forward + backward + optimize
                outputs = net(x);
                running_acc += (((outputs.data.argmax(dim=1) == y.argmax(dim=1))*1).float().mean()).item();
                loss = lossFunc(outputs.log(), y);
                loss.backward();
                optimizer.step();

                running_loss += loss.item();

            tr_acc = (running_acc / n_batches)*100;
            tr_loss = running_loss / n_batches;

            #Epoch wise validation Validation
            epoch_train_time = time.time() - epoch_start_time;

            net.eval();
            val_acc, val_loss = self.__validate(net, lossFunc);
            #Save best model
            self.__save_model(val_acc, epochIdx, net);
            self.__on_epoch_end(epoch_start_time, epoch_train_time, epochIdx, cur_lr, tr_loss, tr_acc, val_loss, val_acc);

            running_loss = 0;
            running_acc = 0;
            net.train();

        total_time_taken = time.time() - train_start_time;
        print("Execution finished in: {}".format(U.to_hms(total_time_taken)));

    def load_test_data(self):
        # data = np.load(os.path.join(self.opt.data, self.opt.dataset, 'test_data_{}khz/fold{}_test4000.npz'.format(self.opt.sr//1000, self.opt.split)), allow_pickle=True);
        data = np.load(self.opt.testData, allow_pickle=True);
        self.testX = torch.tensor(np.moveaxis(data['x'], 3, 1)).to(self.opt.device);
        self.testY = torch.tensor(data['y']).type(torch.float32).to(self.opt.device);

    def __get_lr(self, epoch):
        divide_epoch = np.array([self.opt.nEpochs * i for i in self.opt.schedule]);
        decay = sum(epoch > divide_epoch);
        if epoch <= self.opt.warmup:
            decay = 1;
        return self.opt.LR * np.power(0.1, decay);

    def __get_batch(self, index):
        x = self.trainX[index*self.opt.batchSize : (index+1)*self.opt.batchSize];
        y = self.trainY[index*self.opt.batchSize : (index+1)*self.opt.batchSize];
        return x.to(self.opt.device), y.to(self.opt.device);

    def __validate(self, net, lossFunc):
        if self.testX is None:
            self.load_test_data();
        net.eval();
        with torch.no_grad():
            y_pred = None;
            batch_size = len(self.testX);#(self.opt.batchSize//self.opt.nCrops)*self.opt.nCrops;
#             for idx in range(math.ceil(len(self.testX)/batch_size)):
#             for idx in range(len(self.testX)):
#             x = self.testX[idx*batch_size : (idx+1)*batch_size];
            x = self.testX[:];
            scores = net(x);
            y_pred = scores.data if y_pred is None else torch.cat((y_pred, scores.data));
            acc, loss = self.__compute_accuracy(y_pred, self.testY, lossFunc);
#         with torch.no_grad():
#             y_pred = None;
#             batch_size = (self.opt.batchSize//self.opt.nCrops)*self.opt.nCrops;
#             for idx in range(math.ceil(len(self.testX)/batch_size)):
#                 x = self.testX[idx*batch_size : (idx+1)*batch_size];
#                 scores = net(x);
#                 y_pred = scores.data if y_pred is None else torch.cat((y_pred, scores.data));

#             acc, loss = self.__compute_accuracy(y_pred, self.testY, lossFunc);
        net.train();
        return acc, loss;

    #Calculating average prediction (10 crops) and final accuracy
    def __compute_accuracy(self, y_pred, y_target, lossFunc):
        # print(f"shape of y_pred:{y_pred.shape}");
        # print(f"shape of y_target:{y_target.shape}");
        with torch.no_grad():
            #Reshape to shape theme like each sample comtains 10 samples, calculate mean and find theindices that has highest average value for each sample
            if self.opt.nCrops == 1:
                y_pred = y_pred.argmax(dim=1);
                y_target = y_target.argmax(dim=1);
            else:
                y_pred = (y_pred.reshape(y_pred.shape[0]//self.opt.nCrops, self.opt.nCrops, y_pred.shape[1])).mean(dim=1).argmax(dim=1);
                y_target = (y_target.reshape(y_target.shape[0]//self.opt.nCrops, self.opt.nCrops, y_target.shape[1])).mean(dim=1).argmax(dim=1);
            acc = (((y_pred==y_target)*1).float().mean()*100).item();
            # valLossFunc = torch.nn.KLDivLoss();
            loss = lossFunc(y_pred.float().log(), y_target.float()).item();
            # loss = 0.0;
        return acc, loss;

    def __on_epoch_end(self, start_time, train_time, epochIdx, lr, tr_loss, tr_acc, val_loss, val_acc):
        epoch_time = time.time() - start_time;
        val_time = epoch_time - train_time;
        line = 'SP-{} Epoch: {}/{} | Time: {} (Train {}  Val {}) | Train: LR {}  Loss {:.2f}  Acc {:.2f}% | Val: Loss {:.2f}  Acc(top1) {:.2f}% | HA {:.2f}@{}\n'.format(
            self.opt.splits, epochIdx+1, self.opt.nEpochs, U.to_hms(epoch_time), U.to_hms(train_time), U.to_hms(val_time),
            lr, tr_loss, tr_acc, val_loss, val_acc, self.bestAcc, self.bestAccEpoch);
        # print(line)
        sys.stdout.write(line);
        sys.stdout.flush();

    def __save_model(self, acc, epochIdx, net):
        print("__save_model is called")
        print(f"current best Acc is {self.bestAcc}")
        print(f"pass in acc is {acc}")
        if acc > self.bestAcc:
            dir = os.getcwd();
            save_path = "./trained_models/{}".format(self.opt.model_name.format(genDataTimeStr(),acc));
            # fname = "{}/torch/trained_models/{}_fold{}.pt";
            # fname = "{}/trained_models/acdnet_torch_20231218.pt";
            # old_model = fname.format(dir, self.opt.model_name.lower(), self.opt.splits);
            # if os.path.isfile(old_model):
            #     os.remove(old_model);
            self.bestAcc = acc;
            self.bestAccEpoch = epochIdx +1;
            # torch.save({'weight':net.state_dict(), 'config':net.ch_config}, fname.format(dir, self.opt.model_name.lower(), self.opt.split));
            # torch.save({'weight':net.state_dict()}, fname.format(dir, self.opt.model_name.lower(), self.opt.splits));
            torch.save({'weight':net.state_dict()}, save_path);
            print(f"model saved....., acc: {acc}")


In [54]:
def getTrainGen(opt=None, split=None):
    # dataset = np.load(os.path.join(opt.data, opt.dataset, 'wav{}.npz'.format(opt.sr // 1000)), allow_pickle=True);
    # dataset = np.load("../datasets/fold1_test16000.npz", allow_pickle=True);
    dataset = np.load(opt.trainData, allow_pickle=True);
    train_sounds = []
    train_labels = []
    # print(len(dataset['x']))
    # for i in range(1, opt.nFolds + 1):

    # train_sounds = [dataset['x'][i][0] for i in range(len(dataset['x']))]
    # train_labels = [dataset['y'][i][0] for i in range(len(dataset['y']))]
    train_sounds = dataset['fold{}'.format(1)].item()['sounds']
    train_labels = dataset['fold{}'.format(1)].item()['labels']
    # print(train_sounds)

    trainGen = TLGenerator(train_sounds, train_labels, opt);
    return trainGen

In [55]:
def main():
    opt = getOpts();
    opt.sr = 20000;
    opt.inputLength = 30225;
    opt.trainer = None
    opt.trainData="../../acd_datasets/single_fold/train_fsd50_20K__202401041450.npz";
    opt.testData="../../acd_datasets/single_fold/test_data_20K.npz";
    # import torch;
    
    tlopts.display_info(opt)
    opt.model_name = "acdnet_fsd50k_{}_acc_{}.pt"
    # valid_path = False;
    print("Initializing TLTrainer Object.....")
    trainer = TLTrainer(opt)
    print("Start to training.....")
    trainer.Train();

In [None]:
main()

***Use device:mps
+------------------------------+
| ACDNet_TL_Model_Extend Sound classification
+------------------------------+
| dataset  : uec_iot
| nEpochs  : 1000
| LRInit   : 0.1
| schedule : [0.3, 0.6, 0.9]
| warmup   : 10
| batchSize: 32
| nFolds: 5
| Splits: [1, 2, 3, 4, 5]
+------------------------------+
Initializing TLTrainer Object.....
length of samples:332
length of samples:332
Start to training.....
***Use device:mps
device is mps
sfeb.0.weight:False
sfeb.1.weight:False
sfeb.1.bias:False
sfeb.3.weight:False
sfeb.4.weight:False
sfeb.4.bias:False
tfeb.0.weight:False
tfeb.1.weight:False
tfeb.1.bias:False
tfeb.4.weight:False
tfeb.5.weight:False
tfeb.5.bias:False
tfeb.7.weight:False
tfeb.8.weight:False
tfeb.8.bias:False
tfeb.11.weight:False
tfeb.12.weight:False
tfeb.12.bias:False
tfeb.14.weight:False
tfeb.15.weight:False
tfeb.15.bias:False
tfeb.18.weight:False
tfeb.19.weight:False
tfeb.19.bias:False
tfeb.21.weight:False
tfeb.22.weight:False
tfeb.22.bias:False
tfeb.25.weight

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.59003633 0.         0.4099637 ]
 [0.         0.         0.         0.6341814  0.         0.
  0.         0.         0.         0.36581862 0.         0.        ]
 [0.         0.         0.         0.         0.63613445 0.
  0.         0.         0.         0.         0.         0.36386555]
 [0.         0.         0.         0.15922242 0.         0.
  0.         0.         0.         0.         0.         0.8407776 ]
 [0.         0.         0.78118086 0.         0.         0.
  0.         0.         0.         0.21881911 0.         0.        ]
 [0.         0.3880931  0.         0.6119069  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.16512774 0.         0.         0.83487225 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0. 

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.06514243 0.93485755
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.8772937  0.         0.         0.
  0.         0.         0.         0.12270629 0.         0.        ]
 [0.         0.         0.         0.         0.5581397  0.
  0.         0.         0.         0.44186032 0.         0.        ]
 [0.         0.         0.1520535  0.84794647 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.29961053 0.         0.7003895  0.         0.        ]
 [0.         0.         0.         0.         0.41173586 0.
  0.         0.         0.         0.         0.         0.5882641 ]
 [0.         0.         0.         0.7370197  0.         0.26298028
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.       

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.15274915
  0.         0.         0.         0.8472508  0.         0.        ]
 [0.         0.         0.         0.         0.         0.4189693
  0.         0.         0.         0.         0.         0.5810307 ]
 [0.         0.         0.0043142  0.9956858  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.94860154 0.         0.         0.
  0.         0.         0.         0.05139844 0.         0.        ]
 [0.         0.         0.03880381 0.         0.9611962  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.9738018  0.
  0.         0.         0.         0.         0.         0.02619822]
 [0.         0.         0.         0.         0.7690876  0.
  0.         0.         0.         0.         0.         0.23091237]
 [0.         0.3739594  0.        

total sounds is 32
labels in generate_batch is:
[[0.70849675 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.29150325]
 [0.         0.         0.         0.08149838 0.         0.
  0.         0.9185016  0.         0.         0.         0.        ]
 [0.         0.15236326 0.         0.         0.         0.84763676
  0.         0.         0.         0.         0.         0.        ]
 [0.8207533  0.         0.         0.         0.         0.17924672
  0.         0.         0.         0.         0.         0.        ]
 [0.28822634 0.         0.         0.         0.71177363 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.5554216  0.         0.         0.4445784  0.         0.        ]
 [0.         0.         0.         0.         0.         0.13427745
  0.86572254 0.         0.         0.         0.         0.        ]
 [0.08466344 0.         0

total sounds is 32
labels in generate_batch is:
[[0.58177614 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.41822383 0.        ]
 [0.         0.         0.         0.855004   0.14499597 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.7208291  0.27917087 0.         0.        ]
 [0.         0.         0.         0.         0.8982518  0.
  0.10174821 0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.8212913  0.17870867 0.         0.         0.        ]
 [0.         0.05006611 0.         0.9499339  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.8381972  0.16180283 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.18494573 0.         0.         0.         0. 

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.06471041 0.         0.         0.93528956
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.5752958  0.         0.4247042
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.9501388  0.
  0.         0.         0.         0.         0.         0.04986118]
 [0.         0.         0.         0.         0.5106586  0.
  0.         0.         0.         0.48934138 0.         0.        ]
 [0.32189256 0.         0.67810744 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.09633363
  0.         0.         0.         0.         0.         0.9036664 ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.50974965 0.49025035 0.         0.        ]
 [0.         0.         0.

total sounds is 32
labels in generate_batch is:
[[0.         0.16504633 0.83495367 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.5260154  0.
  0.         0.         0.         0.4739846  0.         0.        ]
 [0.         0.9740378  0.         0.         0.         0.
  0.02596217 0.         0.         0.         0.         0.        ]
 [0.         0.0288745  0.         0.         0.         0.9711255
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.67057306 0.3294269  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.19844204 0.         0.         0.         0.80155796]
 [0.2551639  0.7448361  0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.7150641  0.         0.     

total sounds is 32
labels in generate_batch is:
[[0.48117644 0.51882356 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.3942903
  0.         0.60570973 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.8663455  0.         0.         0.13365445]
 [0.         0.         0.         0.         0.         0.04192051
  0.9580795  0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.6301092  0.
  0.         0.         0.         0.         0.         0.36989078]
 [0.         0.         0.         0.         0.9551683  0.
  0.         0.         0.         0.04483167 0.         0.        ]
 [0.12028318 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.8797168  0.        ]
 [0.5557705  0.         0.44422948

shape of y_pred:torch.Size([352, 12])
shape of y_target:torch.Size([352, 12])
__save_model is called
current best Acc is 5.68181848526001
pass in acc is 10.795454978942871
model saved....., acc: 10.795454978942871
SP-[1, 2, 3, 4, 5] Epoch: 3/10 | Time: 0m25s (Train 0m20s  Val 0m05s) | Train: LR 0.010000000000000002  Loss 2.02  Acc 9.66% | Val: Loss nan  Acc(top1) 10.80% | HA 10.80@3
total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.7928004
  0.         0.         0.         0.         0.         0.2071996 ]
 [0.         0.         0.         0.4286296  0.         0.
  0.         0.57137036 0.         0.         0.         0.        ]
 [0.         0.         0.         0.55549467 0.         0.4445053
  0.         0.         0.         0.         0.         0.        ]
 [0.532449   0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.467551  ]
 [0.         0.         0.         0

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.01233739 0.9876626  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.48146296
  0.         0.         0.         0.         0.         0.51853704]
 [0.         0.         0.03231657 0.         0.         0.
  0.96768343 0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.86941797
  0.         0.         0.         0.         0.         0.13058203]
 [0.         0.         0.         0.         0.         0.300991
  0.         0.         0.         0.         0.         0.699009  ]
 [0.         0.         0.524962   0.         0.         0.
  0.         0.         0.47503802 0.         0.         0.        ]
 [0.         0.         0.51352876 0.         0.         0.
  0.48647127 0.         0.         0.         0.         0.        ]
 [0.         0.7499157  0. 

total sounds is 32
labels in generate_batch is:
[[0.0000000e+00 0.0000000e+00 0.0000000e+00 7.9846889e-01 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 2.0153110e-01]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 3.0518743e-01
  0.0000000e+00 6.9481260e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 7.6693214e-02
  0.0000000e+00 9.2330676e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 9.2958903e-01 7.0410997e-02 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [5.1501149e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 4.8498851e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 2.8583992e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0

total sounds is 32
labels in generate_batch is:
[[0.         0.32305342 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.6769466 ]
 [0.         0.         0.         0.         0.7935016  0.
  0.         0.         0.20649841 0.         0.         0.        ]
 [0.         0.         0.         0.53247446 0.         0.46752557
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.18870792 0.         0.         0.         0.8112921
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.36239845 0.         0.63760155 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.7053027  0.29469728 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.8827642  0.         0.         0.11723581 0.         0.        ]
 [0.         0.         0.10600983

total sounds is 32
labels in generate_batch is:
[[0.         0.13306095 0.         0.         0.86693907 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.90722287 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.09277714]
 [0.         0.         0.         0.926424   0.         0.07357597
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.9083024  0.         0.0916976  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.6066549  0.         0.39334515 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.9265832  0.         0.         0.07341682]
 [0.8668743  0.         0.         0.         0.         0.
  0.         0.         0.         0.13312574 0.         0.        ]
 [0.0601519  0.         0.         0.9398

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.97232705
  0.         0.02767294 0.         0.         0.         0.        ]
 [0.         0.7573543  0.         0.         0.         0.
  0.24264565 0.         0.         0.         0.         0.        ]
 [0.         0.         0.75533515 0.         0.         0.
  0.         0.         0.         0.         0.         0.24466483]
 [0.9153409  0.         0.         0.         0.         0.08465909
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.3313076  0.         0.         0.         0.6686924
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.8558481  0.         0.1441519 ]
 [0.         0.         0.28085795 0.         0.         0.
  0.         0.         0.719142   0.         0.         0.        ]
 [0.         0.         0.

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.02490117 0.         0.97509885 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.36830193
  0.         0.6316981  0.         0.         0.         0.        ]
 [0.         0.         0.5107961  0.         0.         0.
  0.         0.         0.         0.         0.         0.48920387]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.03259953 0.         0.9674005 ]
 [0.         0.36967564 0.         0.         0.         0.63032436
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.52767086 0.         0.         0.         0.
  0.         0.         0.         0.47232917 0.         0.        ]
 [0.         0.         0.         0.         0.75692666 0.
  0.24307334 0.         0.         0.         0.         0.        ]
 [0.         0.         0.       

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.8207742  0.         0.         0.17922579
  0.         0.         0.         0.         0.         0.        ]
 [0.45859358 0.         0.         0.         0.5414064  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.11665682 0.         0.         0.         0.88334316]
 [0.3672748  0.         0.         0.6327252  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.10297628 0.         0.
  0.         0.         0.         0.89702374 0.         0.        ]
 [0.         0.         0.6448301  0.35516986 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.09077026 0.
  0.         0.         0.         0.         0.         0.90922976]
 [0.14333333 0.         0.         0.8566

total sounds is 32
labels in generate_batch is:
[[0.         0.80835634 0.         0.19164364 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.37423638 0.
  0.         0.         0.         0.         0.         0.6257636 ]
 [0.         0.         0.         0.         0.95425785 0.
  0.04574218 0.         0.         0.         0.         0.        ]
 [0.         0.         0.6572128  0.         0.3427872  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.01807292 0.         0.         0.         0.9819271 ]
 [0.         0.42076477 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.5792352 ]
 [0.         0.         0.         0.         0.         0.6803572
  0.         0.         0.         0.3196428  0.         0.        ]
 [0.         0.         0.17001341 0.     

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.2914194  0.
  0.         0.         0.         0.         0.         0.7085806 ]
 [0.         0.7472732  0.         0.         0.2527268  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.25535777 0.
  0.         0.7446422  0.         0.         0.         0.        ]
 [0.798097   0.         0.201903   0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.95978826 0.         0.         0.         0.         0.
  0.         0.         0.         0.04021174 0.         0.        ]
 [0.         0.         0.         0.         0.7344338  0.
  0.         0.         0.         0.2655662  0.         0.        ]
 [0.         0.42607477 0.         0.         0.         0.
  0.         0.         0.         0.57392526 0.         0.        ]
 [0.         0.5487114  0.         0.45128858 0. 

total sounds is 32
labels in generate_batch is:
[[0.20938101 0.         0.         0.         0.79061896 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.81546944 0.         0.18453054]
 [0.         0.96590364 0.         0.         0.         0.03409638
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.37820026 0.         0.         0.6217997  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.1941025  0.         0.8058975
  0.         0.         0.         0.         0.         0.        ]
 [0.2997717  0.         0.         0.         0.7002283  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.7116656  0.         0.         0.28833443
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.

shape of y_pred:torch.Size([352, 12])
shape of y_target:torch.Size([352, 12])
__save_model is called
current best Acc is 13.068181037902832
pass in acc is 14.772727012634277
model saved....., acc: 14.772727012634277
SP-[1, 2, 3, 4, 5] Epoch: 5/10 | Time: 0m25s (Train 0m20s  Val 0m05s) | Train: LR 0.010000000000000002  Loss 1.99  Acc 14.77% | Val: Loss nan  Acc(top1) 14.77% | HA 14.77@5
total sounds is 32
labels in generate_batch is:
[[0.0000000e+00 0.0000000e+00 8.5951941e-04 9.9914050e-01 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [5.8962858e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  4.1037145e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 4.2114064e-01 0.0000000e+00
  5.7885933e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.000000

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.965466   0.
  0.         0.         0.         0.         0.         0.03453399]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.49667662 0.5033234  0.        ]
 [0.         0.9723194  0.         0.         0.         0.
  0.         0.         0.         0.02768057 0.         0.        ]
 [0.         0.         0.         0.         0.         0.47532067
  0.52467936 0.         0.         0.         0.         0.        ]
 [0.         0.         0.8409395  0.         0.         0.15906046
  0.         0.         0.         0.         0.         0.        ]
 [0.52130675 0.         0.         0.47869325 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.61247474 0.38752526 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.28525782 0.         0.       

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.83637244 0.16362756
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.646041   0.         0.         0.
  0.         0.         0.35395905 0.         0.         0.        ]
 [0.         0.         0.43274152 0.5672585  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.36952114 0.         0.         0.         0.
  0.         0.         0.         0.63047886 0.         0.        ]
 [0.         0.         0.         0.         0.         0.1446211
  0.         0.8553789  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.076546
  0.         0.         0.923454   0.         0.         0.        ]
 [0.         0.         0.4672609  0.         0.         0.
  0.         0.         0.         0.         0.         0.5327391 ]
 [0.         0.5135179  0.  

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.8935032
  0.         0.         0.1064968  0.         0.         0.        ]
 [0.         0.         0.         0.3629929  0.         0.
  0.         0.6370071  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.48117366 0.         0.         0.5188263  0.         0.        ]
 [0.         0.15404594 0.         0.         0.84595406 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.87910587 0.         0.         0.         0.         0.12089416
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.24323587
  0.         0.         0.         0.         0.         0.7567641 ]
 [0.         0.         0.89934015 0.         0.         0.10065983
  0.         0.         0.         0.         0.         0.        ]
 [0.6479698  0.   

total sounds is 32
labels in generate_batch is:
[[0.72345096 0.         0.         0.27654907 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.10695365 0.         0.         0.8930463  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.34047976 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.6595202 ]
 [0.         0.         0.         0.8119993  0.         0.
  0.1880007  0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.00305739 0.         0.
  0.         0.         0.9969426  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.14069787 0.         0.85930216 0.         0.        ]
 [0.652157   0.         0.         0.         0.         0.
  0.347843   0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.7303605  0. 

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.193295   0.
  0.         0.         0.         0.         0.         0.806705  ]
 [0.         0.         0.         0.         0.         0.
  0.         0.45562208 0.         0.         0.         0.5443779 ]
 [0.         0.5950055  0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.40499452]
 [0.0818696  0.         0.         0.9181304  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.7202861  0.27971393
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.33458847 0.66541153 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.3116033  0.         0.         0.
  0.         0.         0.         0.6883967  0.         0.        ]
 [0.         0.         0.         0.    

total sounds is 32
labels in generate_batch is:
[[9.1564751e-01 0.0000000e+00 8.4352508e-02 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 9.9215078e-01 7.8491988e-03 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [8.9509910e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  1.0490092e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 9.2724675e-01 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  7.2753280e-02 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 5.5752641e-01
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 4.4247359e-01
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 1.3192235e-01 8.6807764e-01
  0.0

total sounds is 32
labels in generate_batch is:
[[0.         0.5479795  0.45202053 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.68601453 0.         0.         0.31398547 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.6933196  0.30668038 0.         0.        ]
 [0.         0.         0.         0.         0.         0.9243026
  0.         0.         0.         0.07569742 0.         0.        ]
 [0.64540684 0.         0.         0.35459316 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.04590581 0.
  0.         0.         0.         0.         0.         0.9540942 ]
 [0.         0.2207795  0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.7792205 ]
 [0.         0.         0.         0.     

total sounds is 32
labels in generate_batch is:
[[0.         0.899587   0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.10041301]
 [0.         0.         0.7228653  0.         0.         0.
  0.         0.         0.2771347  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.3981582  0.         0.         0.         0.6018418 ]
 [0.3763146  0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.6236854 ]
 [0.         0.         0.         0.1516894  0.         0.
  0.         0.         0.         0.         0.         0.8483106 ]
 [0.         0.         0.5799974  0.         0.         0.
  0.         0.         0.         0.4200026  0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.03233874 0.         0.         0.         0.         0.96766126]
 [0.         0.         0.         0.61408025 0.3

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.92415375 0.         0.         0.
  0.         0.         0.         0.07584624 0.         0.        ]
 [0.         0.         0.         0.98509717 0.         0.
  0.         0.         0.         0.01490283 0.         0.        ]
 [0.         0.91004074 0.         0.         0.         0.
  0.         0.         0.         0.08995925 0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.18217014 0.81782985 0.         0.        ]
 [0.         0.71063703 0.         0.         0.         0.289363
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.7861688  0.         0.
  0.         0.         0.         0.         0.         0.21383122]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.3163922  0.         0.68360776]
 [0.08400071 0.9159993  0.         0.      

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.         0.         0.571193   0.         0.42880705 0.        ]
 [0.7372321  0.         0.         0.         0.         0.
  0.         0.         0.         0.26276794 0.         0.        ]
 [0.         0.         0.         0.6622124  0.         0.
  0.         0.         0.33778763 0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.54152375
  0.         0.         0.         0.         0.         0.45847628]
 [0.         0.         0.8284423  0.         0.         0.17155772
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.28586906 0.         0.71413094
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.29754356 0.7024565  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0

shape of y_pred:torch.Size([352, 12])
shape of y_target:torch.Size([352, 12])
__save_model is called
current best Acc is 16.47727394104004
pass in acc is 17.89772605895996
model saved....., acc: 17.89772605895996
SP-[1, 2, 3, 4, 5] Epoch: 7/10 | Time: 0m24s (Train 0m19s  Val 0m05s) | Train: LR 0.010000000000000002  Loss 1.99  Acc 13.35% | Val: Loss nan  Acc(top1) 17.90% | HA 17.90@7
total sounds is 32
labels in generate_batch is:
[[0.         0.         0.81169975 0.         0.         0.
  0.         0.18830028 0.         0.         0.         0.        ]
 [0.         0.62929976 0.         0.         0.3707002  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.32548365 0.         0.         0.
  0.6745164  0.         0.         0.         0.         0.        ]
 [0.06571258 0.         0.9342874  0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.03234045 0.  

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.8050972  0.         0.         0.
  0.19490281 0.         0.         0.         0.         0.        ]
 [0.         0.         0.6632857  0.3367143  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.78143096 0.         0.21856906 0.         0.         0.        ]
 [0.         0.         0.         0.         0.6185709  0.
  0.         0.         0.         0.3814291  0.         0.        ]
 [0.43815416 0.         0.56184584 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.94155467 0.         0.
  0.         0.05844535 0.         0.         0.         0.        ]
 [0.5778757  0.         0.         0.         0.         0.4221243
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.     

total sounds is 32
labels in generate_batch is:
[[0.0000000e+00 0.0000000e+00 2.9183542e-02 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 9.7081643e-01]
 [0.0000000e+00 0.0000000e+00 3.2536682e-01 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 6.7463315e-01 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [6.5064263e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  3.4935737e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  4.7735333e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 5.2264667e-01]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 5.5540532e-01 4.4459468e-01
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.98607147 0.         0.         0.0139285  0.         0.        ]
 [0.         0.2621115  0.         0.         0.         0.
  0.         0.         0.         0.7378885  0.         0.        ]
 [0.         0.85219306 0.14780696 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.656993   0.         0.         0.         0.         0.34300706]
 [0.2311978  0.         0.         0.         0.         0.
  0.         0.         0.         0.7688022  0.         0.        ]
 [0.         0.         0.87813944 0.12186054 0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.01524113 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.98475885]
 [0.         0.         0.         0.         0.4

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.8827316  0.
  0.11726835 0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.83182746 0.         0.         0.16817254 0.         0.        ]
 [0.         0.         0.         0.9026624  0.         0.
  0.         0.         0.09733763 0.         0.         0.        ]
 [0.         0.         0.         0.8887701  0.         0.1112299
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.20167878 0.         0.79832125]
 [0.         0.13886371 0.         0.         0.         0.86113626
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.74001265 0.         0.         0.         0.25998735]
 [0.05719725 0.         0.        

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.69698143 0.         0.         0.         0.         0.30301854]
 [0.         0.17654857 0.82345146 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.9792251  0.         0.         0.         0.         0.
  0.         0.         0.02077487 0.         0.         0.        ]
 [0.         0.         0.         0.29718372 0.         0.
  0.         0.         0.         0.70281625 0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.14472054 0.         0.85527945]
 [0.19063109 0.         0.         0.8093689  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.91159195 0.         0.         0.         0.08840804 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0. 

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.57427657 0.         0.42572343
  0.         0.         0.         0.         0.         0.        ]
 [0.09761424 0.         0.         0.9023858  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.35383287 0.         0.6461671  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.10036752 0.         0.         0.         0.8996325  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.51488554 0.48511446 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.30757585 0.         0.         0.
  0.         0.         0.         0.         0.         0.6924241 ]
 [0.         0.797895   0.         0.         0.         0.
  0.         0.         0.20210499 0.         0.         0.        ]
 [0.         0.5168975  0.         0.    

total sounds is 32
labels in generate_batch is:
[[0.44964117 0.55035883 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.8101491  0.         0.         0.
  0.         0.         0.         0.         0.         0.18985093]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.25750086 0.         0.7424992 ]
 [0.17797396 0.         0.         0.822026   0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.61363536 0.         0.         0.         0.
  0.         0.         0.         0.38636464 0.         0.        ]
 [0.         0.         0.         0.         0.42010164 0.
  0.         0.         0.         0.         0.         0.57989836]
 [0.92100805 0.07899193 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.95361865 0.         0.         0. 

total sounds is 32
labels in generate_batch is:
[[0.09054302 0.         0.90945697 0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.3425458  0.
  0.6574542  0.         0.         0.         0.         0.        ]
 [0.49686244 0.         0.         0.5031375  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.5583891  0.         0.         0.         0.         0.
  0.44161084 0.         0.         0.         0.         0.        ]
 [0.09265643 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.90734357]
 [0.         0.         0.         0.         0.18924312 0.
  0.         0.         0.         0.         0.         0.81075686]
 [0.         0.34291023 0.         0.         0.         0.65708977
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.9676514  0.0323

total sounds is 32
labels in generate_batch is:
[[0.         0.43761513 0.         0.         0.         0.
  0.         0.         0.         0.5623849  0.         0.        ]
 [0.5630626  0.         0.         0.4369374  0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.17262499 0.         0.
  0.         0.         0.         0.         0.         0.827375  ]
 [0.         0.         0.50830483 0.         0.         0.
  0.         0.         0.         0.         0.         0.49169517]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.5258596  0.4741404  0.         0.        ]
 [0.9326075  0.         0.         0.         0.         0.
  0.         0.         0.06739254 0.         0.         0.        ]
 [0.         0.42446545 0.         0.         0.         0.5755345
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.     

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.29192254 0.         0.
  0.         0.         0.         0.         0.         0.7080775 ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.81940275 0.         0.         0.18059723]
 [0.         0.         0.         0.         0.386844   0.
  0.61315596 0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.3660515  0.6339485  0.         0.        ]
 [0.         0.         0.23352143 0.         0.         0.
  0.         0.         0.         0.         0.         0.7664786 ]
 [0.         0.         0.         0.66857326 0.         0.
  0.         0.         0.         0.33142674 0.         0.        ]
 [0.         0.         0.         0.         0.09047023 0.
  0.         0.         0.         0.         0.         0.90952975]
 [0.         0.7525611  0.         0.         0. 

shape of y_pred:torch.Size([352, 12])
shape of y_target:torch.Size([352, 12])
__save_model is called
current best Acc is 18.75
pass in acc is 20.738636016845703
model saved....., acc: 20.738636016845703
SP-[1, 2, 3, 4, 5] Epoch: 9/10 | Time: 0m24s (Train 0m18s  Val 0m05s) | Train: LR 0.010000000000000002  Loss 1.92  Acc 17.90% | Val: Loss nan  Acc(top1) 20.74% | HA 20.74@9
total sounds is 32
labels in generate_batch is:
[[0.8692013  0.13079873 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.18548176 0.         0.8145182  0.         0.         0.
  0.         0.         0.         0.         0.         0.        ]
 [0.20800643 0.         0.         0.         0.         0.
  0.         0.         0.         0.79199356 0.         0.        ]
 [0.         0.78299725 0.         0.         0.21700278 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.6328347  0.         0.36716536 0.


total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.709438   0.
  0.         0.         0.         0.29056194 0.         0.        ]
 [0.         0.         0.         0.45853588 0.         0.5414641
  0.         0.         0.         0.         0.         0.        ]
 [0.9850537  0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.01494631]
 [0.         0.         0.         0.         0.         0.6205826
  0.         0.3794174  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.35996407 0.
  0.         0.         0.         0.         0.         0.6400359 ]
 [0.         0.         0.8281393  0.         0.         0.
  0.         0.         0.         0.17186067 0.         0.        ]
 [0.         0.         0.51462704 0.         0.         0.48537293
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.5

total sounds is 32
labels in generate_batch is:
[[0.20737937 0.         0.         0.         0.79262066 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.70558923
  0.         0.         0.         0.         0.         0.29441074]
 [0.         0.         0.         0.7704637  0.         0.22953628
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.09457846 0.         0.90542156 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.92903054
  0.         0.         0.         0.07096946 0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.20239377 0.         0.         0.7976062 ]
 [0.         0.         0.         0.         0.         0.5700025
  0.         0.         0.         0.4299975  0.         0.        ]
 [0.         0.   

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.         0.         0.90469855 0.         0.         0.09530147]
 [0.         0.         0.05480929 0.         0.         0.
  0.9451907  0.         0.         0.         0.         0.        ]
 [0.9079783  0.         0.         0.         0.0920217  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.25571218
  0.         0.         0.         0.         0.         0.74428785]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.06025572 0.9397443  0.         0.        ]
 [0.         0.33734    0.         0.         0.         0.66266
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.43843567 0.         0.
  0.         0.         0.         0.5615643  0.         0.        ]
 [0.         0.         0.         0

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.20747323 0.         0.         0.7925268  0.         0.        ]
 [0.         0.         0.         0.         0.7825846  0.
  0.         0.         0.2174154  0.         0.         0.        ]
 [0.         0.         0.         0.         0.88574636 0.11425365
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.9283816  0.         0.
  0.07161838 0.         0.         0.         0.         0.        ]
 [0.         0.5113587  0.         0.         0.4886413  0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.26193187
  0.         0.         0.         0.7380681  0.         0.        ]
 [0.         0.         0.         0.59315526 0.         0.40684476
  0.         0.         0.         0.         0.         0.        ]
 [0.00298991 0.         0

total sounds is 32
labels in generate_batch is:
[[0.         0.         0.         0.         0.         0.
  0.7417859  0.         0.         0.25821412 0.         0.        ]
 [0.         0.58424205 0.         0.         0.         0.
  0.         0.41575792 0.         0.         0.         0.        ]
 [0.         0.         0.         0.2148928  0.         0.
  0.         0.         0.         0.         0.         0.7851072 ]
 [0.         0.5678649  0.         0.         0.         0.
  0.         0.4321351  0.         0.         0.         0.        ]
 [0.         0.31377503 0.         0.         0.         0.686225
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.7159134  0.28408659 0.
  0.         0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.01864655
  0.         0.         0.         0.         0.         0.98135346]
 [0.57661414 0.         0.         