In [1]:
import re
import os
import pandas as pd

regex = re.compile(r'\[.+\]\n', re.IGNORECASE)
file_paths, file_names, emotions = [], [], []
emotion_map = {'Neutral': 'neutral', 'Anger': 'angry', 'Happiness': 'happy', 'Sadness': 'sad', 'Fear': 'fear',
              'Disgust': 'disgust'}

In [2]:
# Define path to datasets
DATA_NATURAL = "/home/rl3155/MESD_All"
entries = os.listdir(DATA_NATURAL)

from tqdm import tqdm

for i in tqdm(range(len(entries))):
    entry = entries[i]
    path = DATA_NATURAL + "/" + entry
    if "wav" not in path:
        continue
    emotion = emotion_map[entry.split("_")[0]]
    
    file_paths.append(path)
    file_names.append(entry)
    emotions.append(emotion)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 863/863 [00:00<00:00, 729627.97it/s]


In [3]:
file = pd.DataFrame({'path':file_paths, 'name': file_names, 'emotion': emotions})

In [4]:
file.head()

Unnamed: 0,path,name,emotion
0,/home/rl3155/MESD_All/Anger_C_B_delincuencia.wav,Anger_C_B_delincuencia.wav,angry
1,/home/rl3155/MESD_All/Happiness_M_A_hoy.wav,Happiness_M_A_hoy.wav,happy
2,/home/rl3155/MESD_All/Disgust_C_B_irrespetuoso...,Disgust_C_B_irrespetuoso.wav,disgust
3,/home/rl3155/MESD_All/Sadness_M_B_hambre.wav,Sadness_M_B_hambre.wav,sad
4,/home/rl3155/MESD_All/Neutral_F_A_izquierda.wav,Neutral_F_A_izquierda.wav,neutral


In [5]:
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

spaths, semotions = shuffle(file_paths, emotions, random_state=42)

In [6]:
X_train, X_val, y_train, y_val = train_test_split(spaths, semotions, test_size=0.2, random_state=1)

In [7]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

1.12.1
0.12.1+cu113
cuda


In [8]:
bundle = torchaudio.pipelines.WAV2VEC2_BASE
extractor = bundle.get_model()
print(extractor.__class__)

<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>


In [9]:
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, paths, labels, label_transform):
        super(MyDataSet).__init__()
        self.paths = paths
        self.labels = labels
        self.label_transform = label_transform
        
    def __getitem__(self, idx):
        path = self.paths[idx]
        label = self.label_transform[self.labels[idx]]
        wave, sr = torchaudio.load(path)
        if sr != bundle.sample_rate:
            wave = torchaudio.functional.resample(wave, sr, bundle.sample_rate)
        with torch.inference_mode():
            feature, _ = extractor.extract_features(wave)
        feature = [f[0] for f in feature]
        audio = torch.stack(feature)
        length = audio.size(1)
        return audio, length, label
    
    def __len__(self):
        return len(self.labels)

In [10]:
def collate_indic(data):
    audios, lengths, labels = zip(*data)
    max_len = max(lengths)
    n_ftrs = audios[0].size(2)
    n_dims = audios[0].size(0)
    features = torch.zeros((len(audios), n_dims, max_len, n_ftrs))
    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)

    for i in range(len(data)):
        j, k = audios[i].size(1), audios[i].size(2)
        features[i] = torch.cat([audios[i], torch.zeros((n_dims, max_len - j, k))], dim=1)

    return features, lengths, labels

In [11]:
categories = ['neutral', 'angry', 'happy', 'sad', 'fear', 'disgust']
cate_dic = {}
for i, cate in enumerate(categories):
    cate_dic[cate] = i
cate_dic

{'neutral': 0, 'angry': 1, 'happy': 2, 'sad': 3, 'fear': 4, 'disgust': 5}

In [12]:
from torch.utils.data import DataLoader

train_dataset = MyDataSet(X_train, y_train, cate_dic)
trainloader_args = dict(batch_size=16, shuffle=True)
train_dataloader = DataLoader(train_dataset, **trainloader_args, 
                              collate_fn=collate_indic)

test_dataset = MyDataSet(X_val, y_val, cate_dic)
testloader_args = dict(batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, **testloader_args, 
                             collate_fn=collate_indic)

## Train with 3CNN+LSTM

In [13]:
import torch.nn as nn
import torch.nn.functional as F

