## Prepare Data

In [1]:
import re
import os
import pandas as pd

regex = re.compile(r'\[.+\]\n', re.IGNORECASE)
file_paths, file_names, sessions, emotions = [], [], [], []
emotion_map = {'neu': 'neutral', 'ang': 'angry', 'hap': 'happy', 'exc': 'happy', 'sad': 'sad'}

for session in range(1, 6):
    emo_evaluation_dir = f'/home/jz3313/IEMOCAP_full_release/Session{session}/dialog/EmoEvaluation/'
    file_dir = f'/home/jz3313/IEMOCAP_full_release/Session{session}/sentences/wav/'
    evaluation_files = [l for l in os.listdir(emo_evaluation_dir) if 'Ses' in l]
    for file in evaluation_files:
        with open(emo_evaluation_dir + file) as f:
            content = f.read()
        lines = re.findall(regex, content)
        for line in lines[1:]:  # the first line is a header
            start_end_time, wav_file_name, emotion, val_act_dom = line.strip().split('\t')
            dir_name = '_'.join(wav_file_name.split('_')[:-1])
            if emotion in emotion_map:
                file_paths.append(file_dir+dir_name+'/'+wav_file_name+'.wav')
                file_names.append(wav_file_name)
                sessions.append(session)
                emotions.append(emotion_map[emotion])

In [2]:
file = pd.DataFrame({'path':file_paths, 'name': file_names, 'session': sessions, 'emotion': emotions})

In [3]:
file.head()

Unnamed: 0,path,name,session,emotion
0,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F000,1,sad
1,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F001,1,sad
2,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F002,1,sad
3,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F003,1,neutral
4,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F004,1,sad


In [4]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

1.12.1+cu113
0.12.1+cu113
cuda


In [5]:
bundle = torchaudio.pipelines.WAV2VEC2_BASE
extractor = bundle.get_model()
print(extractor.__class__)

<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>


In [None]:
from tqdm import tqdm

for _, row in tqdm(file.iterrows()):
    path, name, session, emotion = row[0], row[1], row[2], row[3]
    wave, sr = torchaudio.load(path)
    if sr != bundle.sample_rate:
        wave = torchaudio.functional.resample(wave, sr, bundle.sample_rate)
    with torch.inference_mode():
        feature, _ = extractor.extract_features(wave)
    feature = [f[0] for f in feature]
    feature = torch.stack(feature)
    save_path = f'../data/wav2vecbase/Session{session}/{name}.pt'
    torch.save(feature, save_path)

## Load Data

In [6]:
file['newpath'] = file.apply(lambda x: f'../data/wav2vecbase/Session{x[2]}/{x[1]}.pt', axis=1)

In [7]:
file.head()

Unnamed: 0,path,name,session,emotion,newpath
0,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F000,1,sad,../data/wav2vecbase/Session1/Ses01F_impro02_F0...
1,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F001,1,sad,../data/wav2vecbase/Session1/Ses01F_impro02_F0...
2,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F002,1,sad,../data/wav2vecbase/Session1/Ses01F_impro02_F0...
3,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F003,1,neutral,../data/wav2vecbase/Session1/Ses01F_impro02_F0...
4,/home/jz3313/IEMOCAP_full_release/Session1/sen...,Ses01F_impro02_F004,1,sad,../data/wav2vecbase/Session1/Ses01F_impro02_F0...


In [8]:
holdout = 4
train = file[file['session'] != holdout]
test = file[file['session'] == holdout]

In [9]:
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, datas, labels, label_transform):
        super(MyDataSet).__init__()
        self.datas = datas
        self.labels = labels
        self.label_transform = label_transform
        
    def __getitem__(self, idx):
        audio = self.datas[idx]
        label = self.label_transform[self.labels[idx]]
        length = audio.size(1)
        return audio, length, label
    
    def __len__(self):
        return len(self.labels)

