In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!nvidia-smi

Sat Feb 10 00:54:19 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 2080        On  | 00000000:01:00.0  On |                  N/A |
| 32%   33C    P8              19W / 215W |    405MiB /  8192MiB |      3%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
import os
os.getcwd()

'/home/linh/projects/test_notebooks'

In [5]:
os.chdir('../IConNet')

In [6]:
import numpy as np
from tqdm import tqdm
import itertools
import matplotlib.pyplot as plt
import pandas as pd

# Audio
import librosa
import librosa.display

# Scikit learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import shuffle
from sklearn.utils import class_weight

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [7]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys
import IPython.display as ipd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [8]:
dataset = 'text_emotion'
input_feature = 'signals_3channels_win5stride1'
output_label = 'labels_13emotions'
data_dir = f'../data/nlp/{dataset}/'

In [9]:
data_path_prefix = f'{data_dir}preprocessed/{dataset}'
x_train = np.load(f'{data_path_prefix}.{input_feature}.train.npy', allow_pickle=True)
x_test = np.load(f'{data_path_prefix}.{input_feature}.test.npy', allow_pickle=True)
labels = np.load(f'{data_path_prefix}.classnames.npy', allow_pickle=True)
y_train = np.load(f'{data_path_prefix}.{output_label}.test.npy', allow_pickle=True)
y_test = np.load(f'{data_path_prefix}.{output_label}.test.npy', allow_pickle=True)

In [10]:
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape, labels, sep=' ')

(8000, 3, 1024) (8000, 3, 1024) (8000,) (8000,) ['anger' 'boredom' 'empty' 'enthusiasm' 'fun' 'happiness' 'hate' 'love'
 'neutral' 'relief' 'sadness' 'surprise' 'worry']


In [11]:
from torch.utils.data import TensorDataset, DataLoader

def create_data_loader(x_train, y_train, 
                       x_test, y_test, batch_size=32):
    x_train = torch.tensor(x_train, dtype=torch.float)
    y_train = torch.tensor(y_train)
    x_test = torch.tensor(x_test, dtype=torch.float)
    y_test = torch.tensor(y_test)
    train_data = TensorDataset(x_train, y_train)
    test_data = TensorDataset(x_test, y_test)
    
    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
    test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)
    return train_loader, test_loader

In [12]:
batch_size = 128
train_loader, test_loader = create_data_loader(x_train, y_train, 
                       x_test, y_test, batch_size=batch_size)

In [13]:
from IConNet.nn.model import M11

In [14]:
from omegaconf import OmegaConf as ocf

In [66]:
del model
gc.collect()

1133

In [67]:
model_config_path = 'config/model/m10.yaml'
model_config = ocf.load(model_config_path)
model_config

{'name': 'M10', 'description': 'gated residual FirConv', 'fe': {'n_block': 2, 'n_channel': [128, 128], 'kernel_size': [511, 127], 'stride': [2, 8], 'window_k': [2, 9], 'residual_connection_type': 'concat', 'pooling': 'None'}, 'cls': {'n_block': 2, 'n_hidden_dim': [256, 256]}}

In [68]:
labels

array(['anger', 'boredom', 'empty', 'enthusiasm', 'fun', 'happiness',
       'hate', 'love', 'neutral', 'relief', 'sadness', 'surprise',
       'worry'], dtype=object)

In [69]:
model = M11(config=model_config,
           n_input=3, n_output=len(labels))
model

M11(
  (fe_blocks): FeBlocks(
    (blocks): ModuleList(
      (0-1): 2 x Sequential(
        (layer): FirConvLayer()
        (norm): CustomNormLayer(
          (layer): LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1.0)
        )
      )
    )
    (act): NLReLU()
  )
  (seq_blocks): SeqBlocks(
    (blocks): LSTM(128, 64, num_layers=2, batch_first=True)
  )
  (cls_head): Classifier(
    (blocks): ModuleList(
      (0): Sequential(
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (layer): Linear(in_features=64, out_features=256, bias=True)
      )
      (1): Sequential(
        (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (layer): Linear(in_features=256, out_features=256, bias=True)
      )
    )
    (act): LeakyReLU(negative_slope=0.01)
    (output_layer): Linear(in_features=256, out_features=13, bias=True)
  )
)

