In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import umap
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from tqdm import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.chdir('../IConNet/')
os.getcwd()

'/home/linh/projects/IConNet'

In [5]:
from einops import rearrange, reduce, repeat

In [6]:
dataset_name = 'meld'
experiment_prefix = "scb13"
log_dir = f'../{experiment_prefix}_models/meld/'
codebook_pretrained_path = f'../scb11_models/ravdess/epoch=220.masked_codebook.pt'
data_dir = "../data/data_preprocessed/"
sr = 16000


In [7]:
data_x = np.load(f'{data_dir}{dataset_name}/preprocessing/{dataset_name}.audio16k.npy', allow_pickle=True)
print(data_x.shape)

(4639,)


In [8]:
data_y = np.load(f'{data_dir}{dataset_name}/preprocessing/{dataset_name}.label_emotion_key.npy', allow_pickle=True)
print(data_y.shape)
np.unique(data_y)

(4639,)


array(['ang', 'hap', 'neu', 'sad'], dtype=object)

In [9]:
import pandas as pd
from IConNet.trainer.train_torch import get_dataloader
from IConNet.trainer.train_torch import Trainer_custom_model as Trainer
from IConNet.acov.model import SCB13 as SCB
from omegaconf import OmegaConf as ocf

sr = 16000
batch_size = 2
in_channels = 1
kernel_size = 511
stride = 125
embedding_dim = 511
num_embeddings = 384
cls_dim = 512
learning_rate = 1e-4
num_classes = 4
window_k = 5

model = SCB(
    in_channels=in_channels,    
    num_embeddings=num_embeddings, 
    stride=1,
    window_k=window_k,
    embedding_dim=embedding_dim, 
    num_classes=num_classes, 
    cls_dim=cls_dim,
    sample_rate=sr,
    codebook_pretrained_path=codebook_pretrained_path,
)

In [10]:
import IPython.display as ipd
import seaborn as sns

from IConNet.acov.visualize import (
    visualize_speech_codebook, get_embedding_color, 
    visualize_embedding_umap, visualize_training_curves,
    get_embedding_color_v2, get_zcs_color_v2
)

In [11]:
dataset_config_path = f'config/dataset/{dataset_name}.yaml'
dataset_config = ocf.load(dataset_config_path)
print(dataset_config)

{'name': 'meld', 'dataset_class': 'WaveformDataset', 'root': 'meld/', 'audio_dir': 'full_release/', 'feature_dir': 'preprocessing/', 'label_name': 'label_emotion_key', 'feature_name': 'audio16k', 'num_classes': 4, 'label_values': ['neu', 'hap', 'sad', 'ang'], 'classnames': ['neu', 'hap', 'sad', 'ang'], 'target_labels': ['ang', 'neu', 'sad', 'hap']}


In [12]:
train_loader, test_loader, batch_size = get_dataloader(
    dataset_config, data_dir, batch_size=2)
print(batch_size)

2


In [13]:
trainer = Trainer(batch_size=batch_size, log_dir=log_dir, 
                  experiment_prefix=experiment_prefix, device=device,
                 accumulate_grad_batches=16, gradient_clip_val=1.)

trainer.prepare(train_loader=train_loader, 
                test_loader=test_loader, 
                batch_size=batch_size)
trainer.setup(model, lr=1e-4)

In [14]:
trainer.fit(n_epoch=100, test_n_epoch=10)

  0%|▏                          | 0.4766307139188397/100 [02:43<385:32:25, 13945.93s/it]

Epoch: 1	Loss: 1.295	Val_acc: 45/185 (24.32%)

Saved new best val model: ../scb13_models/meld/model.epoch=1.step=870.loss=1.295.val_acc=0.243.pt


  1%|▎                          | 0.9527478171545639/100 [05:26<382:30:05, 13902.52s/it]

Epoch: 1	Loss: 1.910	Val_acc: 51/185 (27.57%)

Saved new best val model: ../scb13_models/meld/model.epoch=1.step=1739.loss=1.910.val_acc=0.276.pt


  1%|▍                              | 1.429378531073483/100 [09:14<13:07:20, 479.25s/it]

Correct: 252/928 (0.2716)
Saved new best test model: ../scb13_models/meld/model.epoch=1.step=1739.test_acc=0.2716.pt
{'acc_unweighted': tensor(0.2425, device='cuda:0'),
 'acc_weighted': tensor(0.2716, device='cuda:0'),
 'f1s_unweighted': tensor(0.1498, device='cuda:0'),
 'f1s_weighted': tensor(0.1663, device='cuda:0'),
 'rocauc': tensor(0.5158, device='cuda:0'),
 'uar': tensor(0.2425, device='cuda:0'),
 'wap': tensor(0.2140, device='cuda:0')}
{'acc_detail': tensor([0.0736, 0.0282, 0.0062, 0.8621], device='cuda:0'),
 'f1s_detail': tensor([0.1114, 0.0504, 0.0110, 0.4265], device='cuda:0'),
 'precision_detail': tensor([0.2289, 0.2333, 0.0476, 0.2834], device='cuda:0'),
 'recall_detail': tensor([0.0736, 0.0282, 0.0062, 0.8621], device='cuda:0'),
 'rocauc_detail': tensor([0.5463, 0.4929, 0.4983, 0.5255], device='cuda:0')}
tensor([[ 19,  11,   4, 224],
        [ 24,   7,   7, 210],
        [ 20,   5,   1, 135],
        [ 20,   7,   9, 225]], device='cuda:0')


  2%|▌                          | 1.9060092449924124/100 [12:00<378:50:38, 13903.39s/it]

Epoch: 2	Loss: 1.261	Val_acc: 39/185 (21.08%)



  2%|▋                          | 2.3821263482280743/100 [14:44<383:38:02, 14147.85s/it]

Epoch: 2	Loss: 1.469	Val_acc: 42/185 (22.70%)



  3%|▊                          | 2.8587570621467977/100 [17:32<385:40:02, 14292.61s/it]

Epoch: 3	Loss: 1.276	Val_acc: 40/185 (21.62%)



  3%|▉                           | 3.334874165382419/100 [20:18<374:32:00, 13948.37s/it]

Epoch: 3	Loss: 1.462	Val_acc: 44/185 (23.78%)



  4%|█                          | 3.8115048793011423/100 [23:04<373:40:15, 13985.21s/it]

Epoch: 4	Loss: 1.444	Val_acc: 45/185 (24.32%)



  4%|█▏                          | 4.287621982536764/100 [25:49<372:48:01, 14022.02s/it]

Epoch: 4	Loss: 1.231	Val_acc: 51/185 (27.57%)



  5%|█▎                          | 4.764252696455487/100 [28:35<369:34:30, 13970.28s/it]

Epoch: 5	Loss: 1.446	Val_acc: 51/185 (27.57%)



  5%|█▍                          | 5.240369799691108/100 [31:19<373:42:47, 14197.69s/it]

Epoch: 5	Loss: 1.808	Val_acc: 51/185 (27.57%)



  6%|█▌                          | 5.717000513609832/100 [34:05<365:39:28, 13961.88s/it]

Epoch: 6	Loss: 1.397	Val_acc: 47/185 (25.41%)



  6%|█▋                          | 6.193117616845453/100 [36:51<371:21:32, 14251.53s/it]

Epoch: 6	Loss: 1.383	Val_acc: 43/185 (23.24%)



  7%|█▉                           | 6.670261941447278/100 [39:38<255:15:40, 9846.17s/it]

Epoch: 7	Loss: 1.266	Val_acc: 48/185 (25.95%)



  7%|██                          | 7.145865433999798/100 [42:22<361:52:06, 14029.81s/it]

Epoch: 7	Loss: 1.510	Val_acc: 42/185 (22.70%)



  8%|██▏                         | 7.622496147918521/100 [45:09<363:05:32, 14149.90s/it]

Epoch: 8	Loss: 1.262	Val_acc: 61/185 (32.97%)

Saved new best val model: ../scb13_models/meld/model.epoch=8.step=13043.loss=1.262.val_acc=0.330.pt


  8%|██▎                         | 8.098613251154143/100 [47:54<367:47:12, 14407.10s/it]

Epoch: 8	Loss: 1.475	Val_acc: 51/185 (27.57%)



  9%|██▍                         | 8.575243965072866/100 [50:40<354:06:20, 13943.49s/it]

Epoch: 9	Loss: 1.282	Val_acc: 42/185 (22.70%)



  9%|██▌                         | 9.051361068308488/100 [53:25<355:51:47, 14086.05s/it]