In [10]:
def collate_indic(data):
    audios, lengths, labels = zip(*data)
    max_len = max(lengths)
    n_ftrs = audios[0].size(2)
    n_dims = audios[0].size(0)
    features = torch.zeros((len(audios), n_dims, max_len, n_ftrs))
    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)

    for i in range(len(data)):
        j, k = audios[i].size(1), audios[i].size(2)
        features[i] = torch.cat([audios[i], torch.zeros((n_dims, max_len - j, k))], dim=1)

    return features, lengths, labels

In [11]:
categories = ['angry', 'happy', 'neutral', 'sad']
cate_dic = {}
for i, cate in enumerate(categories):
    cate_dic[cate] = i
cate_dic

{'angry': 0, 'happy': 1, 'neutral': 2, 'sad': 3}

## Train with 3CNN+LSTM

In [12]:
import torch.nn as nn
import torch.nn.functional as F

class ICASSP3CNN(nn.Module):
    def __init__(self, vocab_size, dims = 12, embed_size=128, hidden_size=512, num_lstm_layers = 2, bidirectional = False, label_size=7):
        super().__init__()
        self.n_layers = num_lstm_layers 
        self.hidden = hidden_size
        self.bidirectional = bidirectional
        
        self.aggr = nn.Conv1d(in_channels=dims, out_channels=1, kernel_size=1)
        
        self.embed = nn.Linear(in_features = vocab_size, out_features = embed_size)

        self.cnn  = nn.Conv1d(embed_size, embed_size, kernel_size=3, padding=1)
        self.cnn2 = nn.Conv1d(embed_size, embed_size, kernel_size=5, padding=2)
        self.cnn3 = nn.Conv1d(embed_size, embed_size, kernel_size=7, padding=3)

        self.batchnorm = nn.BatchNorm1d(3 * embed_size)

        self.lstm = nn.LSTM(input_size = 3 * embed_size, 
                            hidden_size = hidden_size, 
                            num_layers = num_lstm_layers, 
                            bidirectional = bidirectional)

        self.linear = nn.Linear(in_features = 2 * hidden_size if bidirectional else hidden_size, 
                                out_features = label_size)


    def forward(self, x, lengths):
        """
        padded_x: (B,T) padded LongTensor
        """
        n, d, b, t = x.size(0), x.size(1), x.size(2), x.size(3)
        x = torch.flatten(x, start_dim=2)
        input = self.aggr(x)
        input = torch.reshape(input, (n, b, t))
        input = self.embed(input)

        batch_size = input.size(0)
        input = input.transpose(1,2)    # (B,T,H) -> (B,H,T)

        cnn_output = torch.cat([self.cnn(input), self.cnn2(input), self.cnn3(input)], dim=1)

        input = F.relu(self.batchnorm(cnn_output))

        input = input.transpose(1,2)

        pack_tensor = nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=True, enforce_sorted=False)
        _, (hn, cn) = self.lstm(pack_tensor)

        if self.bidirectional:
            h_n = hn.view(self.n_layers, 2, batch_size, self.hidden)
            h_n = torch.cat([ h_n[-1, 0,:], h_n[-1,1,:] ], dim = 1)
        else:
            h_n = hn[-1]

        logits = self.linear(h_n)

        return logits

### Model Traning on each layer 

In [13]:
from tqdm import tqdm
from torch.utils.data import DataLoader

traindata = []
for _, row in tqdm(train.iterrows()):
    traindata.append(torch.load(row[4]))

train_dataset = MyDataSet(traindata, train['emotion'].tolist(), cate_dic)
trainloader_args = dict(batch_size=64, shuffle=True)
train_dataloader = DataLoader(train_dataset, **trainloader_args, 
                              collate_fn=collate_indic)

4500it [05:45, 13.03it/s]


In [15]:
train_dataset = MyDataSet(traindata, train['emotion'].tolist(), cate_dic)
trainloader_args = dict(batch_size=64, shuffle=True)
train_dataloader = DataLoader(train_dataset, **trainloader_args, 
                              collate_fn=collate_indic)

