In [1]:
import numpy as np
import pandas as pd
import sklearn
import torch
import librosa
import librosa.display
import torchaudio
import os
import random

In [2]:
# load metadata
metadata = pd.read_csv("./information.csv")

metadata.head()

Unnamed: 0.1,Unnamed: 0,fname,directory,model,label
0,0,2022-05-14_06-18-26.wav,./dataset/big_fast/,X8SW,1
1,1,2022-05-14_06-18-36.wav,./dataset/big_fast/,X8SW,1
2,2,2022-05-14_06-18-47.wav,./dataset/big_fast/,X8SW,1
3,3,2022-05-14_06-18-58.wav,./dataset/big_fast/,X8SW,1
4,4,2022-05-14_06-19-08.wav,./dataset/big_fast/,X8SW,1


In [3]:
class AudioUtil():
    def open(audio_file):
        y, sr = torchaudio.load(audio_file)
        return y, sr

    # data augmentation function
    def time_shift(aud, shift_limit):
        y, sr = aud
        _, sig_len = y.shape
        shift_amt = int(random.random() * shift_limit * sig_len)
        return y.roll(shift_amt), sr
    
    def MFCCs(y, sr):
        y = y.cpu().detach().numpy()
        mfccs = librosa.feature.mfcc(y=y, sr=sr)
        mfcc_scaled = np.mean(mfccs.T, axis=0)
        return mfcc_scaled

In [4]:
from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio

In [5]:
class CustomDataset(Dataset):
    def __init__(self, root, label):
        # file root
        self.root = root
        # slow = 0, fast = 1
        self.label = label

        fs = [os.path.join(root, f) for f in os.listdir(self.root)]
        # all file path
        self.data_files = [f for f in fs if os.path.isfile(f)]
        self.label = [label] * len(self.data_files)
    
    # __len__
    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        y, sr = AudioUtil.open(self.data_files[idx])
        mfcc = AudioUtil.MFCCs(y, sr)
        return mfcc, torch.tensor(self.label[idx])

In [6]:
# file path
big_fast_path = "./dataset/big_fast/"
big_slow_path = "./dataset/big_slow/"

In [7]:
slow_dataset = CustomDataset(big_slow_path, label = 0)
fast_dataset = CustomDataset(big_fast_path, label = 1)

slow_train, slow_valid, slow_test = torch.utils.data.random_split(slow_dataset,
[int(len(slow_dataset)*0.8), int(len(slow_dataset)*0.1), len(slow_dataset) - int(len(slow_dataset) * 0.8) - int(len(slow_dataset) * 0.1)],
generator=torch.Generator().manual_seed(42))

fast_train, fast_valid, fast_test = torch.utils.data.random_split(fast_dataset,
[int(len(fast_dataset)*0.8), int(len(fast_dataset)*0.1), len(fast_dataset) - int(len(fast_dataset) * 0.8) - int(len(fast_dataset) * 0.1)],
generator=torch.Generator().manual_seed(42))

In [8]:
train_dataset = torch.utils.data.ConcatDataset([slow_train, fast_train])
val_dataset = torch.utils.data.ConcatDataset([slow_valid, fast_valid])
test_dataset = torch.utils.data.ConcatDataset([slow_test, fast_test])

In [9]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
val_loader = DataLoader(val_dataset, batch_size=16)

In [10]:
import torch.nn as nn

In [11]:
# example
class ClassifireNN(nn.Module):
    def __init__(self, drop_out=0.0):
        super(ClassifireNN, self).__init__()
        # self.cnn1 = nn.Conv1d(in_channels=20, out_channels=32, kernel_size=5, padding=2)
        # self.cnn2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        # self.cnn3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=5, padding=2)

        # self.pool1 = nn.MaxPool1d(4)
        # self.pool2 = nn.MaxPool1d(5)
        # self.pool3 = nn.MaxPool1d(5)

        self.fc1 = nn.Linear(2 * 20, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 8)
        self.fc4 = nn.Linear(8, 4)
        self.fc5 = nn.Linear(4, 1)
        
        self.relu = nn.ReLU()

        self.drop_out = nn.Dropout(p=drop_out)

    def forward(self, x):
        # torch.Size([16, 20, 2]) [batch, feature, channel]

        # x = self.relu(self.cnn1(x))
        # # x = self.pool1(x)
        # x = self.relu(self.cnn2(x))
        # # x = self.pool2(x)
        # x = self.relu(self.cnn3(x))      
        # # x = self.pool3(x)
        # print(x.shape)
        # asdf
        
        x = x.view(-1, 20 * 2)  

        x = self.relu(self.fc1(x))
        x = self.drop_out(x)
        x = self.relu(self.fc2(x))
        x = self.drop_out(x)
        x = self.relu(self.fc3(x))
        x = self.drop_out(x)
        x = self.relu(self.fc4(x))
        x = self.fc5(x)

        x = torch.sigmoid(x)

        return x.view(-1)
        