Epoch: 9	Loss: 1.217	Val_acc: 50/185 (27.03%)



 10%|██▋                         | 9.527991782227211/100 [56:12<356:57:53, 14204.10s/it]

Epoch: 10	Loss: 1.769	Val_acc: 45/185 (24.32%)



 10%|██▋                        | 10.004108885462832/100 [58:59<359:08:21, 14366.22s/it]

Epoch: 10	Loss: 2.063	Val_acc: 51/185 (27.57%)



 10%|██▌                      | 10.480739599381556/100 [1:01:46<355:21:00, 14290.33s/it]

Epoch: 11	Loss: 1.724	Val_acc: 44/185 (23.78%)



 11%|██▋                      | 10.956856702617177/100 [1:04:30<349:50:56, 14144.34s/it]

Epoch: 11	Loss: 1.565	Val_acc: 64/185 (34.59%)

Saved new best val model: ../scb13_models/meld/model.epoch=11.step=19129.loss=1.565.val_acc=0.346.pt


 11%|███▍                          | 11.4334874165359/100 [1:08:21<11:44:54, 477.55s/it]

Correct: 252/928 (0.2716)
{'acc_unweighted': tensor(0.2533, device='cuda:0'),
 'acc_weighted': tensor(0.2716, device='cuda:0'),
 'f1s_unweighted': tensor(0.1313, device='cuda:0'),
 'f1s_weighted': tensor(0.1417, device='cuda:0'),
 'rocauc': tensor(0.5189, device='cuda:0'),
 'uar': tensor(0.2533, device='cuda:0'),
 'wap': tensor(0.2569, device='cuda:0')}
{'acc_detail': tensor([0.0233, 0.9556, 0.0000, 0.0345], device='cuda:0'),
 'f1s_detail': tensor([0.0433, 0.4191, 0.0000, 0.0627], device='cuda:0'),
 'precision_detail': tensor([0.3158, 0.2684, 0.0000, 0.3462], device='cuda:0'),
 'recall_detail': tensor([0.0233, 0.9556, 0.0000, 0.0345], device='cuda:0'),
 'rocauc_detail': tensor([0.5339, 0.5208, 0.5246, 0.4962], device='cuda:0')}
tensor([[  6, 245,   0,   7],
        [  6, 237,   0,   5],
        [  2, 154,   0,   5],
        [  5, 247,   0,   9]], device='cuda:0')


 12%|██▉                      | 11.910118130454624/100 [1:11:07<342:29:35, 13996.79s/it]

Epoch: 12	Loss: 1.609	Val_acc: 50/185 (27.03%)



 12%|███                      | 12.386235233690245/100 [1:13:51<346:38:18, 14243.17s/it]

Epoch: 12	Loss: 1.384	Val_acc: 46/185 (24.86%)



 13%|███▏                     | 12.862865947608968/100 [1:16:38<336:37:50, 13907.63s/it]

Epoch: 13	Loss: 1.319	Val_acc: 50/185 (27.03%)



 13%|███▍                      | 13.33898305084459/100 [1:19:22<347:52:45, 14451.32s/it]

Epoch: 13	Loss: 1.448	Val_acc: 44/185 (23.78%)



 14%|███▍                     | 13.815613764763313/100 [1:22:09<342:10:10, 14292.73s/it]

Epoch: 14	Loss: 1.374	Val_acc: 51/185 (27.57%)



 14%|███▌                     | 14.291730867998934/100 [1:24:54<342:45:21, 14396.77s/it]

Epoch: 14	Loss: 1.976	Val_acc: 59/185 (31.89%)



 15%|███▋                     | 14.768361581917658/100 [1:27:41<331:25:27, 13998.65s/it]

Epoch: 15	Loss: 1.302	Val_acc: 61/185 (32.97%)



 15%|███▊                     | 15.244478685153279/100 [1:30:25<332:25:15, 14119.62s/it]

Epoch: 15	Loss: 1.930	Val_acc: 51/185 (27.57%)



 16%|███▉                     | 15.721109399072002/100 [1:33:10<327:44:56, 13999.91s/it]

Epoch: 16	Loss: 0.875	Val_acc: 49/185 (26.49%)



 16%|████                     | 16.197226502308304/100 [1:35:56<331:29:37, 14240.30s/it]

Epoch: 16	Loss: 1.184	Val_acc: 50/185 (27.03%)



 17%|████▏                    | 16.673857216228676/100 [1:38:42<323:07:03, 13959.88s/it]

Epoch: 17	Loss: 0.859	Val_acc: 40/185 (21.62%)



 17%|████▎                    | 17.149974319465944/100 [1:41:28<328:04:55, 14255.83s/it]

Epoch: 17	Loss: 1.291	Val_acc: 42/185 (22.70%)



 18%|████▍                    | 17.626605033386316/100 [1:44:14<327:32:57, 14315.03s/it]

Epoch: 18	Loss: 1.375	Val_acc: 55/185 (29.73%)



 18%|████▌                    | 18.102722136623584/100 [1:46:59<320:41:53, 14097.09s/it]

Epoch: 18	Loss: 1.459	Val_acc: 60/185 (32.43%)



 19%|████▋                    | 18.579352850543955/100 [1:49:46<319:25:23, 14123.24s/it]

Epoch: 19	Loss: 1.074	Val_acc: 49/185 (26.49%)



 19%|████▊                    | 19.055469953781223/100 [1:52:31<317:36:18, 14125.45s/it]

Epoch: 19	Loss: 1.504	Val_acc: 49/185 (26.49%)



 20%|████▉                    | 19.532100667701595/100 [1:55:17<313:44:26, 14036.23s/it]

Epoch: 20	Loss: 1.359	Val_acc: 43/185 (23.24%)



 20%|█████                    | 20.008217770938863/100 [1:58:02<314:54:59, 14172.70s/it]

Epoch: 20	Loss: 1.243	Val_acc: 43/185 (23.24%)



 20%|█████                    | 20.484848484859235/100 [2:00:50<314:11:39, 14224.95s/it]

Epoch: 21	Loss: 1.744	Val_acc: 60/185 (32.43%)



 21%|█████▏                   | 20.960965588096503/100 [2:03:34<308:50:34, 14066.91s/it]

Epoch: 21	Loss: 1.351	Val_acc: 61/185 (32.97%)



 21%|██████                      | 21.437596302016875/100 [2:07:24<10:20:58, 474.26s/it]

Correct: 257/928 (0.2769)
Saved new best test model: ../scb13_models/meld/model.epoch=21.step=36519.test_acc=0.2769.pt
{'acc_unweighted': tensor(0.2584, device='cuda:0'),
 'acc_weighted': tensor(0.2769, device='cuda:0'),
 'f1s_unweighted': tensor(0.1289, device='cuda:0'),
 'f1s_weighted': tensor(0.1390, device='cuda:0'),
 'rocauc': tensor(0.5027, device='cuda:0'),
 'uar': tensor(0.2584, device='cuda:0'),
 'wap': tensor(0.1987, device='cuda:0')}
{'acc_detail': tensor([0.0000, 0.9839, 0.0000, 0.0498], device='cuda:0'),
 'f1s_detail': tensor([0.0000, 0.4258, 0.0000, 0.0897], device='cuda:0'),
 'precision_detail': tensor([0.0000, 0.2717, 0.0000, 0.4483], device='cuda:0'),
 'recall_detail': tensor([0.0000, 0.9839, 0.0000, 0.0498], device='cuda:0'),
 'rocauc_detail': tensor([0.5009, 0.5033, 0.4917, 0.5149], device='cuda:0')}
tensor([[  0, 249,   0,   9],
        [  0, 244,   0,   4],
        [  0, 158,   0,   3],
        [  1, 247,   0,  13]], device='cuda:0')


 22%|█████▍                   | 21.914227015937247/100 [2:10:11<308:45:38, 14234.84s/it]

Epoch: 22	Loss: 1.819	Val_acc: 39/185 (21.08%)



 22%|█████▌                   | 22.390344119174515/100 [2:12:55<304:35:09, 14128.52s/it]

Epoch: 22	Loss: 1.305	Val_acc: 44/185 (23.78%)



 23%|█████▋                   | 22.866974833094886/100 [2:15:42<305:35:33, 14262.81s/it]

Epoch: 23	Loss: 1.325	Val_acc: 53/185 (28.65%)



 23%|█████▊                   | 23.343091936332154/100 [2:18:27<306:22:23, 14388.05s/it]