class ICASSP3CNN(nn.Module):
    def __init__(self, vocab_size, dims = 12, embed_size=128, hidden_size=512, num_lstm_layers = 2, bidirectional = False, label_size=7):
        super().__init__()
        self.n_layers = num_lstm_layers 
        self.hidden = hidden_size
        self.bidirectional = bidirectional
        
        self.aggr = nn.Conv1d(in_channels=dims, out_channels=1, kernel_size=1)
        
        self.embed = nn.Linear(in_features = vocab_size, out_features = embed_size)

        self.cnn  = nn.Conv1d(embed_size, embed_size, kernel_size=3, padding=1)
        self.cnn2 = nn.Conv1d(embed_size, embed_size, kernel_size=5, padding=2)
        self.cnn3 = nn.Conv1d(embed_size, embed_size, kernel_size=7, padding=3)

        self.batchnorm = nn.BatchNorm1d(3 * embed_size)

        self.lstm = nn.LSTM(input_size = 3 * embed_size, 
                            hidden_size = hidden_size, 
                            num_layers = num_lstm_layers, 
                            bidirectional = bidirectional)

        self.linear = nn.Linear(in_features = 2 * hidden_size if bidirectional else hidden_size, 
                                out_features = label_size)


    def forward(self, x, lengths):
        """
        padded_x: (B,T) padded LongTensor
        """
        n, d, b, t = x.size(0), x.size(1), x.size(2), x.size(3)
        x = torch.flatten(x, start_dim=2)
        input = self.aggr(x)
        input = torch.reshape(input, (n, b, t))
        input = self.embed(input)

        batch_size = input.size(0)
        input = input.transpose(1,2)    # (B,T,H) -> (B,H,T)

        cnn_output = torch.cat([self.cnn(input), self.cnn2(input), self.cnn3(input)], dim=1)

        input = F.relu(self.batchnorm(cnn_output))

        input = input.transpose(1,2)

        pack_tensor = nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=True, enforce_sorted=False)
        _, (hn, cn) = self.lstm(pack_tensor)

        if self.bidirectional:
            h_n = hn.view(self.n_layers, 2, batch_size, self.hidden)
            h_n = torch.cat([ h_n[-1, 0,:], h_n[-1,1,:] ], dim = 1)
        else:
            h_n = hn[-1]

        logits = self.linear(h_n)

        return logits

### Model Traning on each layer 

In [None]:
from tqdm import tqdm
from torchsummary import summary
import torch.optim as optim

model = ICASSP3CNN(768)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

epochs = 50
train_losses = []
train_accuracies = []
valid_losses = []
valid_accuracies = []

for epoch in tqdm(range(epochs)):
    train_loss = 0
    acc_cnt = 0
    err_cnt = 0
    batch_cnt = 0
    model.train()
    for batch, (x, length, y) in enumerate(train_dataloader):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(x, length)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().item()

        #model outputs
        out_val, out_indices = torch.max(logits, dim=1)
        tar_indices = y

        for i in range(len(out_indices)):
            if out_indices[i] == tar_indices[i]:
                acc_cnt += 1
            else:
                err_cnt += 1
        batch_cnt += 1
    
    train_loss = train_loss/batch_cnt
    train_accuracy = acc_cnt/(acc_cnt+err_cnt)
    train_accuracies.append(train_accuracy)
    train_losses.append(train_loss)
    
    valid_loss = 0
    acc_cnt = 0
    err_cnt = 0
    batch_cnt = 0
    model.eval()

    for x, lengths, y in test_dataloader:

        x = x.to(device)
        y = y.to(device)

        logits = model(x, lengths)
        loss = criterion(logits, y)
        valid_loss += loss.cpu().item()

        out_val, out_indices = torch.max(logits, dim=1)
        tar_indices = y

        for i in range(len(out_indices)):
            if out_indices[i] == tar_indices[i]:
                acc_cnt += 1
            else:
                err_cnt += 1
        batch_cnt += 1
    
    valid_loss = valid_loss/batch_cnt
    valid_accuracy = acc_cnt/(acc_cnt+err_cnt)
    valid_accuracies.append(valid_accuracy)
    valid_losses.append(valid_loss)
    
    print(f"epoch:{epoch+1}, train accu:{train_accuracy:.4f},", 
          f"train loss:{train_loss:.2f}, valid accu:{valid_accuracy:.4f},", 
          f"valid loss:{valid_loss:.2f}")

  2%|████                                                                                                                                                                                                       | 1/50 [01:45<1:25:54, 105.19s/it]