In [12]:
device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [13]:
LR = 0.0001
PATIENCE = 3
FACTOR = 0.95
DROP_OUT = 0.3
EPOCHS = 100

In [14]:
model = ClassifireNN(drop_out=DROP_OUT).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.BCELoss()

In [15]:
best_auc = 0
best_epoch = -1
best_pred = []

prev_model = None

In [16]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.utils.tensorboard import SummaryWriter

In [17]:
wirter = SummaryWriter()

In [36]:
for i in tqdm(range(EPOCHS)):

    # Train
    loss_sum = 0
    true_labels = []
    pred_labels = []
    model.train()

    for e_num, (x, y) in enumerate(train_loader):

        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)
        
        model.zero_grad()
        pred_y = model(x)

        loss = criterion(pred_y, y)
        loss_sum += loss.detach()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        true_labels.extend(y.cpu().numpy())
        pred_labels.extend(np.around(pred_y.cpu().detach().numpy()))

    auc = accuracy_score(true_labels, pred_labels)

    # Valid
    for e_num, (x, y) in enumerate(val_loader):
        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)

        pred_y = model(x)
        loss = criterion(pred_y, y)

        loss_sum += loss.detach()

        true_labels.extend(y.cpu().numpy())
        pred_labels.extend(np.around(pred_y.cpu().detach().numpy()))

    auc = accuracy_score(true_labels, pred_labels)
    
    # wirter.add_scalar("")

    if auc > best_auc:
        best_pred = pred_labels
        best_auc = auc
        best_epoch = i

        if prev_model is not None:
            os.remove(prev_model)
        prev_model = f'cnn_model_{best_auc}.h5'
        torch.save(model.state_dict(), prev_model)

print(f"best validation acc = {best_auc}, in epoch {best_epoch}")

  1%|          | 1/100 [00:28<47:33, 28.82s/it]

auc: 0.46255506607929514


  2%|▏         | 2/100 [00:56<45:30, 27.86s/it]

auc: 0.4977973568281938


  3%|▎         | 3/100 [01:25<46:13, 28.59s/it]

auc: 0.5506607929515418


  4%|▍         | 4/100 [01:53<45:22, 28.36s/it]

auc: 0.5682819383259912


  5%|▌         | 5/100 [02:22<45:07, 28.50s/it]

auc: 0.5550660792951542


  6%|▌         | 6/100 [02:52<45:46, 29.22s/it]

auc: 0.5330396475770925


  7%|▋         | 7/100 [03:18<43:39, 28.17s/it]

auc: 0.5330396475770925


  8%|▊         | 8/100 [03:50<44:59, 29.34s/it]

auc: 0.5726872246696035


  9%|▉         | 9/100 [04:16<42:51, 28.25s/it]

auc: 0.5814977973568282


 10%|█         | 10/100 [04:45<42:29, 28.33s/it]

auc: 0.5154185022026432


 11%|█         | 11/100 [05:13<41:57, 28.29s/it]

auc: 0.6079295154185022


 12%|█▏        | 12/100 [05:40<41:07, 28.04s/it]

auc: 0.5903083700440529


 13%|█▎        | 13/100 [06:09<40:52, 28.19s/it]

auc: 0.5947136563876652


 14%|█▍        | 14/100 [06:34<39:05, 27.27s/it]

auc: 0.5991189427312775


 15%|█▌        | 15/100 [07:18<45:55, 32.41s/it]

auc: 0.6167400881057269


 16%|█▌        | 16/100 [07:48<44:24, 31.73s/it]

auc: 0.5903083700440529


 17%|█▋        | 17/100 [08:17<42:27, 30.69s/it]

auc: 0.6431718061674009


 18%|█▊        | 18/100 [08:45<40:56, 29.96s/it]

auc: 0.6431718061674009


 19%|█▉        | 19/100 [09:12<39:11, 29.03s/it]

auc: 0.6167400881057269


 20%|██        | 20/100 [09:39<38:09, 28.62s/it]

auc: 0.6563876651982379


 21%|██        | 21/100 [10:07<37:13, 28.27s/it]