Epoch: 23	Loss: 1.263	Val_acc: 61/185 (32.97%)



 24%|█████▉                   | 23.819722650252526/100 [2:21:14<297:06:38, 14040.36s/it]

Epoch: 24	Loss: 1.076	Val_acc: 44/185 (23.78%)



 24%|██████                   | 24.295839753489794/100 [2:23:58<296:17:15, 14089.52s/it]

Epoch: 24	Loss: 1.391	Val_acc: 51/185 (27.57%)



 25%|██████▏                  | 24.772470467410166/100 [2:26:45<300:49:50, 14396.20s/it]

Epoch: 25	Loss: 1.114	Val_acc: 42/185 (22.70%)



 25%|██████▎                  | 25.248587570647434/100 [2:29:29<294:11:57, 14168.53s/it]

Epoch: 25	Loss: 0.603	Val_acc: 52/185 (28.11%)



 26%|██████▍                  | 25.725218284567806/100 [2:32:15<290:02:20, 14057.81s/it]

Epoch: 26	Loss: 1.196	Val_acc: 41/185 (22.16%)



 26%|██████▌                  | 26.201335387805074/100 [2:35:00<293:12:38, 14303.22s/it]

Epoch: 26	Loss: 1.402	Val_acc: 42/185 (22.70%)



 27%|██████▋                  | 26.677966101725445/100 [2:37:50<294:08:29, 14441.91s/it]

Epoch: 27	Loss: 1.487	Val_acc: 64/185 (34.59%)



 27%|██████▊                  | 27.154083204962713/100 [2:40:34<285:11:19, 14093.84s/it]

Epoch: 27	Loss: 1.443	Val_acc: 50/185 (27.03%)



 28%|██████▉                  | 27.630713918883085/100 [2:43:21<283:28:08, 14101.13s/it]

Epoch: 28	Loss: 1.013	Val_acc: 56/185 (30.27%)



 28%|███████                  | 28.106831022120353/100 [2:46:06<282:24:23, 14141.31s/it]

Epoch: 28	Loss: 1.253	Val_acc: 64/185 (34.59%)



 29%|███████▏                 | 28.583461736040725/100 [2:48:53<279:34:51, 14093.26s/it]

Epoch: 29	Loss: 1.915	Val_acc: 52/185 (28.11%)



 29%|███████▎                 | 29.059578839277993/100 [2:51:38<280:32:36, 14236.68s/it]

Epoch: 29	Loss: 1.654	Val_acc: 51/185 (27.57%)



 30%|███████▍                 | 29.536209553198365/100 [2:54:27<285:25:12, 14582.13s/it]

Epoch: 30	Loss: 2.357	Val_acc: 56/185 (30.27%)



 30%|███████▌                 | 30.012326656435633/100 [2:57:14<281:31:44, 14481.19s/it]

Epoch: 30	Loss: 1.297	Val_acc: 32/185 (17.30%)



 30%|███████▌                 | 30.488957370356005/100 [3:00:01<271:30:14, 14061.28s/it]

Epoch: 31	Loss: 1.185	Val_acc: 43/185 (23.24%)



 31%|███████▋                 | 30.965074473593273/100 [3:02:47<273:41:24, 14272.26s/it]

Epoch: 31	Loss: 1.633	Val_acc: 50/185 (27.03%)



 31%|█████████                    | 31.441705187513644/100 [3:06:39<9:19:15, 489.45s/it]

Correct: 267/928 (0.2877)
Saved new best test model: ../scb13_models/meld/model.epoch=31.step=53909.test_acc=0.2877.pt
{'acc_unweighted': tensor(0.2563, device='cuda:0'),
 'acc_weighted': tensor(0.2877, device='cuda:0'),
 'f1s_unweighted': tensor(0.1372, device='cuda:0'),
 'f1s_weighted': tensor(0.1532, device='cuda:0'),
 'rocauc': tensor(0.5125, device='cuda:0'),
 'uar': tensor(0.2563, device='cuda:0'),
 'wap': tensor(0.2747, device='cuda:0')}
{'acc_detail': tensor([0.0194, 0.0403, 0.0000, 0.9655], device='cuda:0'),
 'f1s_detail': tensor([0.0368, 0.0725, 0.0000, 0.4394], device='cuda:0'),
 'precision_detail': tensor([0.3571, 0.3571, 0.0000, 0.2844], device='cuda:0'),
 'recall_detail': tensor([0.0194, 0.0403, 0.0000, 0.9655], device='cuda:0'),
 'rocauc_detail': tensor([0.4868, 0.5250, 0.5084, 0.5299], device='cuda:0')}
tensor([[  5,   6,   0, 247],
        [  5,  10,   0, 233],
        [  2,   5,   0, 154],
        [  2,   7,   0, 252]], device='cuda:0')


 32%|███████▉                 | 31.918335901434016/100 [3:09:25<271:09:19, 14338.07s/it]

Epoch: 32	Loss: 1.239	Val_acc: 45/185 (24.32%)



 32%|████████                 | 32.394453004668556/100 [3:12:10<270:07:11, 14383.90s/it]

Epoch: 32	Loss: 1.567	Val_acc: 46/185 (24.86%)



 33%|████████▌                 | 32.87108371858563/100 [3:14:57<266:37:54, 14298.98s/it]

Epoch: 33	Loss: 0.923	Val_acc: 60/185 (32.43%)



 33%|████████▎                | 33.347200821819605/100 [3:17:43<262:01:46, 14152.55s/it]

Epoch: 33	Loss: 1.420	Val_acc: 62/185 (33.51%)



 34%|████████▊                 | 33.82383153573668/100 [3:20:29<262:03:37, 14256.16s/it]

Epoch: 34	Loss: 1.841	Val_acc: 49/185 (26.49%)



 34%|████████▌                | 34.299948638970655/100 [3:23:15<258:09:44, 14145.87s/it]

Epoch: 34	Loss: 1.083	Val_acc: 50/185 (27.03%)



 35%|█████████                 | 34.77657935288773/100 [3:26:02<255:07:58, 14082.04s/it]

Epoch: 35	Loss: 0.026	Val_acc: 63/185 (34.05%)



 35%|████████▊                | 35.252696456121704/100 [3:28:47<255:25:36, 14201.93s/it]

Epoch: 35	Loss: 1.630	Val_acc: 50/185 (27.03%)



 36%|█████████▎                | 35.72932717003878/100 [3:31:34<252:36:44, 14149.60s/it]

Epoch: 36	Loss: 1.218	Val_acc: 53/185 (28.65%)



 36%|█████████                | 36.205444273272754/100 [3:34:19<251:45:36, 14207.11s/it]

Epoch: 36	Loss: 1.305	Val_acc: 42/185 (22.70%)



 37%|█████████▌                | 36.68207498718983/100 [3:37:04<246:24:03, 14009.36s/it]

Epoch: 37	Loss: 1.329	Val_acc: 60/185 (32.43%)



 37%|██████████                 | 37.1581920904238/100 [3:39:49<246:31:18, 14122.42s/it]

Epoch: 37	Loss: 1.653	Val_acc: 50/185 (27.03%)



 38%|█████████▊                | 37.63482280434088/100 [3:42:36<243:17:55, 14044.30s/it]

Epoch: 38	Loss: 1.422	Val_acc: 53/185 (28.65%)



 38%|█████████▉                | 38.11093990757485/100 [3:45:21<245:17:52, 14268.64s/it]

Epoch: 38	Loss: 1.341	Val_acc: 52/185 (28.11%)



 39%|██████████                | 38.58757062149193/100 [3:48:09<240:26:37, 14094.83s/it]

Epoch: 39	Loss: 1.045	Val_acc: 62/185 (33.51%)



 39%|██████████▌                | 39.0636877247259/100 [3:50:55<243:21:20, 14376.98s/it]

Epoch: 39	Loss: 1.113	Val_acc: 61/185 (32.97%)



 40%|██████████▎               | 39.54031843864298/100 [3:53:43<238:00:39, 14172.08s/it]

Epoch: 40	Loss: 1.603	Val_acc: 44/185 (23.78%)



 40%|██████████▍               | 40.01643554187695/100 [3:56:29<237:05:34, 14229.48s/it]

Epoch: 40	Loss: 1.125	Val_acc: 42/185 (22.70%)



 40%|██████████▌               | 40.49306625579403/100 [3:59:16<234:06:36, 14163.00s/it]

Epoch: 41	Loss: 0.431	Val_acc: 66/185 (35.68%)