epoch:1, train accu:0.1858, train loss:1.87, valid accu:0.2428, valid loss:1.82


  4%|████████                                                                                                                                                                                                   | 2/50 [03:27<1:22:48, 103.50s/it]

epoch:2, train accu:0.1771, train loss:1.80, valid accu:0.3121, valid loss:1.57


  6%|████████████▏                                                                                                                                                                                              | 3/50 [05:12<1:21:26, 103.96s/it]

epoch:3, train accu:0.2337, train loss:1.67, valid accu:0.2890, valid loss:1.51


  8%|████████████████▏                                                                                                                                                                                          | 4/50 [06:56<1:19:49, 104.11s/it]

epoch:4, train accu:0.3149, train loss:1.55, valid accu:0.3295, valid loss:1.39


 10%|████████████████████▎                                                                                                                                                                                      | 5/50 [08:38<1:17:33, 103.40s/it]

epoch:5, train accu:0.2961, train loss:1.48, valid accu:0.2890, valid loss:1.47


 12%|████████████████████████▎                                                                                                                                                                                  | 6/50 [10:21<1:15:48, 103.38s/it]

epoch:6, train accu:0.3483, train loss:1.37, valid accu:0.3006, valid loss:1.43


 14%|████████████████████████████▍                                                                                                                                                                              | 7/50 [12:05<1:14:13, 103.57s/it]

epoch:7, train accu:0.4383, train loss:1.28, valid accu:0.3006, valid loss:1.44


 16%|████████████████████████████████▍                                                                                                                                                                          | 8/50 [13:48<1:12:16, 103.25s/it]

epoch:8, train accu:0.4906, train loss:1.18, valid accu:0.3468, valid loss:1.39


 18%|████████████████████████████████████▌                                                                                                                                                                      | 9/50 [15:30<1:10:13, 102.77s/it]

epoch:9, train accu:0.5515, train loss:1.08, valid accu:0.5376, valid loss:1.14


 20%|████████████████████████████████████████▍                                                                                                                                                                 | 10/50 [17:12<1:08:24, 102.62s/it]

epoch:10, train accu:0.6038, train loss:0.95, valid accu:0.5145, valid loss:1.02


 22%|████████████████████████████████████████████▍                                                                                                                                                             | 11/50 [18:53<1:06:20, 102.07s/it]

epoch:11, train accu:0.5922, train loss:0.95, valid accu:0.5318, valid loss:1.05


 24%|████████████████████████████████████████████████▍                                                                                                                                                         | 12/50 [20:35<1:04:37, 102.05s/it]

epoch:12, train accu:0.6415, train loss:0.78, valid accu:0.5896, valid loss:0.96


 26%|████████████████████████████████████████████████████▌                                                                                                                                                     | 13/50 [22:17<1:02:56, 102.06s/it]

epoch:13, train accu:0.6865, train loss:0.76, valid accu:0.5607, valid loss:1.10


 28%|████████████████████████████████████████████████████████▌                                                                                                                                                 | 14/50 [24:01<1:01:34, 102.63s/it]

epoch:14, train accu:0.7402, train loss:0.64, valid accu:0.6069, valid loss:1.08


 30%|█████████████████████████████████████████████████████████████▏                                                                                                                                              | 15/50 [25:43<59:52, 102.65s/it]

epoch:15, train accu:0.7039, train loss:0.77, valid accu:0.5376, valid loss:1.06


 32%|█████████████████████████████████████████████████████████████████▎                                                                                                                                          | 16/50 [27:25<57:57, 102.28s/it]

epoch:16, train accu:0.6952, train loss:0.68, valid accu:0.5954, valid loss:0.81


 34%|█████████████████████████████████████████████████████████████████████▎                                                                                                                                      | 17/50 [29:07<56:13, 102.23s/it]

epoch:17, train accu:0.7620, train loss:0.55, valid accu:0.6763, valid loss:0.83


 36%|█████████████████████████████████████████████████████████████████████████▍                                                                                                                                  | 18/50 [30:48<54:23, 101.99s/it]

epoch:18, train accu:0.7765, train loss:0.51, valid accu:0.6127, valid loss:1.00


 38%|█████████████████████████████████████████████████████████████████████████████▌                                                                                                                              | 19/50 [32:30<52:36, 101.82s/it]

epoch:19, train accu:0.7808, train loss:0.54, valid accu:0.6590, valid loss:0.84


 40%|█████████████████████████████████████████████████████████████████████████████████▌                                                                                                                          | 20/50 [34:12<50:59, 101.99s/it]

