<a href="https://colab.research.google.com/github/zahra-zarrabi/Audio_Classification_CNN/blob/main/torch_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
import torchaudio
import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [3]:
cd "/content/drive/MyDrive/audio_classification"

/content/drive/MyDrive/audio_classification


In [4]:
epochs = 60
lr = 0.001
batch_size = 8

In [5]:
class AudioDataset(Dataset):
    def __init__(self, root):
        self.dir_path = root
        self.classes = os.listdir(self.dir_path)

        self.data_paths = []
        self.labels = []

        for root, dirs, files in os.walk(self.dir_path):
            for file in files:
                label = os.path.basename(root)
                data_path = os.path.join(root, file)
                self.data_paths.append(data_path)
                self.labels.append(self.classes.index(label))

        print("classes: ", self.classes)
        print(f"{len(self.labels)} datas loaded from {len(set(self.labels))} classes")
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        data_path = self.data_paths[index]
        label = self.labels[index]

        signal, sample_rate = torchaudio.load(data_path)
        signal = torch.mean(signal, dim=0, keepdim=True)
       
        transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=8000)
        signal_transformed = transform(signal)

        return signal_transformed, label

In [6]:
dataset = AudioDataset('dataset')

classes:  ['mohammadali', 'morteza', 'zeynab', 'alireza', 'maryam', 'nahid', 'parisa', 'zahra', 'sajjad', 'hosein', 'amir']
1249 datas loaded from 11 classes


In [11]:
len(os.listdir('dataset/zeynab'))

99

In [None]:
dataset.data_paths

In [7]:
# Data loader
train_size = int(len(dataset)*0.8)
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_data_loader = torch.utils.data.DataLoader(train_dataset, 
                                                batch_size=batch_size, 
                                                shuffle=True)

test_data_loader = torch.utils.data.DataLoader(test_dataset, 
                                                batch_size=batch_size, 
                                                shuffle=False)

In [8]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=11, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        x = F.softmax(x, dim=1)
        return x

    def accuracy(self, preds, labels):
        maxs, indices = torch.max(preds, 1)
        acc = torch.sum(indices == labels) / len(preds)
        return acc.cpu()

In [9]:
model = M5(n_output=11).to(device)
print(model)

# count_parameters
n = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: %s" % n)

M5(
  (conv1): Conv1d(1, 32, kernel_size=(80,), stride=(16,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=11, bias=True)
)
Numbe

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss()

In [11]:
# train
model.train()

for epoch in range(epochs):
    train_loss = 0.0
    train_acc = 0.0
    for audios, labels in tqdm(train_data_loader):
        audios, labels = audios.to(device), labels.to(device)
        labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=11).type(torch.FloatTensor).to(device)

        preds = model(audios)
        loss = loss_function(preds, labels_one_hot)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss
        train_acc += model.accuracy(preds, labels)
    
    total_loss = train_loss / len(train_data_loader)
    total_acc = train_acc / len(train_data_loader)

    print(f"Epoch: {epoch}, Loss: {total_loss}, Acc: {total_acc}")

100%|██████████| 125/125 [13:42<00:00,  6.58s/it]


Epoch: 0, Loss: 2.300555467605591, Acc: 0.2774285674095154


100%|██████████| 125/125 [00:03<00:00, 32.94it/s]


Epoch: 1, Loss: 2.1951823234558105, Acc: 0.35757145285606384


100%|██████████| 125/125 [00:03<00:00, 34.37it/s]


Epoch: 2, Loss: 2.138993263244629, Acc: 0.40542855858802795


100%|██████████| 125/125 [00:03<00:00, 33.39it/s]


Epoch: 3, Loss: 2.0759379863739014, Acc: 0.5035714507102966


100%|██████████| 125/125 [00:03<00:00, 33.53it/s]


Epoch: 4, Loss: 2.0400569438934326, Acc: 0.5375714302062988


100%|██████████| 125/125 [00:03<00:00, 33.94it/s]


Epoch: 5, Loss: 1.9961564540863037, Acc: 0.5737143158912659


100%|██████████| 125/125 [00:03<00:00, 33.80it/s]


Epoch: 6, Loss: 1.978256344795227, Acc: 0.5907142758369446


100%|██████████| 125/125 [00:03<00:00, 34.05it/s]


Epoch: 7, Loss: 1.951163649559021, Acc: 0.6148571372032166


100%|██████████| 125/125 [00:03<00:00, 33.60it/s]


Epoch: 8, Loss: 1.9234580993652344, Acc: 0.6467142701148987


100%|██████████| 125/125 [00:03<00:00, 33.68it/s]


Epoch: 9, Loss: 1.9086214303970337, Acc: 0.6578571200370789


100%|██████████| 125/125 [00:03<00:00, 32.45it/s]