Saved new best val model: ../scb13_models/meld/model.epoch=41.step=70430.loss=0.431.val_acc=0.357.pt


 41%|███████████▍                | 40.969183359028/100 [4:02:02<232:31:39, 14180.72s/it]

Epoch: 41	Loss: 1.303	Val_acc: 40/185 (21.62%)



 41%|████████████                 | 41.445814072945076/100 [4:05:55<8:02:22, 494.28s/it]

Correct: 273/928 (0.2942)
Saved new best test model: ../scb13_models/meld/model.epoch=41.step=71299.test_acc=0.2942.pt
{'acc_unweighted': tensor(0.2650, device='cuda:0'),
 'acc_weighted': tensor(0.2942, device='cuda:0'),
 'f1s_unweighted': tensor(0.1838, device='cuda:0'),
 'f1s_weighted': tensor(0.2036, device='cuda:0'),
 'rocauc': tensor(0.5280, device='cuda:0'),
 'uar': tensor(0.2650, device='cuda:0'),
 'wap': tensor(0.2719, device='cuda:0')}
{'acc_detail': tensor([0.8643, 0.0806, 0.0000, 0.1149], device='cuda:0'),
 'f1s_detail': tensor([0.4305, 0.1303, 0.0000, 0.1744], device='cuda:0'),
 'precision_detail': tensor([0.2866, 0.3390, 0.0000, 0.3614], device='cuda:0'),
 'recall_detail': tensor([0.8643, 0.0806, 0.0000, 0.1149], device='cuda:0'),
 'rocauc_detail': tensor([0.5457, 0.5384, 0.5106, 0.5175], device='cuda:0')}
tensor([[223,  11,   4,  20],
        [209,  20,   2,  17],
        [134,  11,   0,  16],
        [212,  17,   2,  30]], device='cuda:0')


 42%|██████████▉               | 41.92244478686215/100 [4:08:43<233:59:23, 14504.11s/it]

Epoch: 42	Loss: 2.559	Val_acc: 62/185 (33.51%)



 42%|██████████▌              | 42.398561890096126/100 [4:11:29<229:07:45, 14320.23s/it]

Epoch: 42	Loss: 1.285	Val_acc: 51/185 (27.57%)



 43%|███████████▌               | 42.8751926040132/100 [4:14:16<225:20:48, 14201.33s/it]

Epoch: 43	Loss: 1.631	Val_acc: 64/185 (34.59%)



 43%|██████████▊              | 43.351309707247175/100 [4:17:02<225:02:00, 14300.77s/it]

Epoch: 43	Loss: 1.368	Val_acc: 57/185 (30.81%)



 44%|███████████▍              | 43.82794042116425/100 [4:19:49<222:20:32, 14249.66s/it]

Epoch: 44	Loss: 1.767	Val_acc: 59/185 (31.89%)



 44%|███████████              | 44.304057524398225/100 [4:22:35<219:16:23, 14173.09s/it]

Epoch: 44	Loss: 1.141	Val_acc: 47/185 (25.41%)



 45%|████████████               | 44.7806882383153/100 [4:25:22<217:21:45, 14170.88s/it]

Epoch: 45	Loss: 1.487	Val_acc: 63/185 (34.05%)



 45%|███████████▎             | 45.256805341549274/100 [4:28:09<216:08:22, 14213.68s/it]

Epoch: 45	Loss: 1.892	Val_acc: 51/185 (27.57%)



 46%|███████████▉              | 45.73343605546635/100 [4:30:56<218:20:57, 14485.11s/it]

Epoch: 46	Loss: 1.148	Val_acc: 53/185 (28.65%)



 46%|███████████▌             | 46.209553158700324/100 [4:33:43<217:55:16, 14584.67s/it]

Epoch: 46	Loss: 1.291	Val_acc: 49/185 (26.49%)



 47%|████████████▌              | 46.6861838726174/100 [4:36:30<209:31:51, 14148.52s/it]

Epoch: 47	Loss: 1.175	Val_acc: 42/185 (22.70%)



 47%|████████████▎             | 47.16230097585137/100 [4:39:17<208:34:30, 14210.88s/it]

Epoch: 47	Loss: 1.683	Val_acc: 62/185 (33.51%)



 48%|████████████▍             | 47.63893168976845/100 [4:42:05<207:16:56, 14251.36s/it]

Epoch: 48	Loss: 0.551	Val_acc: 36/185 (19.46%)



 48%|████████████▌             | 48.11504879300242/100 [4:44:51<205:59:27, 14292.53s/it]

Epoch: 48	Loss: 1.329	Val_acc: 60/185 (32.43%)



 49%|█████████████              | 48.5916795069195/100 [4:47:38<198:33:38, 13904.73s/it]

Epoch: 49	Loss: 1.223	Val_acc: 52/185 (28.11%)



 49%|████████████▊             | 49.06779661015347/100 [4:50:21<198:20:34, 14019.31s/it]

Epoch: 49	Loss: 1.130	Val_acc: 48/185 (25.95%)



 50%|████████████▉             | 49.54442732407055/100 [4:53:08<195:55:02, 13978.69s/it]

Epoch: 50	Loss: 1.626	Val_acc: 63/185 (34.05%)



 50%|█████████████             | 50.02054442730452/100 [4:55:53<197:03:51, 14194.46s/it]

Epoch: 50	Loss: 1.157	Val_acc: 59/185 (31.89%)



 50%|██████████████▏             | 50.4976887519047/100 [4:58:40<136:14:21, 9907.86s/it]

Epoch: 51	Loss: 1.053	Val_acc: 57/185 (30.81%)



 51%|█████████████▎            | 50.97329224445557/100 [5:01:25<196:43:58, 14445.98s/it]

Epoch: 51	Loss: 1.637	Val_acc: 66/185 (35.68%)



 51%|██████████████▉              | 51.449922958372646/100 [5:05:18<6:32:47, 485.43s/it]

Correct: 267/928 (0.2877)
{'acc_unweighted': tensor(0.2674, device='cuda:0'),
 'acc_weighted': tensor(0.2877, device='cuda:0'),
 'f1s_unweighted': tensor(0.1842, device='cuda:0'),
 'f1s_weighted': tensor(0.1994, device='cuda:0'),
 'rocauc': tensor(0.5166, device='cuda:0'),
 'uar': tensor(0.2674, device='cuda:0'),
 'wap': tensor(0.3280, device='cuda:0')}
{'acc_detail': tensor([0.1705, 0.8468, 0.0062, 0.0460], device='cuda:0'),
 'f1s_detail': tensor([0.2234, 0.4171, 0.0117, 0.0845], device='cuda:0'),
 'precision_detail': tensor([0.3235, 0.2767, 0.1000, 0.5217], device='cuda:0'),
 'recall_detail': tensor([0.1705, 0.8468, 0.0062, 0.0460], device='cuda:0'),
 'rocauc_detail': tensor([0.5489, 0.5192, 0.4988, 0.4994], device='cuda:0')}
tensor([[ 44, 205,   3,   6],
        [ 33, 210,   3,   2],
        [ 19, 138,   1,   3],
        [ 40, 206,   3,  12]], device='cuda:0')


 52%|█████████████▌            | 51.92655367228972/100 [5:08:05<187:13:08, 14019.98s/it]

Epoch: 52	Loss: 1.586	Val_acc: 67/185 (36.22%)

Saved new best val model: ../scb13_models/meld/model.epoch=52.step=89559.loss=1.586.val_acc=0.362.pt


 52%|█████████████            | 52.402670775523696/100 [5:10:50<187:03:38, 14148.25s/it]

Epoch: 52	Loss: 2.467	Val_acc: 64/185 (34.59%)



 53%|█████████████▋            | 52.87930148944077/100 [5:13:38<183:25:50, 14014.03s/it]

Epoch: 53	Loss: 0.859	Val_acc: 47/185 (25.41%)



 53%|█████████████▎           | 53.355418592674745/100 [5:16:24<188:18:40, 14533.74s/it]

Epoch: 53	Loss: 0.753	Val_acc: 67/185 (36.22%)



 54%|█████████████▉            | 53.83204930659182/100 [5:19:10<180:26:23, 14070.01s/it]

Epoch: 54	Loss: 1.449	Val_acc: 60/185 (32.43%)



 54%|█████████████▌           | 54.308166409825795/100 [5:21:55<179:36:04, 14150.55s/it]