auc: 0.6607929515418502


 22%|██▏       | 22/100 [10:35<36:29, 28.07s/it]

auc: 0.6828193832599119


 23%|██▎       | 23/100 [11:02<35:50, 27.93s/it]

auc: 0.6343612334801763


 24%|██▍       | 24/100 [11:29<34:56, 27.59s/it]

auc: 0.6784140969162996


 25%|██▌       | 25/100 [11:57<34:35, 27.68s/it]

auc: 0.6431718061674009


 26%|██▌       | 26/100 [12:24<33:57, 27.53s/it]

auc: 0.6343612334801763


 27%|██▋       | 27/100 [12:52<33:40, 27.68s/it]

auc: 0.6651982378854625


 28%|██▊       | 28/100 [13:19<32:54, 27.43s/it]

auc: 0.6696035242290749


 29%|██▉       | 29/100 [13:47<32:40, 27.61s/it]

auc: 0.6696035242290749


 30%|███       | 30/100 [14:14<31:58, 27.40s/it]

auc: 0.6563876651982379


 31%|███       | 31/100 [14:41<31:24, 27.31s/it]

auc: 0.6784140969162996


 32%|███▏      | 32/100 [15:09<31:19, 27.64s/it]

auc: 0.6696035242290749


 33%|███▎      | 33/100 [15:37<30:59, 27.76s/it]

auc: 0.6916299559471366


 34%|███▍      | 34/100 [16:05<30:35, 27.81s/it]

auc: 0.6740088105726872


 35%|███▌      | 35/100 [16:32<29:40, 27.39s/it]

auc: 0.6916299559471366


 36%|███▌      | 36/100 [17:00<29:25, 27.58s/it]

auc: 0.7400881057268722


 37%|███▋      | 37/100 [17:27<28:47, 27.42s/it]

auc: 0.762114537444934


 38%|███▊      | 38/100 [17:55<28:37, 27.71s/it]

auc: 0.7004405286343612


 39%|███▉      | 39/100 [18:22<27:57, 27.49s/it]

auc: 0.6916299559471366


 40%|████      | 40/100 [18:50<27:32, 27.55s/it]

auc: 0.7577092511013216


 41%|████      | 41/100 [19:19<27:41, 28.16s/it]

auc: 0.7268722466960352


 42%|████▏     | 42/100 [19:46<26:53, 27.83s/it]

auc: 0.7577092511013216


 43%|████▎     | 43/100 [20:14<26:24, 27.80s/it]

auc: 0.7136563876651982


 44%|████▍     | 44/100 [20:42<25:52, 27.73s/it]

auc: 0.7444933920704846


 45%|████▌     | 45/100 [21:09<25:24, 27.73s/it]

auc: 0.788546255506608


 46%|████▌     | 46/100 [21:38<25:07, 27.92s/it]

auc: 0.7312775330396476


 47%|████▋     | 47/100 [22:06<24:39, 27.91s/it]

auc: 0.7400881057268722


 48%|████▊     | 48/100 [22:32<23:47, 27.44s/it]

auc: 0.7312775330396476


 49%|████▉     | 49/100 [23:01<23:35, 27.76s/it]

auc: 0.7444933920704846


 50%|█████     | 50/100 [23:27<22:50, 27.42s/it]

auc: 0.7444933920704846


 51%|█████     | 51/100 [23:55<22:33, 27.62s/it]

auc: 0.7797356828193832


 52%|█████▏    | 52/100 [24:22<21:54, 27.39s/it]

auc: 0.7577092511013216


 53%|█████▎    | 53/100 [24:50<21:40, 27.66s/it]

auc: 0.7092511013215859


 54%|█████▍    | 54/100 [25:18<21:14, 27.70s/it]

auc: 0.7841409691629956


 55%|█████▌    | 55/100 [25:45<20:38, 27.53s/it]

auc: 0.7841409691629956


 56%|█████▌    | 56/100 [26:13<20:12, 27.56s/it]

auc: 0.7797356828193832


 57%|█████▋    | 57/100 [26:40<19:40, 27.45s/it]

auc: 0.788546255506608


 58%|█████▊    | 58/100 [27:08<19:14, 27.50s/it]

auc: 0.775330396475771


 59%|█████▉    | 59/100 [27:36<19:00, 27.81s/it]

auc: 0.7841409691629956


 60%|██████    | 60/100 [28:02<18:12, 27.31s/it]

auc: 0.7797356828193832


 61%|██████    | 61/100 [28:29<17:40, 27.19s/it]