Epoch: 10, Loss: 1.8784090280532837, Acc: 0.6954286098480225


100%|██████████| 125/125 [00:04<00:00, 29.89it/s]


Epoch: 11, Loss: 1.8719780445098877, Acc: 0.6875714063644409


100%|██████████| 125/125 [00:03<00:00, 33.97it/s]


Epoch: 12, Loss: 1.8553218841552734, Acc: 0.7055714130401611


100%|██████████| 125/125 [00:03<00:00, 33.14it/s]


Epoch: 13, Loss: 1.8363265991210938, Acc: 0.7247142791748047


100%|██████████| 125/125 [00:03<00:00, 32.82it/s]


Epoch: 14, Loss: 1.8305360078811646, Acc: 0.7295714020729065


100%|██████████| 125/125 [00:03<00:00, 32.74it/s]


Epoch: 15, Loss: 1.8121057748794556, Acc: 0.7527142763137817


100%|██████████| 125/125 [00:03<00:00, 33.23it/s]


Epoch: 16, Loss: 1.7945915460586548, Acc: 0.7799999713897705


100%|██████████| 125/125 [00:03<00:00, 33.47it/s]


Epoch: 17, Loss: 1.801680564880371, Acc: 0.7617142796516418


100%|██████████| 125/125 [00:03<00:00, 33.75it/s]


Epoch: 18, Loss: 1.7873834371566772, Acc: 0.7697142958641052


100%|██████████| 125/125 [00:03<00:00, 33.96it/s]


Epoch: 19, Loss: 1.7711857557296753, Acc: 0.7870000004768372


100%|██████████| 125/125 [00:03<00:00, 33.87it/s]


Epoch: 20, Loss: 1.768113136291504, Acc: 0.7967143058776855


100%|██████████| 125/125 [00:03<00:00, 34.45it/s]


Epoch: 21, Loss: 1.734413981437683, Acc: 0.8367142677307129


100%|██████████| 125/125 [00:03<00:00, 33.65it/s]


Epoch: 22, Loss: 1.7026920318603516, Acc: 0.8640000224113464


100%|██████████| 125/125 [00:03<00:00, 34.55it/s]


Epoch: 23, Loss: 1.6923978328704834, Acc: 0.8730000257492065


100%|██████████| 125/125 [00:03<00:00, 33.57it/s]


Epoch: 24, Loss: 1.698943853378296, Acc: 0.8629999756813049


100%|██████████| 125/125 [00:03<00:00, 33.96it/s]


Epoch: 25, Loss: 1.6659891605377197, Acc: 0.8965713977813721


100%|██████████| 125/125 [00:03<00:00, 33.32it/s]


Epoch: 26, Loss: 1.6667472124099731, Acc: 0.8918570876121521


100%|██████████| 125/125 [00:03<00:00, 34.56it/s]


Epoch: 27, Loss: 1.6651240587234497, Acc: 0.8920000195503235


100%|██████████| 125/125 [00:03<00:00, 33.73it/s]


Epoch: 28, Loss: 1.665522813796997, Acc: 0.8938571214675903


100%|██████████| 125/125 [00:03<00:00, 33.77it/s]


Epoch: 29, Loss: 1.6543798446655273, Acc: 0.9070000052452087


100%|██████████| 125/125 [00:03<00:00, 34.37it/s]


Epoch: 30, Loss: 1.6674779653549194, Acc: 0.8928571343421936


100%|██████████| 125/125 [00:03<00:00, 34.08it/s]


Epoch: 31, Loss: 1.6676651239395142, Acc: 0.890999972820282


100%|██████████| 125/125 [00:03<00:00, 34.26it/s]


Epoch: 32, Loss: 1.6604597568511963, Acc: 0.8999999761581421


100%|██████████| 125/125 [00:03<00:00, 34.13it/s]


Epoch: 33, Loss: 1.6481878757476807, Acc: 0.9120000004768372


100%|██████████| 125/125 [00:03<00:00, 32.98it/s]


Epoch: 34, Loss: 1.6633201837539673, Acc: 0.8928571343421936


100%|██████████| 125/125 [00:03<00:00, 34.34it/s]


Epoch: 35, Loss: 1.6600468158721924, Acc: 0.8938571214675903


100%|██████████| 125/125 [00:03<00:00, 33.75it/s]


Epoch: 36, Loss: 1.645414113998413, Acc: 0.9087142944335938


100%|██████████| 125/125 [00:03<00:00, 34.50it/s]


Epoch: 37, Loss: 1.6534931659698486, Acc: 0.9089999794960022


100%|██████████| 125/125 [00:03<00:00, 34.31it/s]


Epoch: 38, Loss: 1.638113260269165, Acc: 0.9147142767906189


