In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import pathlib
import tqdm
import util

EPOCHS = 100

# PyTorch

In [2]:
class AudioDataset(Dataset):
    def __init__(self, dataset_path: str) -> None:
        self.dataset_path = pathlib.Path(dataset_path)
        self.file_paths = []
        self.labels = []

        self.duration = 10000
        self.sr = 44100
        self.channel = 2
        self.shift_pct = 0.4

        self.labels_meaning = {
            0: 'нейтрально',
            1: 'спокойно',
            2: 'счастливо',
            3: 'грустно',
            4: 'сердито',
            5: 'напуганно',
            6: 'недовольно',
            7: 'удивлённо'
        }

        for elem in os.listdir(self.dataset_path):
            dirpath = self.dataset_path.joinpath(elem)
            for audio in os.listdir(dirpath):
                audiopath = dirpath.joinpath(audio)
                self.file_paths.append(audiopath)
                emotion = int(audio.split('-')[2])-1
                self.labels.append(emotion)


    def __len__(self) -> int:
        return len(self.labels)
    

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        aud = util.open(self.file_paths[idx])
  
        reaud = util.resample(aud, self.sr)
        rechan = util.rechannel(reaud, self.channel)

        dur_aud = util.pad_trunc(rechan, self.duration)
        shift_aud = util.time_shift(dur_aud, self.shift_pct)
        sgram = util.spectrogram(shift_aud, n_mels=64, n_fft=1024, hop_len=None)
        aug_sgram = util.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)
        label = torch.tensor(self.labels[idx])

        return aug_sgram, label

In [3]:
train_dataset = AudioDataset('./dataset/train')
test_dataset = AudioDataset('./dataset/test')

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [4]:
class SpeechEmotionClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        conv_layers = []

        self.conv1 = nn.Conv2d(2, 32, kernel_size=5, stride=2, padding=2)
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(32)
        conv_layers += [self.conv1, self.relu1, self.bn1]

        self.conv2 = nn.Conv2d(32, 128, kernel_size=3, stride=2, padding=1)
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm2d(128)
        conv_layers += [self.conv2, self.relu2, self.bn2]

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.relu3 = nn.ReLU()
        self.bn3 = nn.BatchNorm2d(256)
        conv_layers += [self.conv3, self.relu3, self.bn3]

        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.relu4 = nn.ReLU()
        self.bn4 = nn.BatchNorm2d(512)
        conv_layers += [self.conv4, self.relu4, self.bn4]

        self.conv = nn.Sequential(*conv_layers)

        self.adapt_avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Linear(in_features=512, out_features=8)

        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.adapt_avg_pool(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

In [104]:
model = SpeechEmotionClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [105]:
best_model = None
best_loss = float('inf')

train_losses = []
validation_losses = []

for epoch in tqdm.tqdm(range(EPOCHS)):
    train_loss = 0
    loader_iterator = 0

    model.train()
    for batch in train_dataloader:
        features, labels = batch
        optimizer.zero_grad()
        y_pred = model(features)
        loss = criterion(y_pred, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        loader_iterator += 1
    train_loss /= loader_iterator
    train_losses.append(train_loss)

    val_loss = 0
    loader_iterator = 0

    model.eval()
    for batch in test_dataloader:
        features, labels = batch
        y_pred = model(features)
        loss = criterion(y_pred, labels)
        val_loss += loss.item()
        loader_iterator += 1
    
    val_loss /= loader_iterator
    validation_losses.append(val_loss)

    if val_loss < best_loss:
        best_loss = val_loss
        best_model = model.state_dict()
    
    if epoch % 10 == 0:
        print(f'Эпоха: {epoch}. Ошибка тестовая: {train_loss}. Ошибка валидационная: {val_loss}')

  1%|          | 1/100 [00:57<1:34:16, 57.14s/it]

Эпоха: 0. Ошибка тестовая: 1.919232652366506. Ошибка валидационная: 1.8212080299854279


 11%|█         | 11/100 [09:58<1:20:05, 53.99s/it]

Эпоха: 10. Ошибка тестовая: 1.5904662829603073. Ошибка валидационная: 1.3694091141223907


 21%|██        | 21/100 [18:58<1:10:57, 53.90s/it]

Эпоха: 20. Ошибка тестовая: 1.1718974480394684. Ошибка валидационная: 1.1339569613337517


 31%|███       | 31/100 [27:57<1:01:56, 53.87s/it]

Эпоха: 30. Ошибка тестовая: 0.8645320017041499. Ошибка валидационная: 0.9403711035847664


 41%|████      | 41/100 [36:58<53:09, 54.05s/it]  

Эпоха: 40. Ошибка тестовая: 0.6356431732890923. Ошибка валидационная: 0.7042659409344196


 51%|█████     | 51/100 [47:02<48:33, 59.47s/it]

Эпоха: 50. Ошибка тестовая: 0.43553579887213734. Ошибка валидационная: 0.48439713194966316


 61%|██████    | 61/100 [56:44<37:38, 57.92s/it]

Эпоха: 60. Ошибка тестовая: 0.3537190649227772. Ошибка валидационная: 0.733573067933321


 71%|███████   | 71/100 [1:06:41<29:06, 60.23s/it]

Эпоха: 70. Ошибка тестовая: 0.2861375659381206. Ошибка валидационная: 0.7428099140524864


 81%|████████  | 81/100 [1:16:47<19:09, 60.51s/it]

Эпоха: 80. Ошибка тестовая: 0.2533519707264842. Ошибка валидационная: 0.9981588162481785


 91%|█████████ | 91/100 [1:26:38<08:50, 58.98s/it]

Эпоха: 90. Ошибка тестовая: 0.2241477800554709. Ошибка валидационная: 0.7096858527511358


100%|██████████| 100/100 [1:35:35<00:00, 57.35s/it]


In [142]:
import plotly.graph_objects as go
import plotly.offline as pyo

fig = go.Figure()

fig.add_traces(go.Scatter(x = np.arange(100), y = train_losses, name="Тренировочная ошибка"))
fig.add_traces(go.Scatter(x = np.arange(100), y = validation_losses, name="Валидационная ошибка"))
fig.update_layout(
    hovermode='x',
    xaxis_title="Эпохи",
    yaxis_title="Ошибка"
)

pyo.iplot(fig)

In [19]:
from torchmetrics import Accuracy


model = SpeechEmotionClassifier()
model.load_state_dict(torch.load('./speech_emotion_recognizer_model.pth'))

X_test, y_test = [], []

for data in test_dataset:
    x, y = data
    X_test.append(x)
    y_test.append(y)

X_test = torch.stack(X_test)
y_test = torch.tensor(y_test)

accuracy = Accuracy('multiclass', num_classes=8)
with torch.no_grad():
    y_pred = model(X_test)
    y_pred = y_pred.argmax(dim=1)
    acc = accuracy(y_pred, y_test)

print(f"Точность: {round(acc.item(), 2)*100}%")

Точность: 80.0%


In [143]:
torch.save(model.state_dict(), 'speech_emotion_recognizer_model.pth')