Epoch: 54	Loss: 0.547	Val_acc: 62/185 (33.51%)



 55%|██████████████▏           | 54.78479712374287/100 [5:24:42<177:58:07, 14169.73s/it]

Epoch: 55	Loss: 1.627	Val_acc: 66/185 (35.68%)



 55%|█████████████▊           | 55.260914226976844/100 [5:27:28<176:47:25, 14225.70s/it]

Epoch: 55	Loss: 1.153	Val_acc: 51/185 (27.57%)



 56%|██████████████▍           | 55.73754494089392/100 [5:30:15<172:01:25, 13991.21s/it]

Epoch: 56	Loss: 1.647	Val_acc: 40/185 (21.62%)



 56%|██████████████           | 56.213662044127894/100 [5:33:01<172:05:19, 14148.70s/it]

Epoch: 56	Loss: 1.302	Val_acc: 48/185 (25.95%)



 57%|██████████████▋           | 56.69080636872807/100 [5:35:50<121:50:57, 10128.52s/it]

Epoch: 57	Loss: 1.167	Val_acc: 48/185 (25.95%)



 57%|██████████████▊           | 57.16640986127894/100 [5:38:34<173:16:00, 14562.41s/it]

Epoch: 57	Loss: 1.185	Val_acc: 58/185 (31.35%)



 58%|███████████████▌           | 57.64355418587912/100 [5:41:22<116:37:32, 9912.35s/it]

Epoch: 58	Loss: 1.547	Val_acc: 42/185 (22.70%)



 58%|███████████████           | 58.11915767842999/100 [5:44:08<166:01:25, 14271.11s/it]

Epoch: 58	Loss: 0.470	Val_acc: 62/185 (33.51%)



 59%|███████████████▊           | 58.59630200303017/100 [5:46:56<113:44:03, 9889.05s/it]

Epoch: 59	Loss: 0.886	Val_acc: 63/185 (34.05%)



 59%|███████████████▎          | 59.07190549558104/100 [5:49:41<165:33:22, 14562.18s/it]

Epoch: 59	Loss: 1.741	Val_acc: 28/185 (15.14%)



 60%|███████████████▍          | 59.54853620949812/100 [5:52:29<160:02:54, 14243.60s/it]

Epoch: 60	Loss: 2.300	Val_acc: 50/185 (27.03%)



 60%|███████████████▌          | 60.02465331273209/100 [5:55:17<156:26:43, 14088.77s/it]

Epoch: 60	Loss: 2.485	Val_acc: 52/185 (28.11%)



 61%|███████████████▋          | 60.50128402664917/100 [5:58:04<155:12:10, 14145.53s/it]

Epoch: 61	Loss: 2.018	Val_acc: 64/185 (34.59%)



 61%|███████████████▊          | 60.97740112988314/100 [6:00:50<153:40:21, 14176.94s/it]

Epoch: 61	Loss: 1.530	Val_acc: 48/185 (25.95%)



 61%|█████████████████▊           | 61.454031843800216/100 [6:04:43<5:11:44, 485.25s/it]

Correct: 274/928 (0.2953)
Saved new best test model: ../scb13_models/meld/model.epoch=61.step=106079.test_acc=0.2953.pt
{'acc_unweighted': tensor(0.2706, device='cuda:0'),
 'acc_weighted': tensor(0.2953, device='cuda:0'),
 'f1s_unweighted': tensor(0.2154, device='cuda:0'),
 'f1s_weighted': tensor(0.2327, device='cuda:0'),
 'rocauc': tensor(0.5335, device='cuda:0'),
 'uar': tensor(0.2706, device='cuda:0'),
 'wap': tensor(0.3504, device='cuda:0')}
{'acc_detail': tensor([0.6434, 0.3589, 0.0186, 0.0613], device='cuda:0'),
 'f1s_detail': tensor([0.3981, 0.3207, 0.0357, 0.1070], device='cuda:0'),
 'precision_detail': tensor([0.2882, 0.2899, 0.4286, 0.4211], device='cuda:0'),
 'recall_detail': tensor([0.6434, 0.3589, 0.0186, 0.0613], device='cuda:0'),
 'rocauc_detail': tensor([0.5221, 0.5440, 0.5492, 0.5189], device='cuda:0')}
tensor([[166,  83,   1,   8],
        [151,  89,   1,   7],
        [ 96,  55,   3,   7],
        [163,  80,   2,  16]], device='cuda:0')


 62%|████████████████          | 61.93066255771729/100 [6:07:32<152:05:12, 14381.97s/it]

Epoch: 62	Loss: 1.116	Val_acc: 56/185 (30.27%)



 62%|███████████████▌         | 62.406779660951265/100 [6:10:16<147:24:25, 14115.98s/it]

Epoch: 62	Loss: 1.596	Val_acc: 60/185 (32.43%)



 63%|████████████████▎         | 62.88341037486834/100 [6:13:03<146:05:12, 14169.21s/it]

Epoch: 63	Loss: 1.213	Val_acc: 66/185 (35.68%)



 63%|███████████████▊         | 63.359527478102315/100 [6:15:50<145:08:11, 14259.95s/it]

Epoch: 63	Loss: 1.163	Val_acc: 60/185 (32.43%)



 64%|████████████████▌         | 63.83615819201939/100 [6:18:38<144:34:37, 14392.21s/it]

Epoch: 64	Loss: 1.182	Val_acc: 32/185 (17.30%)



 64%|████████████████▋         | 64.31227529525337/100 [6:21:25<143:32:02, 14479.01s/it]

Epoch: 64	Loss: 1.564	Val_acc: 64/185 (34.59%)



 65%|████████████████▊         | 64.78890600917045/100 [6:24:12<138:15:06, 14134.94s/it]

Epoch: 65	Loss: 1.087	Val_acc: 50/185 (27.03%)



 65%|████████████████▉         | 65.26502311240442/100 [6:27:00<137:25:31, 14243.03s/it]

Epoch: 65	Loss: 1.053	Val_acc: 54/185 (29.19%)



 66%|█████████████████▊         | 65.7416538263215/100 [6:29:48<135:17:23, 14216.77s/it]

Epoch: 66	Loss: 1.282	Val_acc: 49/185 (26.49%)



 66%|█████████████████▏        | 66.21777092955547/100 [6:32:34<134:44:33, 14358.83s/it]

Epoch: 66	Loss: 1.403	Val_acc: 44/185 (23.78%)



 67%|█████████████████▎        | 66.69440164347255/100 [6:35:23<130:20:59, 14089.50s/it]

Epoch: 67	Loss: 1.129	Val_acc: 57/185 (30.81%)



 67%|█████████████████▍        | 67.17051874670652/100 [6:38:08<129:52:25, 14241.63s/it]

Epoch: 67	Loss: 0.818	Val_acc: 65/185 (35.14%)



 68%|██████████████████▎        | 67.6471494606236/100 [6:40:57<127:34:49, 14196.27s/it]

Epoch: 68	Loss: 1.607	Val_acc: 58/185 (31.35%)



 68%|█████████████████▋        | 68.12326656385757/100 [6:43:43<126:03:10, 14235.80s/it]

Epoch: 68	Loss: 1.661	Val_acc: 51/185 (27.57%)



 69%|█████████████████▊        | 68.59989727777464/100 [6:46:32<124:59:57, 14331.09s/it]

Epoch: 69	Loss: 1.215	Val_acc: 44/185 (23.78%)



 69%|█████████████████▉        | 69.07601438100862/100 [6:49:19<123:16:02, 14350.12s/it]

Epoch: 69	Loss: 2.105	Val_acc: 30/185 (16.22%)



 70%|███████████████████▍        | 69.5531587056088/100 [6:52:07<85:51:47, 10152.37s/it]

Epoch: 70	Loss: 1.830	Val_acc: 39/185 (21.08%)



 70%|██████████████████▏       | 70.02876219815967/100 [6:54:54<119:22:52, 14339.49s/it]

Epoch: 70	Loss: 0.956	Val_acc: 34/185 (18.38%)



 71%|██████████████████▎       | 70.50539291207674/100 [6:57:43<115:51:54, 14142.06s/it]

Epoch: 71	Loss: 1.120	Val_acc: 47/185 (25.41%)



 71%|██████████████████▍       | 70.98151001531072/100 [7:00:29<113:58:16, 14139.14s/it]

Epoch: 71	Loss: 0.831	Val_acc: 55/185 (29.73%)



 71%|██████████████████████▏        | 71.4581407292278/100 [7:04:21<3:49:00, 481.42s/it]