In [70]:
train_loader_length = len(train_loader.dataset)
test_loader_length = len(test_loader.dataset)

def train(model, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # data = transform(data)
        data_length = len(data)
        data = data.to(device)
        target = target.to(device)
        output = model(data)
        del data
        gc.collect()
        torch.cuda.empty_cache()
        loss = F.cross_entropy(output.squeeze(), target)
        del target
        gc.collect()
        torch.cuda.empty_cache()
        optimizer.zero_grad()
        loss.backward()
        # nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * data_length}/{train_loader_length} ({100. * batch_idx / train_loader_length:.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        train_losses.append(loss.item())

def test(model, epoch):
    model.eval()
    correct = 0
    for data, target in test_loader:
        # data = transform(data)
        data = data.to(device)
        target = target.to(device)
        output = model(data).squeeze()
        del data
        gc.collect()
        torch.cuda.empty_cache()
        probs = F.softmax(output, dim=-1)
        pred = probs.argmax(dim=-1)
        correct += pred.eq(target).sum().item()
        del target
        gc.collect()
        torch.cuda.empty_cache()
        pbar.update(pbar_update)
    acc = correct / test_loader_length
    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{test_loader_length} ({100. * acc:.0f}%)\n")
    return acc

In [71]:
n_epoch = 10
train_losses = []
test_accuracy = []
optimizer = optim.RAdam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.2,
    steps_per_epoch=len(train_loader), epochs=n_epoch)

In [None]:
log_interval = 40
pbar_update = 1 / (len(train_loader) + len(test_loader))
model.to(device)
with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        train(model, epoch, log_interval)
        acc = test(model, epoch)
        test_accuracy += [acc]
        scheduler.step()

  0%|                                                                           | 0.015873015873015872/10 [00:00<04:11, 25.18s/it]



  3%|██▌                                                                         | 0.33333333333333304/10 [00:08<04:12, 26.13s/it]



 10%|███████▋                                                                     | 0.9999999999999973/10 [00:24<03:24, 22.70s/it]


Test Epoch: 1	Accuracy: 1955/8000 (24%)



 10%|███████▊                                                                     | 1.0158730158730132/10 [00:25<03:37, 24.17s/it]



 13%|██████████▎                                                                  | 1.3333333333333295/10 [00:33<03:48, 26.34s/it]



 20%|███████████████▍                                                             | 1.9999999999999938/10 [00:49<03:01, 22.72s/it]


Test Epoch: 2	Accuracy: 1812/8000 (23%)



 20%|███████████████▍                                                             | 2.0079365079365017/10 [00:49<03:11, 23.99s/it]



 23%|██████████████████▏                                                           | 2.333333333333326/10 [00:58<03:22, 26.40s/it]



 30%|███████████████████████                                                      | 2.9999999999999902/10 [01:14<02:42, 23.28s/it]


Test Epoch: 3	Accuracy: 1998/8000 (25%)



 30%|███████████████████████▌                                                      | 3.015873015873006/10 [01:14<02:48, 24.11s/it]



 33%|█████████████████████████▋                                                   | 3.3333333333333224/10 [01:23<02:53, 26.09s/it]



 40%|██████████████████████████████▊                                              | 3.9999999999999867/10 [01:39<02:19, 23.26s/it]


Test Epoch: 4	Accuracy: 2004/8000 (25%)



 40%|██████████████████████████████▉                                              | 4.0158730158730025/10 [01:39<02:25, 24.26s/it]



 43%|█████████████████████████████████▊                                            | 4.333333333333319/10 [01:48<02:28, 26.21s/it]



 50%|██████████████████████████████████████▉                                       | 4.999999999999983/10 [02:04<01:55, 23.13s/it]