In [17]:
from tqdm import tqdm
from torchsummary import summary
import torch.optim as optim

model = ICASSP3CNN(768)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [18]:
epochs = 50
train_losses = []
train_accuracies = []
valid_losses = []
valid_accuracies = []

for epoch in tqdm(range(epochs)):
    train_loss = 0
    acc_cnt = 0
    err_cnt = 0
    batch_cnt = 0
    model.train()
    for batch, (x, length, y) in enumerate(train_dataloader):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(x, length)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().item()

        #model outputs
        out_val, out_indices = torch.max(logits, dim=1)
        tar_indices = y

        for i in range(len(out_indices)):
            if out_indices[i] == tar_indices[i]:
                acc_cnt += 1
            else:
                err_cnt += 1
        batch_cnt += 1
    
    train_loss = train_loss/batch_cnt
    train_accuracy = acc_cnt/(acc_cnt+err_cnt)
    train_accuracies.append(train_accuracy)
    train_losses.append(train_loss)
    
    print(f"epoch:{epoch+1}, train accu:{train_accuracy:.4f},", f"train loss:{train_loss:.2f}")

  2%|███▎                                                                                                                                                                 | 1/50 [01:49<1:29:07, 109.13s/it]

epoch:1, train accu:0.4889, train loss:1.12


  4%|██████▌                                                                                                                                                              | 2/50 [03:30<1:23:43, 104.66s/it]

epoch:2, train accu:0.5831, train loss:0.97


  6%|█████████▉                                                                                                                                                           | 3/50 [05:14<1:21:36, 104.17s/it]

epoch:3, train accu:0.6427, train loss:0.88


  8%|█████████████▏                                                                                                                                                       | 4/50 [06:56<1:19:25, 103.59s/it]

epoch:4, train accu:0.6693, train loss:0.83


 10%|████████████████▌                                                                                                                                                    | 5/50 [08:37<1:16:46, 102.37s/it]

epoch:5, train accu:0.7007, train loss:0.76


 12%|███████████████████▊                                                                                                                                                 | 6/50 [10:20<1:15:17, 102.67s/it]

epoch:6, train accu:0.7129, train loss:0.73


 14%|███████████████████████                                                                                                                                              | 7/50 [12:02<1:13:30, 102.57s/it]

epoch:7, train accu:0.7264, train loss:0.70


 16%|██████████████████████████▍                                                                                                                                          | 8/50 [13:42<1:11:10, 101.67s/it]

epoch:8, train accu:0.7362, train loss:0.67


 18%|█████████████████████████████▋                                                                                                                                       | 9/50 [15:26<1:10:03, 102.53s/it]

epoch:9, train accu:0.7540, train loss:0.64


 20%|████████████████████████████████▊                                                                                                                                   | 10/50 [17:10<1:08:31, 102.80s/it]

epoch:10, train accu:0.7600, train loss:0.62


 22%|████████████████████████████████████                                                                                                                                | 11/50 [18:51<1:06:28, 102.26s/it]

epoch:11, train accu:0.7853, train loss:0.56


 24%|███████████████████████████████████████▎                                                                                                                            | 12/50 [20:31<1:04:22, 101.66s/it]

epoch:12, train accu:0.7862, train loss:0.55


 26%|██████████████████████████████████████████▋                                                                                                                         | 13/50 [22:13<1:02:39, 101.61s/it]

epoch:13, train accu:0.8082, train loss:0.51


 28%|█████████████████████████████████████████████▉                                                                                                                      | 14/50 [23:50<1:00:15, 100.43s/it]

epoch:14, train accu:0.8211, train loss:0.47


 30%|█████████████████████████████████████████████████▊                                                                                                                    | 15/50 [25:33<59:02, 101.23s/it]

epoch:15, train accu:0.8429, train loss:0.43


 32%|█████████████████████████████████████████████████████                                                                                                                 | 16/50 [27:16<57:32, 101.55s/it]

epoch:16, train accu:0.8640, train loss:0.38


 34%|████████████████████████████████████████████████████████▍                                                                                                             | 17/50 [28:55<55:25, 100.76s/it]