Correct: 259/928 (0.2791)
{'acc_unweighted': tensor(0.2546, device='cuda:0'),
 'acc_weighted': tensor(0.2791, device='cuda:0'),
 'f1s_unweighted': tensor(0.2159, device='cuda:0'),
 'f1s_weighted': tensor(0.2328, device='cuda:0'),
 'rocauc': tensor(0.5378, device='cuda:0'),
 'uar': tensor(0.2546, device='cuda:0'),
 'wap': tensor(0.2659, device='cuda:0')}
{'acc_detail': tensor([0.1163, 0.2097, 0.0373, 0.6552], device='cuda:0'),
 'f1s_detail': tensor([0.1709, 0.2369, 0.0583, 0.3977], device='cuda:0'),
 'precision_detail': tensor([0.3226, 0.2723, 0.1333, 0.2855], device='cuda:0'),
 'recall_detail': tensor([0.1163, 0.2097, 0.0373, 0.6552], device='cuda:0'),
 'rocauc_detail': tensor([0.5689, 0.5373, 0.5213, 0.5235], device='cuda:0')}
tensor([[ 30,  53,  15, 160],
        [ 23,  52,  13, 160],
        [ 11,  36,   6, 108],
        [ 29,  50,  11, 171]], device='cuda:0')


 72%|██████████████████▋       | 71.93477144314487/100 [7:07:10<109:58:13, 14106.18s/it]

Epoch: 72	Loss: 0.800	Val_acc: 54/185 (29.19%)



 72%|██████████████████▊       | 72.41088854637884/100 [7:09:56<111:35:14, 14560.61s/it]

Epoch: 72	Loss: 1.883	Val_acc: 41/185 (22.16%)



 73%|██████████████████▉       | 72.88751926029592/100 [7:12:46<109:16:59, 14510.63s/it]

Epoch: 73	Loss: 1.401	Val_acc: 63/185 (34.05%)



 73%|███████████████████       | 73.36363636352989/100 [7:15:33<104:50:26, 14169.58s/it]

Epoch: 73	Loss: 1.381	Val_acc: 53/185 (28.65%)



 74%|███████████████████▏      | 73.84026707744697/100 [7:18:21<102:52:16, 14156.73s/it]

Epoch: 74	Loss: 1.338	Val_acc: 54/185 (29.19%)



 74%|███████████████████▎      | 74.31638418068094/100 [7:21:09<103:41:56, 14535.22s/it]

Epoch: 74	Loss: 1.116	Val_acc: 60/185 (32.43%)



 75%|████████████████████▏      | 74.79301489459802/100 [7:23:57<98:59:37, 14138.06s/it]

Epoch: 75	Loss: 0.786	Val_acc: 34/185 (18.38%)



 75%|████████████████████▎      | 75.26913199783199/100 [7:26:43<97:34:09, 14202.89s/it]

Epoch: 75	Loss: 1.121	Val_acc: 50/185 (27.03%)



 76%|████████████████████▍      | 75.74576271174907/100 [7:29:31<95:33:15, 14182.91s/it]

Epoch: 76	Loss: 1.857	Val_acc: 56/185 (30.27%)



 76%|████████████████████▌      | 76.22187981498304/100 [7:32:18<96:19:01, 14582.39s/it]

Epoch: 76	Loss: 0.764	Val_acc: 58/185 (31.35%)



 77%|████████████████████▋      | 76.69851052890012/100 [7:35:07<93:15:50, 14408.97s/it]

Epoch: 77	Loss: 1.525	Val_acc: 43/185 (23.24%)



 77%|████████████████████▊      | 77.17462763213409/100 [7:37:54<90:53:34, 14335.55s/it]

Epoch: 77	Loss: 0.751	Val_acc: 34/185 (18.38%)



 78%|████████████████████▉      | 77.65125834605116/100 [7:40:44<89:10:02, 14363.34s/it]

Epoch: 78	Loss: 1.771	Val_acc: 36/185 (19.46%)



 78%|█████████████████████      | 78.12737544928514/100 [7:43:30<87:05:12, 14333.57s/it]

Epoch: 78	Loss: 1.244	Val_acc: 60/185 (32.43%)



 79%|█████████████████████▏     | 78.60400616320221/100 [7:46:19<83:52:30, 14112.47s/it]

Epoch: 79	Loss: 2.059	Val_acc: 44/185 (23.78%)



 79%|█████████████████████▎     | 79.08012326643619/100 [7:49:05<83:22:30, 14347.63s/it]

Epoch: 79	Loss: 1.827	Val_acc: 41/185 (22.16%)



 80%|█████████████████████▍     | 79.55675398035326/100 [7:51:53<81:12:27, 14300.47s/it]

Epoch: 80	Loss: 0.623	Val_acc: 48/185 (25.95%)



 80%|█████████████████████▌     | 80.03287108358724/100 [7:54:40<79:43:35, 14374.39s/it]

Epoch: 80	Loss: 1.259	Val_acc: 57/185 (30.81%)



 81%|█████████████████████▋     | 80.50950179750431/100 [7:57:29<76:31:48, 14135.53s/it]

Epoch: 81	Loss: 1.143	Val_acc: 63/185 (34.05%)



 81%|█████████████████████▊     | 80.98561890073829/100 [8:00:15<76:15:39, 14438.50s/it]

Epoch: 81	Loss: 1.382	Val_acc: 45/185 (24.32%)



 81%|████████████████████████▍     | 81.46224961465536/100 [8:04:09<2:32:45, 494.44s/it]

Correct: 248/928 (0.2672)
{'acc_unweighted': tensor(0.2501, device='cuda:0'),
 'acc_weighted': tensor(0.2672, device='cuda:0'),
 'f1s_unweighted': tensor(0.2265, device='cuda:0'),
 'f1s_weighted': tensor(0.2382, device='cuda:0'),
 'rocauc': tensor(0.5107, device='cuda:0'),
 'uar': tensor(0.2501, device='cuda:0'),
 'wap': tensor(0.2571, device='cuda:0')}
{'acc_detail': tensor([0.5155, 0.0806, 0.1056, 0.2989], device='cuda:0'),
 'f1s_detail': tensor([0.3720, 0.1246, 0.1273, 0.2821], device='cuda:0'),
 'precision_detail': tensor([0.2910, 0.2740, 0.1604, 0.2671], device='cuda:0'),
 'recall_detail': tensor([0.5155, 0.0806, 0.1056, 0.2989], device='cuda:0'),
 'rocauc_detail': tensor([0.5244, 0.5035, 0.5426, 0.4721], device='cuda:0')}
tensor([[133,  19,  24,  82],
        [117,  20,  31,  80],
        [ 79,  13,  17,  52],
        [128,  21,  34,  78]], device='cuda:0')


 82%|██████████████████████     | 81.93888032857244/100 [8:06:57<70:51:42, 14124.42s/it]

Epoch: 82	Loss: 0.640	Val_acc: 50/185 (27.03%)



 82%|██████████████████████▎    | 82.41499743180641/100 [8:09:44<70:37:22, 14457.92s/it]

Epoch: 82	Loss: 0.952	Val_acc: 47/185 (25.41%)



 83%|██████████████████████▍    | 82.89162814572349/100 [8:12:33<68:01:04, 14312.56s/it]

Epoch: 83	Loss: 0.977	Val_acc: 63/185 (34.05%)



 83%|██████████████████████▌    | 83.36774524895746/100 [8:15:21<67:18:07, 14567.32s/it]

Epoch: 83	Loss: 0.759	Val_acc: 56/185 (30.27%)



 84%|███████████████████████▍    | 83.84488957355764/100 [8:18:10<44:51:06, 9994.75s/it]

Epoch: 84	Loss: 0.482	Val_acc: 49/185 (26.49%)



 84%|██████████████████████▊    | 84.32049306610851/100 [8:20:57<62:54:59, 14445.56s/it]

Epoch: 84	Loss: 1.220	Val_acc: 62/185 (33.51%)



 85%|██████████████████████▉    | 84.79712378002559/100 [8:23:46<59:50:36, 14170.77s/it]

Epoch: 85	Loss: 1.438	Val_acc: 69/185 (37.30%)

Saved new best val model: ../scb13_models/meld/model.epoch=85.step=146946.loss=1.438.val_acc=0.373.pt


 85%|███████████████████████    | 85.27324088325956/100 [8:26:33<58:51:28, 14387.98s/it]