epoch:20, train accu:0.8200, train loss:0.48, valid accu:0.7341, valid loss:0.83


 42%|█████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                      | 21/50 [35:54<49:13, 101.86s/it]

epoch:21, train accu:0.8752, train loss:0.38, valid accu:0.7052, valid loss:0.81


 44%|█████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                  | 22/50 [37:35<47:25, 101.64s/it]

epoch:22, train accu:0.8970, train loss:0.32, valid accu:0.7283, valid loss:0.73


 46%|█████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                              | 23/50 [39:18<45:53, 101.97s/it]

epoch:23, train accu:0.8810, train loss:0.34, valid accu:0.6879, valid loss:0.86


 48%|█████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                                          | 24/50 [41:00<44:13, 102.05s/it]

epoch:24, train accu:0.8824, train loss:0.33, valid accu:0.7399, valid loss:0.70


 50%|██████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                      | 25/50 [42:44<42:47, 102.70s/it]

epoch:25, train accu:0.9216, train loss:0.26, valid accu:0.7225, valid loss:0.98


 52%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                  | 26/50 [44:29<41:18, 103.25s/it]

epoch:26, train accu:0.9390, train loss:0.23, valid accu:0.7052, valid loss:1.02


 54%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                             | 27/50 [46:14<39:51, 103.98s/it]

epoch:27, train accu:0.9202, train loss:0.31, valid accu:0.6994, valid loss:0.77


 56%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                         | 28/50 [48:00<38:18, 104.46s/it]

epoch:28, train accu:0.8853, train loss:0.34, valid accu:0.6994, valid loss:0.81


 58%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                     | 29/50 [49:46<36:42, 104.89s/it]

epoch:29, train accu:0.9158, train loss:0.24, valid accu:0.6243, valid loss:1.13


 60%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                 | 30/50 [51:31<34:57, 104.85s/it]

epoch:30, train accu:0.9289, train loss:0.25, valid accu:0.6647, valid loss:1.00


 62%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                             | 31/50 [53:15<33:07, 104.62s/it]

epoch:31, train accu:0.9448, train loss:0.19, valid accu:0.6647, valid loss:1.00


 64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                         | 32/50 [54:59<31:20, 104.49s/it]

epoch:32, train accu:0.9594, train loss:0.15, valid accu:0.6994, valid loss:1.04


 66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                     | 33/50 [56:41<29:25, 103.84s/it]

epoch:33, train accu:0.9724, train loss:0.20, valid accu:0.6763, valid loss:1.05


 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                 | 34/50 [58:26<27:46, 104.14s/it]

epoch:34, train accu:0.9521, train loss:0.16, valid accu:0.6705, valid loss:0.98


In [None]:
model_path = '/home/rl3155/models/wav2vecbase_new.pth'

torch.save({'epoch':epochs,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict()},
            model_path)

metadata = pd.DataFrame({'epoch':range(epochs), 'train loss':train_losses, 
                         'valid loss':valid_losses, 'train accu':train_accuracies, 
                         'valid_accu':valid_accuracies})
metadata.to_csv('/home/rl3155/results/acc_loss/wav2vecbase_new.csv ', 
                index=False)

In [None]:

import matplotlib.pyplot as plt

plt.plot(range(epochs), train_losses, label='train')
plt.plot(range(epochs), valid_losses, label='valid')
plt.legend()
plt.title('training and validation loss')
plt.show()
     

In [None]:
plt.plot(range(epochs), train_accuracies, label='train')
plt.plot(range(epochs), valid_accuracies, label='valid')
plt.legend()
plt.title('training and validation accuracy')
plt.show()

In [None]:
from sklearn.metrics import confusion_matrix

y_pred = []
y_true = []

for inputs, lengths, labels in test_dataloader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    output = model(inputs, lengths) # Feed Network

    output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
    y_pred.extend(output) # Save Prediction

    labels = labels.data.cpu().numpy()
    y_true.extend(labels) # Save Truth

In [None]:
import numpy as np
cf = confusion_matrix(y_true, y_pred)
classes = list(set([v[1] for k,v in waveforms_results_dict.items()]))
df_cm = pd.DataFrame(cf, index = [i for i in classes],
                     columns = [i for i in classes])
sns.heatmap(df_cm, annot=True)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Spanish Wav2VecBase Confusion Matrix')
plt.savefig('/home/rl3155/results/confusion_matrix/wav2vecbase_new.png')