epoch:17, train accu:0.8836, train loss:0.32


 36%|███████████████████████████████████████████████████████████▊                                                                                                          | 18/50 [30:36<53:46, 100.83s/it]

epoch:18, train accu:0.8978, train loss:0.29


 38%|███████████████████████████████████████████████████████████████                                                                                                       | 19/50 [32:19<52:29, 101.60s/it]

epoch:19, train accu:0.8987, train loss:0.28


 40%|██████████████████████████████████████████████████████████████████▍                                                                                                   | 20/50 [33:58<50:27, 100.91s/it]

epoch:20, train accu:0.9093, train loss:0.25


 42%|█████████████████████████████████████████████████████████████████████▋                                                                                                | 21/50 [35:41<49:04, 101.54s/it]

epoch:21, train accu:0.9422, train loss:0.17


 44%|█████████████████████████████████████████████████████████████████████████                                                                                             | 22/50 [37:23<47:20, 101.43s/it]

epoch:22, train accu:0.9311, train loss:0.20


 46%|████████████████████████████████████████████████████████████████████████████▎                                                                                         | 23/50 [39:05<45:44, 101.64s/it]

epoch:23, train accu:0.9424, train loss:0.16


 48%|███████████████████████████████████████████████████████████████████████████████▋                                                                                      | 24/50 [40:44<43:48, 101.09s/it]

epoch:24, train accu:0.9593, train loss:0.13


 50%|███████████████████████████████████████████████████████████████████████████████████                                                                                   | 25/50 [42:27<42:16, 101.46s/it]

epoch:25, train accu:0.9753, train loss:0.08


 52%|██████████████████████████████████████████████████████████████████████████████████████▎                                                                               | 26/50 [44:04<40:01, 100.08s/it]

epoch:26, train accu:0.9751, train loss:0.08


 54%|██████████████████████████████████████████████████████████████████████████████████████████▏                                                                            | 27/50 [45:41<38:05, 99.37s/it]

epoch:27, train accu:0.9756, train loss:0.07


 56%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                                         | 28/50 [47:20<36:23, 99.27s/it]

epoch:28, train accu:0.9653, train loss:0.10


 58%|████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                      | 29/50 [48:59<34:40, 99.08s/it]

epoch:29, train accu:0.9638, train loss:0.10


 60%|████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                  | 30/50 [50:39<33:06, 99.32s/it]

epoch:30, train accu:0.9902, train loss:0.03


 62%|███████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                               | 31/50 [52:19<31:31, 99.56s/it]

epoch:31, train accu:0.9878, train loss:0.03


 64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                            | 32/50 [53:56<29:37, 98.74s/it]

epoch:32, train accu:0.9951, train loss:0.02


 66%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                        | 33/50 [55:40<28:25, 100.31s/it]

epoch:33, train accu:0.9747, train loss:0.07


 68%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                     | 34/50 [57:18<26:32, 99.55s/it]

epoch:34, train accu:0.9860, train loss:0.04


 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                  | 35/50 [58:59<24:59, 99.97s/it]

epoch:35, train accu:0.9818, train loss:0.05


 72%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                              | 36/50 [1:00:37<23:14, 99.64s/it]

epoch:36, train accu:0.9876, train loss:0.03


 74%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                           | 37/50 [1:02:16<21:30, 99.30s/it]

epoch:37, train accu:0.9876, train loss:0.03


 76%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                       | 38/50 [1:03:52<19:41, 98.45s/it]

epoch:38, train accu:0.9793, train loss:0.06


 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                    | 39/50 [1:05:34<18:14, 99.48s/it]

epoch:39, train accu:0.9702, train loss:0.09


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                 | 40/50 [1:07:13<16:32, 99.27s/it]

epoch:40, train accu:0.9838, train loss:0.05


 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                             | 41/50 [1:08:51<14:50, 98.89s/it]