Epoch: 85	Loss: 3.312	Val_acc: 63/185 (34.05%)



 86%|███████████████████████▏   | 85.74987159717664/100 [8:29:23<56:04:22, 14165.69s/it]

Epoch: 86	Loss: 0.305	Val_acc: 51/185 (27.57%)



 86%|███████████████████████▎   | 86.22598870041061/100 [8:32:10<56:00:05, 14636.63s/it]

Epoch: 86	Loss: 1.058	Val_acc: 49/185 (26.49%)



 87%|███████████████████████▍   | 86.70261941432769/100 [8:35:00<52:54:45, 14325.06s/it]

Epoch: 87	Loss: 0.612	Val_acc: 39/185 (21.08%)



 87%|███████████████████████▌   | 87.17873651756166/100 [8:37:47<51:34:39, 14482.15s/it]

Epoch: 87	Loss: 0.899	Val_acc: 58/185 (31.35%)



 88%|███████████████████████▋   | 87.65536723147873/100 [8:40:36<48:31:51, 14152.80s/it]

Epoch: 88	Loss: 0.446	Val_acc: 58/185 (31.35%)



 88%|███████████████████████▊   | 88.13148433471271/100 [8:43:23<47:00:00, 14256.24s/it]

Epoch: 88	Loss: 1.156	Val_acc: 55/185 (29.73%)



 89%|███████████████████████▉   | 88.60811504862978/100 [8:46:13<45:11:42, 14282.30s/it]

Epoch: 89	Loss: 1.276	Val_acc: 63/185 (34.05%)



 89%|████████████████████████   | 89.08423215186376/100 [8:48:59<43:25:55, 14323.82s/it]

Epoch: 89	Loss: 1.339	Val_acc: 52/185 (28.11%)



 90%|████████████████████████▏  | 89.56086286578083/100 [8:51:48<42:14:38, 14568.16s/it]

Epoch: 90	Loss: 0.947	Val_acc: 66/185 (35.68%)



 90%|████████████████████████▎  | 90.03697996901481/100 [8:54:36<39:28:48, 14265.63s/it]

Epoch: 90	Loss: 1.340	Val_acc: 55/185 (29.73%)



 91%|████████████████████████▍  | 90.51361068293188/100 [8:57:26<38:29:09, 14605.04s/it]

Epoch: 91	Loss: 1.031	Val_acc: 63/185 (34.05%)



 91%|████████████████████████▌  | 90.98972778616586/100 [9:00:14<36:21:38, 14527.66s/it]

Epoch: 91	Loss: 1.069	Val_acc: 39/185 (21.08%)



 91%|███████████████████████████▍  | 91.46635850008293/100 [9:04:09<1:08:55, 484.65s/it]

Correct: 256/928 (0.2759)
{'acc_unweighted': tensor(0.2947, device='cuda:0'),
 'acc_weighted': tensor(0.2759, device='cuda:0'),
 'f1s_unweighted': tensor(0.2575, device='cuda:0'),
 'f1s_weighted': tensor(0.2555, device='cuda:0'),
 'rocauc': tensor(0.5401, device='cuda:0'),
 'uar': tensor(0.2947, device='cuda:0'),
 'wap': tensor(0.2888, device='cuda:0')}
{'acc_detail': tensor([0.3760, 0.0645, 0.4969, 0.2414], device='cuda:0'),
 'f1s_detail': tensor([0.3573, 0.1026, 0.2930, 0.2769], device='cuda:0'),
 'precision_detail': tensor([0.3404, 0.2500, 0.2078, 0.3247], device='cuda:0'),
 'recall_detail': tensor([0.3760, 0.0645, 0.4969, 0.2414], device='cuda:0'),
 'rocauc_detail': tensor([0.5492, 0.5509, 0.5597, 0.5006], device='cuda:0')}
tensor([[ 97,  16,  86,  59],
        [ 63,  16, 117,  52],
        [ 46,  15,  80,  20],
        [ 79,  17, 102,  63]], device='cuda:0')


 92%|████████████████████████▊  | 91.94298921400001/100 [9:06:58<31:44:03, 14179.44s/it]

Epoch: 92	Loss: 1.393	Val_acc: 50/185 (27.03%)



 92%|████████████████████████▉  | 92.41910631723398/100 [9:09:45<30:08:51, 14316.49s/it]

Epoch: 92	Loss: 0.797	Val_acc: 63/185 (34.05%)



 93%|█████████████████████████  | 92.89573703115106/100 [9:12:33<28:00:20, 14191.54s/it]

Epoch: 93	Loss: 0.203	Val_acc: 59/185 (31.89%)



 93%|█████████████████████████▏ | 93.37185413438503/100 [9:15:21<26:28:35, 14380.44s/it]

Epoch: 93	Loss: 0.646	Val_acc: 60/185 (32.43%)



 94%|██████████████████████████▎ | 93.8484848483021/100 [9:18:11<24:39:00, 14425.73s/it]

Epoch: 94	Loss: 0.874	Val_acc: 58/185 (31.35%)



 94%|█████████████████████████▍ | 94.32460195153608/100 [9:20:59<22:29:16, 14264.43s/it]

Epoch: 94	Loss: 1.844	Val_acc: 48/185 (25.95%)



 95%|██████████████████████████▌ | 94.80174627613626/100 [9:23:49<14:25:31, 9990.18s/it]

Epoch: 95	Loss: 0.537	Val_acc: 43/185 (23.24%)



 95%|█████████████████████████▋ | 95.27734976868713/100 [9:26:36<18:44:02, 14280.57s/it]

Epoch: 95	Loss: 1.787	Val_acc: 52/185 (28.11%)



 96%|██████████████████████████▊ | 95.7539804826042/100 [9:29:26<16:53:13, 14317.86s/it]

Epoch: 96	Loss: 4.353	Val_acc: 43/185 (23.24%)



 96%|█████████████████████████▉ | 96.23009758583818/100 [9:32:14<14:58:35, 14301.56s/it]

Epoch: 96	Loss: 0.929	Val_acc: 53/185 (28.65%)



 97%|██████████████████████████ | 96.70672829975526/100 [9:35:03<13:01:49, 14244.18s/it]

Epoch: 97	Loss: 0.610	Val_acc: 52/185 (28.11%)



 97%|██████████████████████████▏| 97.18284540298923/100 [9:37:50<11:15:18, 14382.76s/it]

Epoch: 97	Loss: 1.758	Val_acc: 42/185 (22.70%)



 98%|████████████████████████████▎| 97.6594761169063/100 [9:40:40<9:21:29, 14393.92s/it]

Epoch: 98	Loss: 1.104	Val_acc: 62/185 (33.51%)



 98%|███████████████████████████▍| 98.13559322014028/100 [9:43:26<7:24:34, 14307.32s/it]

Epoch: 98	Loss: 0.741	Val_acc: 48/185 (25.95%)



 99%|███████████████████████████▌| 98.61222393405735/100 [9:46:16<5:27:23, 14155.00s/it]

Epoch: 99	Loss: 1.280	Val_acc: 66/185 (35.68%)



 99%|███████████████████████████▋| 99.08834103729133/100 [9:49:03<3:38:32, 14382.84s/it]

Epoch: 99	Loss: 0.522	Val_acc: 49/185 (26.49%)



100%|████████████████████████████▊| 99.5649717512084/100 [9:51:53<1:42:56, 14196.81s/it]