Test Epoch: 5	Accuracy: 1728/8000 (22%)



 50%|███████████████████████████████████████                                       | 5.015873015872999/10 [02:04<02:00, 24.20s/it]



 53%|█████████████████████████████████████████▌                                    | 5.333333333333315/10 [02:13<02:05, 26.79s/it]



 60%|███████████████████████████████████████████████                                | 5.96031746031744/10 [02:28<01:32, 22.83s/it]

In [39]:
n_epoch = 10
n_epoch2 = 50
# train_losses = []
# test_accuracy = []
optimizer = optim.RAdam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.2,
    steps_per_epoch=len(train_loader), epochs=n_epoch2)

log_interval = 40
# pbar_update = 1 / (len(train_loader) + len(test_loader))
# model.to(device)
with tqdm(total=n_epoch2) as pbar:
    for epoch in range(n_epoch+1, n_epoch2 + n_epoch + 1):
        train(model, epoch, log_interval)
        acc = test(model, epoch)
        test_accuracy += [acc]
        scheduler.step()

  0%|                                                                                        | 0.015873015873015872/50 [00:00<17:05, 20.52s/it]



  1%|▌                                                                                        | 0.33333333333333304/50 [00:06<16:23, 19.80s/it]



  2%|█▊                                                                                        | 1.0079365079365052/50 [00:19<16:16, 19.93s/it]


Test Epoch: 11	Accuracy: 1969/8000 (25%)



  3%|██▍                                                                                       | 1.3333333333333295/50 [00:26<16:07, 19.88s/it]



  4%|███▌                                                                                      | 2.0079365079365017/50 [00:39<16:14, 20.30s/it]


Test Epoch: 12	Accuracy: 1974/8000 (25%)



  5%|████▏                                                                                      | 2.333333333333326/50 [00:46<15:45, 19.83s/it]



  6%|█████▍                                                                                     | 3.007936507936498/50 [00:59<15:38, 19.98s/it]


Test Epoch: 13	Accuracy: 2017/8000 (25%)



  7%|█████▉                                                                                    | 3.3333333333333224/50 [01:06<15:26, 19.85s/it]



  8%|███████▎                                                                                   | 4.007936507936495/50 [01:19<15:18, 19.98s/it]


Test Epoch: 14	Accuracy: 1951/8000 (24%)



  9%|███████▉                                                                                   | 4.333333333333319/50 [01:26<15:06, 19.84s/it]



 10%|█████████                                                                                  | 5.007936507936491/50 [01:39<14:47, 19.73s/it]


Test Epoch: 15	Accuracy: 1947/8000 (24%)



 11%|█████████▋                                                                                 | 5.333333333333315/50 [01:45<14:34, 19.59s/it]



 12%|██████████▊                                                                               | 6.0079365079364875/50 [01:58<14:27, 19.73s/it]


Test Epoch: 16	Accuracy: 1949/8000 (24%)



 13%|███████████▌                                                                               | 6.333333333333312/50 [02:05<14:13, 19.55s/it]



 14%|████████████▊                                                                              | 7.007936507936484/50 [02:18<14:08, 19.74s/it]


Test Epoch: 17	Accuracy: 1979/8000 (25%)



 15%|█████████████▎                                                                             | 7.333333333333308/50 [02:25<14:07, 19.85s/it]



 16%|██████████████▌                                                                            | 8.007936507936481/50 [02:38<14:22, 20.53s/it]


Test Epoch: 18	Accuracy: 1960/8000 (24%)



 17%|███████████████▏                                                                           | 8.333333333333306/50 [02:45<13:59, 20.14s/it]



 18%|████████████████▍                                                                          | 9.007936507936478/50 [02:58<13:49, 20.23s/it]