epoch:41, train accu:0.9940, train loss:0.02


 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                          | 42/50 [1:10:32<13:15, 99.39s/it]

epoch:42, train accu:0.9998, train loss:0.00


 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                       | 43/50 [1:12:11<11:35, 99.41s/it]

epoch:43, train accu:0.9978, train loss:0.01


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                   | 44/50 [1:13:53<10:00, 100.11s/it]

epoch:44, train accu:0.9942, train loss:0.02


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                | 45/50 [1:15:32<08:19, 99.91s/it]

epoch:45, train accu:0.9873, train loss:0.04


 92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊             | 46/50 [1:17:11<06:37, 99.44s/it]

epoch:46, train accu:0.9902, train loss:0.03


 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████          | 47/50 [1:18:43<04:51, 97.21s/it]

epoch:47, train accu:0.9800, train loss:0.06


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍      | 48/50 [1:20:20<03:14, 97.12s/it]

epoch:48, train accu:0.9960, train loss:0.02


 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋   | 49/50 [1:21:59<01:37, 97.96s/it]

epoch:49, train accu:0.9960, train loss:0.01


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [1:23:34<00:00, 100.30s/it]

epoch:50, train accu:0.9962, train loss:0.01





## Model Test

In [19]:
from tqdm import tqdm

testdata = []
for _, row in tqdm(test.iterrows()):
    testdata.append(torch.load(row[4]))
    
test_dataset = MyDataSet(testdata, test['emotion'].tolist(), cate_dic)
testloader_args = dict(batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, **testloader_args, 
                             collate_fn=collate_indic)

1031it [01:14, 13.92it/s]


In [20]:
test_loss = 0
acc_cnt = 0
err_cnt = 0
batch_cnt = 0
model.eval()

for x, lengths, y in test_dataloader:

    x = x.to(device)
    y = y.to(device)

    logits = model(x, lengths)
    loss = criterion(logits, y)
    test_loss += loss.cpu().item()

    out_val, out_indices = torch.max(logits, dim=1)
    tar_indices = y

    for i in range(len(out_indices)):
        if out_indices[i] == tar_indices[i]:
            acc_cnt += 1
        else:
            err_cnt += 1
    batch_cnt += 1

test_loss = test_loss/batch_cnt
test_accuracy = acc_cnt/(acc_cnt+err_cnt)
print(f'test accuracy: {test_accuracy}')

test accuracy: 0.597478176527643


In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

aggr.weight tensor([[[-0.1419],
         [ 0.1803],
         [ 0.2438],
         [ 0.0492],
         [ 0.1098],
         [-0.1890],
         [-0.3274],
         [ 0.0142],
         [ 0.0680],
         [-0.0127],
         [-0.0288],
         [ 0.0642]]], device='cuda:0')
aggr.bias tensor([0.0273], device='cuda:0')
embed.weight tensor([[-0.0940, -0.0012, -0.0701,  ..., -0.0345,  0.0468,  0.0314],
        [ 0.0636,  0.0790,  0.0462,  ..., -0.0323,  0.0046, -0.0669],
        [ 0.0461,  0.0560,  0.0467,  ...,  0.0322,  0.0976,  0.1140],
        ...,
        [-0.0691, -0.0931, -0.0305,  ..., -0.0112,  0.0366,  0.0109],
        [ 0.0146,  0.0125, -0.0390,  ..., -0.0882, -0.0384, -0.1333],
        [ 0.0336, -0.0391, -0.0775,  ..., -0.0416, -0.0158,  0.1385]],
       device='cuda:0')
embed.bias tensor([ 0.0509,  0.0562, -0.0691, -0.0017, -0.0205,  0.0345,  0.0351,  0.0864,
         0.0757,  0.1065, -0.0912,  0.0041, -0.0271,  0.1938,  0.0095,  0.0783,
         0.0093,  0.0909,  0.0383,  0.1205,

In [22]:
model_path = f'../models/wav2vecbase/holdout{holdout}.pth'

torch.save({'epoch':epochs,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict()},
            model_path)