Epoch: 100	Loss: 1.510	Val_acc: 51/185 (27.57%)



  full_bar = Bar(frac,
100%|████████████████████████████| 100.04108885444238/100 [9:54:40<-1:59:46, 356.66s/it]

Epoch: 100	Loss: 1.180	Val_acc: 53/185 (28.65%)






In [15]:
trainer.fit(n_epoch=100, test_n_epoch=10)

  0%|▏                          | 0.4766307139188397/100 [02:47<393:47:18, 14244.28s/it]

Epoch: 101	Loss: 1.486	Val_acc: 48/185 (25.95%)



  1%|▎                              | 0.9522342064714617/100 [05:10<8:21:50, 304.00s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7b5c2f971070>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
  1%|▎                              | 0.9522342064714617/100 [05:16<9:08:42, 332.39s/it]

KeyboardInterrupt



In [16]:
codebook_pretrained_path = f'../scb11_models/ravdess/epoch=220.codebook.pt'

In [21]:
window_k = 9

model = SCB(
    in_channels=in_channels,    
    num_embeddings=num_embeddings, 
    stride=1,
    window_k=window_k,
    embedding_dim=embedding_dim, 
    num_classes=num_classes, 
    cls_dim=cls_dim,
    sample_rate=sr,
    codebook_pretrained_path=codebook_pretrained_path,
)

In [22]:
trainer = Trainer(batch_size=batch_size, log_dir=log_dir, 
                  experiment_prefix=experiment_prefix, device=device,
                 accumulate_grad_batches=16, gradient_clip_val=1.)

trainer.prepare(train_loader=train_loader, 
                test_loader=test_loader, 
                batch_size=batch_size)
trainer.setup(model, lr=1e-4)

In [23]:
trainer.fit(n_epoch=100, test_n_epoch=10)

  0%|▏                          | 0.4766307139188397/100 [02:54<416:12:06, 15055.02s/it]

Epoch: 1	Loss: 1.581	Val_acc: 53/185 (28.65%)

Saved new best val model: ../scb13_models/meld/model.epoch=1.step=870.loss=1.581.val_acc=0.286.pt


  1%|▎                          | 0.9527478171545639/100 [05:51<420:14:03, 15273.95s/it]

Epoch: 1	Loss: 1.709	Val_acc: 51/185 (27.57%)



  1%|▍                              | 1.429378531073483/100 [09:59<14:04:15, 513.91s/it]

Correct: 262/928 (0.2823)
Saved new best test model: ../scb13_models/meld/model.epoch=1.step=1739.test_acc=0.2823.pt
{'acc_unweighted': tensor(0.2510, device='cuda:0'),
 'acc_weighted': tensor(0.2823, device='cuda:0'),
 'f1s_unweighted': tensor(0.1138, device='cuda:0'),
 'f1s_weighted': tensor(0.1280, device='cuda:0'),
 'rocauc': tensor(0.4967, device='cuda:0'),
 'uar': tensor(0.2510, device='cuda:0'),
 'wap': tensor(0.1490, device='cuda:0')}
{'acc_detail': tensor([0.0078, 0.0000, 0.0000, 0.9962], device='cuda:0'),
 'f1s_detail': tensor([0.0150, 0.0000, 0.0000, 0.4403], device='cuda:0'),
 'precision_detail': tensor([0.2500, 0.0000, 0.0000, 0.2826], device='cuda:0'),
 'recall_detail': tensor([0.0078, 0.0000, 0.0000, 0.9962], device='cuda:0'),
 'rocauc_detail': tensor([0.4558, 0.4917, 0.4990, 0.5401], device='cuda:0')}
tensor([[  2,   0,   0, 256],
        [  4,   0,   0, 244],
        [  1,   0,   0, 160],
        [  1,   0,   0, 260]], device='cuda:0')


  2%|▌                          | 1.9060092449924124/100 [12:58<417:12:55, 15311.59s/it]

Epoch: 2	Loss: 1.784	Val_acc: 51/185 (27.57%)



  2%|▋                          | 2.3821263482280743/100 [15:54<413:24:21, 15245.79s/it]

Epoch: 2	Loss: 1.520	Val_acc: 51/185 (27.57%)



  3%|▊                          | 2.8587570621467977/100 [18:53<411:19:59, 15243.78s/it]

Epoch: 3	Loss: 1.015	Val_acc: 50/185 (27.03%)



  3%|▉                           | 3.334874165382419/100 [21:48<409:18:33, 15243.49s/it]

Epoch: 3	Loss: 1.216	Val_acc: 60/185 (32.43%)

Saved new best val model: ../scb13_models/meld/model.epoch=3.step=5217.loss=1.216.val_acc=0.324.pt


  4%|█                          | 3.8115048793011423/100 [24:44<404:20:14, 15132.94s/it]

Epoch: 4	Loss: 1.684	Val_acc: 61/185 (32.97%)

Saved new best val model: ../scb13_models/meld/model.epoch=4.step=6087.loss=1.684.val_acc=0.330.pt


  4%|█▏                          | 4.287621982536764/100 [27:39<393:47:17, 14811.43s/it]

Epoch: 4	Loss: 1.401	Val_acc: 31/185 (16.76%)



  5%|█▎                          | 4.764252696455487/100 [30:36<408:05:18, 15426.13s/it]

Epoch: 5	Loss: 1.769	Val_acc: 60/185 (32.43%)



  5%|█▍                          | 5.240369799691108/100 [33:30<392:41:35, 14918.76s/it]

Epoch: 5	Loss: 1.265	Val_acc: 51/185 (27.57%)



  6%|█▌                          | 5.717000513609832/100 [36:27<393:53:41, 15040.06s/it]

Epoch: 6	Loss: 1.294	Val_acc: 51/185 (27.57%)



  6%|█▋                          | 6.193117616845453/100 [39:20<392:31:00, 15063.50s/it]

Epoch: 6	Loss: 1.600	Val_acc: 59/185 (31.89%)



  7%|█▊                          | 6.669748330764176/100 [42:16<391:11:15, 15089.17s/it]

Epoch: 7	Loss: 1.473	Val_acc: 61/185 (32.97%)



  7%|██                          | 7.145865433999798/100 [45:12<397:04:37, 15394.87s/it]

Epoch: 7	Loss: 1.023	Val_acc: 51/185 (27.57%)



  8%|██▏                         | 7.622496147918521/100 [48:11<391:19:44, 15250.30s/it]

Epoch: 8	Loss: 1.657	Val_acc: 32/185 (17.30%)



  8%|██▎                         | 8.098613251154143/100 [51:08<389:57:37, 15275.70s/it]

Epoch: 8	Loss: 1.063	Val_acc: 50/185 (27.03%)



  9%|██▍                         | 8.575243965072866/100 [54:03<377:23:57, 14860.72s/it]

Epoch: 9	Loss: 1.231	Val_acc: 51/185 (27.57%)



  9%|██▌                         | 9.051361068308488/100 [56:57<380:53:39, 15076.86s/it]

Epoch: 9	Loss: 1.409	Val_acc: 46/185 (24.86%)



 10%|██▋                         | 9.527991782227211/100 [59:54<386:27:47, 15377.88s/it]

Epoch: 10	Loss: 1.248	Val_acc: 47/185 (25.41%)



 10%|██▌                      | 10.004108885462832/100 [1:02:49<378:14:15, 15130.19s/it]

Epoch: 10	Loss: 1.449	Val_acc: 43/185 (23.24%)



 10%|██▌                      | 10.480739599381556/100 [1:05:45<378:44:13, 15230.84s/it]

Epoch: 11	Loss: 1.201	Val_acc: 49/185 (26.49%)



 11%|██▋                      | 10.956856702617177/100 [1:08:39<375:51:18, 15195.77s/it]

Epoch: 11	Loss: 1.574	Val_acc: 43/185 (23.24%)



 11%|███▍                          | 11.4334874165359/100 [1:12:45<12:47:47, 520.14s/it]

Correct: 260/928 (0.2802)
{'acc_unweighted': tensor(0.2520, device='cuda:0'),
 'acc_weighted': tensor(0.2802, device='cuda:0'),
 'f1s_unweighted': tensor(0.1146, device='cuda:0'),
 'f1s_weighted': tensor(0.1274, device='cuda:0'),
 'rocauc': tensor(0.4988, device='cuda:0'),
 'uar': tensor(0.2520, device='cuda:0'),
 'wap': tensor(0.4477, device='cuda:0')}
{'acc_detail': tensor([0.9961, 0.0040, 0.0000, 0.0077], device='cuda:0'),
 'f1s_detail': tensor([0.4352, 0.0080, 0.0000, 0.0152], device='cuda:0'),
 'precision_detail': tensor([0.2784, 0.3333, 0.0000, 1.0000], device='cuda:0'),
 'recall_detail': tensor([0.9961, 0.0040, 0.0000, 0.0077], device='cuda:0'),
 'rocauc_detail': tensor([0.4717, 0.5140, 0.4913, 0.5180], device='cuda:0')}
tensor([[257,   1,   0,   0],
        [247,   1,   0,   0],
        [160,   1,   0,   0],
        [259,   0,   0,   2]], device='cuda:0')


 12%|██▉                      | 11.910118130454624/100 [1:15:42<369:00:48, 15080.60s/it]

Epoch: 12	Loss: 1.413	Val_acc: 41/185 (22.16%)



 12%|███▍                         | 11.929635336412503/100 [1:15:48<7:39:56, 313.35s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7b5c2f971070>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/audio/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
 12%|███▍                         | 11.930148947095605/100 [1:15:48<9:19:40, 381.29s/it]

KeyboardInterrupt