Test Epoch: 19	Accuracy: 1949/8000 (24%)



 19%|████████████████▉                                                                          | 9.333333333333302/50 [03:05<13:25, 19.80s/it]



 20%|██████████████████                                                                        | 10.007936507936474/50 [03:18<13:13, 19.84s/it]


Test Epoch: 20	Accuracy: 1866/8000 (23%)



 21%|██████████████████▌                                                                       | 10.333333333333298/50 [03:25<13:09, 19.91s/it]



 22%|████████████████████                                                                       | 11.00793650793647/50 [03:38<12:43, 19.58s/it]


Test Epoch: 21	Accuracy: 2017/8000 (25%)



 23%|████████████████████▍                                                                     | 11.333333333333295/50 [03:44<12:35, 19.53s/it]



 24%|█████████████████████▌                                                                    | 12.007936507936467/50 [03:57<12:21, 19.51s/it]


Test Epoch: 22	Accuracy: 1993/8000 (25%)



 25%|██████████████████████▏                                                                   | 12.333333333333291/50 [04:04<12:27, 19.85s/it]



 26%|███████████████████████▍                                                                  | 13.007936507936463/50 [04:17<12:09, 19.73s/it]


Test Epoch: 23	Accuracy: 1958/8000 (24%)



 27%|███████████████████████▉                                                                  | 13.333333333333288/50 [04:24<12:10, 19.92s/it]



 28%|█████████████████████████▍                                                                 | 14.00793650793646/50 [04:37<11:50, 19.75s/it]


Test Epoch: 24	Accuracy: 1952/8000 (24%)



 29%|█████████████████████████▊                                                                | 14.333333333333284/50 [04:43<11:35, 19.51s/it]



 30%|███████████████████████████                                                               | 15.007936507936456/50 [04:56<11:27, 19.65s/it]


Test Epoch: 25	Accuracy: 1991/8000 (25%)



 31%|███████████████████████████▉                                                               | 15.33333333333328/50 [05:03<11:21, 19.66s/it]



 32%|████████████████████████████▊                                                             | 16.007936507936453/50 [05:16<11:07, 19.65s/it]


Test Epoch: 26	Accuracy: 1905/8000 (24%)



 33%|█████████████████████████████▋                                                             | 16.33333333333335/50 [05:22<10:59, 19.60s/it]



 34%|██████████████████████████████▌                                                           | 17.007936507936673/50 [05:35<10:48, 19.64s/it]


Test Epoch: 27	Accuracy: 1868/8000 (23%)



 35%|███████████████████████████████▌                                                           | 17.33333333333357/50 [05:42<10:40, 19.61s/it]



 36%|████████████████████████████████▍                                                         | 18.007936507936893/50 [05:55<10:27, 19.62s/it]


Test Epoch: 28	Accuracy: 1959/8000 (24%)



 37%|█████████████████████████████████▎                                                         | 18.33333333333379/50 [06:01<10:24, 19.72s/it]



 38%|██████████████████████████████████▏                                                       | 19.007936507937114/50 [06:15<10:11, 19.74s/it]


Test Epoch: 29	Accuracy: 2002/8000 (25%)



 39%|███████████████████████████████████▏                                                       | 19.33333333333401/50 [06:21<10:01, 19.61s/it]



 40%|████████████████████████████████████                                                      | 20.007936507937334/50 [06:34<09:51, 19.74s/it]


Test Epoch: 30	Accuracy: 1943/8000 (24%)



 41%|█████████████████████████████████████                                                      | 20.33333333333423/50 [06:41<09:42, 19.64s/it]



 42%|█████████████████████████████████████▊                                                    | 21.007936507937554/50 [06:54<09:33, 19.77s/it]


Test Epoch: 31	Accuracy: 1996/8000 (25%)



 43%|██████████████████████████████████████▊                                                    | 21.33333333333445/50 [07:00<09:24, 19.69s/it]



 44%|███████████████████████████████████████▌                                                  | 22.007936507937774/50 [07:14<09:25, 20.20s/it]