auc: 0.7973568281938326


 62%|██████▏   | 62/100 [28:58<17:28, 27.59s/it]

auc: 0.8325991189427313


 63%|██████▎   | 63/100 [29:26<17:07, 27.77s/it]

auc: 0.8414096916299559


 64%|██████▍   | 64/100 [29:53<16:25, 27.38s/it]

auc: 0.8149779735682819


 65%|██████▌   | 65/100 [30:21<16:06, 27.62s/it]

auc: 0.8458149779735683


 66%|██████▌   | 66/100 [30:48<15:32, 27.42s/it]

auc: 0.8237885462555066


 67%|██████▋   | 67/100 [31:15<15:08, 27.52s/it]

auc: 0.8149779735682819


 68%|██████▊   | 68/100 [31:43<14:37, 27.41s/it]

auc: 0.8281938325991189


 69%|██████▉   | 69/100 [32:11<14:19, 27.71s/it]

auc: 0.8546255506607929


 70%|███████   | 70/100 [32:38<13:40, 27.34s/it]

auc: 0.801762114537445


 71%|███████   | 71/100 [33:06<13:18, 27.54s/it]

auc: 0.8590308370044053


 72%|███████▏  | 72/100 [33:33<12:50, 27.52s/it]

auc: 0.8546255506607929


 73%|███████▎  | 73/100 [34:01<12:23, 27.52s/it]

auc: 0.9030837004405287


 74%|███████▍  | 74/100 [34:27<11:50, 27.31s/it]

auc: 0.8634361233480177


 75%|███████▌  | 75/100 [34:55<11:23, 27.34s/it]

auc: 0.8634361233480177


 76%|███████▌  | 76/100 [35:23<11:02, 27.62s/it]

auc: 0.8370044052863436


 77%|███████▋  | 77/100 [35:49<10:24, 27.16s/it]

auc: 0.9074889867841409


 78%|███████▊  | 78/100 [36:17<10:03, 27.43s/it]

auc: 0.8810572687224669


 79%|███████▉  | 79/100 [36:46<09:42, 27.72s/it]

auc: 0.8634361233480177


 80%|████████  | 80/100 [37:12<09:07, 27.39s/it]

auc: 0.8766519823788547


 81%|████████  | 81/100 [37:40<08:40, 27.38s/it]

auc: 0.8502202643171806


 82%|████████▏ | 82/100 [38:08<08:17, 27.63s/it]

auc: 0.8810572687224669


 83%|████████▎ | 83/100 [38:35<07:45, 27.37s/it]

auc: 0.8810572687224669


 84%|████████▍ | 84/100 [39:03<07:21, 27.62s/it]

auc: 0.8854625550660793


 85%|████████▌ | 85/100 [39:31<06:55, 27.67s/it]

auc: 0.8766519823788547


 86%|████████▌ | 86/100 [39:57<06:24, 27.45s/it]

auc: 0.8854625550660793


 87%|████████▋ | 87/100 [40:26<06:02, 27.89s/it]

auc: 0.8942731277533039


 88%|████████▊ | 88/100 [40:54<05:35, 27.95s/it]

auc: 0.8854625550660793


 89%|████████▉ | 89/100 [41:21<05:03, 27.55s/it]

auc: 0.8898678414096917


 90%|█████████ | 90/100 [41:49<04:36, 27.68s/it]

auc: 0.8898678414096917


 91%|█████████ | 91/100 [42:16<04:06, 27.42s/it]

auc: 0.8986784140969163


 92%|█████████▏| 92/100 [42:44<03:41, 27.70s/it]

auc: 0.8986784140969163


 93%|█████████▎| 93/100 [43:11<03:10, 27.27s/it]

auc: 0.933920704845815


 94%|█████████▍| 94/100 [43:39<02:45, 27.56s/it]

auc: 0.9383259911894273


 95%|█████████▌| 95/100 [44:06<02:16, 27.36s/it]

auc: 0.9162995594713657


 96%|█████████▌| 96/100 [44:33<01:49, 27.43s/it]

auc: 0.9118942731277533


 97%|█████████▋| 97/100 [45:01<01:22, 27.49s/it]

auc: 0.8942731277533039


 98%|█████████▊| 98/100 [45:29<00:55, 27.80s/it]

auc: 0.933920704845815


 99%|█████████▉| 99/100 [45:56<00:27, 27.58s/it]

auc: 0.9162995594713657


100%|██████████| 100/100 [46:25<00:00, 27.85s/it]

auc: 0.9074889867841409
best validation acc = 0.9383259911894273, in epoch 93