100%|██████████| 125/125 [00:03<00:00, 34.17it/s]


Epoch: 39, Loss: 1.6411434412002563, Acc: 0.9150000214576721


100%|██████████| 125/125 [00:03<00:00, 34.32it/s]


Epoch: 40, Loss: 1.6508140563964844, Acc: 0.9027143120765686


100%|██████████| 125/125 [00:03<00:00, 34.26it/s]


Epoch: 41, Loss: 1.635029911994934, Acc: 0.9169999957084656


100%|██████████| 125/125 [00:03<00:00, 33.44it/s]


Epoch: 42, Loss: 1.6541664600372314, Acc: 0.8998571038246155


100%|██████████| 125/125 [00:03<00:00, 33.92it/s]


Epoch: 43, Loss: 1.6440719366073608, Acc: 0.9070000052452087


100%|██████████| 125/125 [00:03<00:00, 34.27it/s]


Epoch: 44, Loss: 1.6363333463668823, Acc: 0.9169999957084656


100%|██████████| 125/125 [00:03<00:00, 34.14it/s]


Epoch: 45, Loss: 1.6288135051727295, Acc: 0.921999990940094


100%|██████████| 125/125 [00:03<00:00, 34.34it/s]


Epoch: 46, Loss: 1.6405571699142456, Acc: 0.9068571329116821


100%|██████████| 125/125 [00:03<00:00, 34.37it/s]


Epoch: 47, Loss: 1.6377733945846558, Acc: 0.9158571362495422


100%|██████████| 125/125 [00:03<00:00, 33.73it/s]


Epoch: 48, Loss: 1.6254404783248901, Acc: 0.9240000247955322


100%|██████████| 125/125 [00:03<00:00, 34.40it/s]


Epoch: 49, Loss: 1.6360265016555786, Acc: 0.9179999828338623


100%|██████████| 125/125 [00:03<00:00, 33.49it/s]


Epoch: 50, Loss: 1.6420707702636719, Acc: 0.9120000004768372


100%|██████████| 125/125 [00:03<00:00, 33.31it/s]


Epoch: 51, Loss: 1.6368178129196167, Acc: 0.9118571281433105


100%|██████████| 125/125 [00:03<00:00, 34.17it/s]


Epoch: 52, Loss: 1.6311153173446655, Acc: 0.9175714254379272


100%|██████████| 125/125 [00:03<00:00, 33.82it/s]


Epoch: 53, Loss: 1.6362879276275635, Acc: 0.916857123374939


100%|██████████| 125/125 [00:03<00:00, 34.04it/s]


Epoch: 54, Loss: 1.6232967376708984, Acc: 0.9259999990463257


100%|██████████| 125/125 [00:03<00:00, 34.12it/s]


Epoch: 55, Loss: 1.6491506099700928, Acc: 0.902999997138977


100%|██████████| 125/125 [00:03<00:00, 34.03it/s]


Epoch: 56, Loss: 1.6413424015045166, Acc: 0.9128571152687073


100%|██████████| 125/125 [00:03<00:00, 33.92it/s]


Epoch: 57, Loss: 1.6326700448989868, Acc: 0.9208571314811707


100%|██████████| 125/125 [00:03<00:00, 34.13it/s]


Epoch: 58, Loss: 1.633940577507019, Acc: 0.916857123374939


100%|██████████| 125/125 [00:03<00:00, 33.11it/s]

Epoch: 59, Loss: 1.6233198642730713, Acc: 0.9269999861717224





In [12]:
model.eval()

test_loss = 0.0
test_acc = 0.0
for audios, labels in tqdm(test_data_loader):
    audios, labels = audios.to(device), labels.to(device)
    labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=11).type(torch.FloatTensor).to(device)

    preds = model(audios)
    loss = loss_function(preds, labels_one_hot)

    test_loss += loss
    test_acc += model.accuracy(preds, labels)

total_loss = test_loss / len(test_data_loader)
total_acc = test_acc / len(test_data_loader)

print(f"Loss: {total_loss}, Acc: {total_acc}")

100%|██████████| 32/32 [03:26<00:00,  6.47s/it]

Loss: 1.677049160003662, Acc: 0.875





In [13]:
torch.save(model.state_dict(), "weights_5.pth")

In [None]:
# Inference

signal, sample_rate = torchaudio.load("/content/drive/MyDrive/audio_classification/audios/zahra_2.opus")

# preprocess
signal = torch.mean(signal, dim=0, keepdim=True)
new_sample_rate = 8000
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)
signal = transform(signal)
signal = signal[:, 32000:40000]
signal = signal.unsqueeze(0).to(device)

# process
preds = model(signal)

# postprocess
preds = preds.cpu().detach().numpy()
output = np.argmax(preds)
print(output)