Test Epoch: 32	Accuracy: 1927/8000 (24%)



 45%|████████████████████████████████████████▋                                                  | 22.33333333333467/50 [07:20<09:51, 21.39s/it]



 46%|█████████████████████████████████████████▍                                                | 23.007936507937995/50 [07:34<08:52, 19.74s/it]


Test Epoch: 33	Accuracy: 1952/8000 (24%)



 47%|██████████████████████████████████████████                                                | 23.333333333334892/50 [07:40<08:46, 19.74s/it]



 48%|███████████████████████████████████████████▏                                              | 24.007936507938215/50 [07:53<08:30, 19.65s/it]


Test Epoch: 34	Accuracy: 1856/8000 (23%)



 49%|███████████████████████████████████████████▊                                              | 24.333333333335112/50 [08:00<08:22, 19.58s/it]



 50%|█████████████████████████████████████████████                                             | 25.007936507938435/50 [08:13<08:14, 19.79s/it]


Test Epoch: 35	Accuracy: 1952/8000 (24%)



 51%|█████████████████████████████████████████████▌                                            | 25.333333333335332/50 [08:19<08:07, 19.77s/it]



 52%|██████████████████████████████████████████████▊                                           | 26.007936507938656/50 [08:32<07:48, 19.54s/it]


Test Epoch: 36	Accuracy: 1965/8000 (25%)



 53%|███████████████████████████████████████████████▍                                          | 26.333333333335553/50 [08:39<07:42, 19.54s/it]



 54%|████████████████████████████████████████████████▌                                         | 27.007936507938876/50 [08:52<07:31, 19.65s/it]


Test Epoch: 37	Accuracy: 1954/8000 (24%)



 55%|█████████████████████████████████████████████████▏                                        | 27.333333333335773/50 [08:58<07:21, 19.49s/it]



 56%|██████████████████████████████████████████████████▍                                       | 28.007936507939096/50 [09:11<07:34, 20.67s/it]


Test Epoch: 38	Accuracy: 1862/8000 (23%)



 57%|███████████████████████████████████████████████████                                       | 28.333333333335993/50 [09:18<07:03, 19.53s/it]



 58%|████████████████████████████████████████████████████▏                                     | 29.007936507939316/50 [09:31<06:50, 19.55s/it]


Test Epoch: 39	Accuracy: 1974/8000 (25%)



 59%|████████████████████████████████████████████████████▊                                     | 29.333333333336213/50 [09:37<06:44, 19.59s/it]



 60%|██████████████████████████████████████████████████████                                    | 30.007936507939537/50 [09:51<06:32, 19.61s/it]


Test Epoch: 40	Accuracy: 1899/8000 (24%)



 61%|██████████████████████████████████████████████████████▌                                   | 30.333333333336434/50 [09:57<06:32, 19.97s/it]



 62%|███████████████████████████████████████████████████████▌                                  | 30.857142857146073/50 [10:07<06:22, 19.96s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fda41dce100>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
 62%|███████████████████████████████████████████████████████▊                                  | 31.007936507939757/50 [10:10<06:20, 20.05s/it]


Test Epoch: 41	Accuracy: 1975/8000 (25%)



 62%|████████████████████████████████████████████████████████▋                                  | 31.14285714286042/50 [10:13<06:22, 20.28s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fda41dce100>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
 63%|████████████████████████████████████████████████████████▎                                 | 31.285714285717596/50 [10:16<06:18, 20.21s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fda41dce100>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInte



 63%|████████████████████████████████████████████████████████▋                                 | 31.484126984130338/50 [10:20<06:06, 19.79s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fda41dce100>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
 63%|████████████████████████████████████████████████████████▋                                 | 31.492063492066848/50 [10:20<06:08, 19.91s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fda41dce100>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInte

KeyboardInterrupt: 