In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
import umap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from torchvision.utils import make_grid

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
import os
os.chdir('../IConNet/')
os.getcwd()

'/home/linh/projects/IConNet'

In [5]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import gc
import sys
import numpy as np
from tqdm import tqdm

In [6]:
from IConNet.acov.audio_vqvae import VqVaeClsLoss
from IConNet.trainer.train_torch import get_dataloader
from IConNet.trainer.train_torch import Trainer_SCB10 as Trainer
from IConNet.acov.model import SCB16 as SCB
from omegaconf import OmegaConf as ocf

In [7]:
dataset_name = 'meld'
experiment_prefix = "scb16"
log_dir = f'../{experiment_prefix}_models/{dataset_name}/run0/'
data_dir = "../data/data_preprocessed/"
dataset_config_path = f'config/dataset/{dataset_name}4.yaml'
dataset_config = ocf.load(dataset_config_path)
print(dataset_config)
batch_size = 2

{'name': 'meld', 'dataset_class': 'WaveformDataset', 'root': 'meld/', 'audio_dir': 'full_release/', 'feature_dir': 'features_4balanced/', 'label_name': 'label_emotion', 'feature_name': 'audio16k', 'num_classes': 4, 'label_values': [0, 1, 2, 3], 'classnames': ['neu', 'hap', 'sad', 'ang'], 'target_labels': ['ang', 'neu', 'sad', 'hap']}


In [8]:
in_channels = 1
out_channels = 8
embedding_dim = 1023
num_embeddings = 384
commitment_cost = 0.25
learning_rate = 1e-4
num_tokens = 256
num_classes = 4

In [9]:
iconnet_config_path = f'config/model/m19win.yaml'
iconnet_config = ocf.load(iconnet_config_path)
print(iconnet_config)

{'name': 'M19', 'description': 'FirConv with learnable windows', 'fe': {'n_block': 1, 'n_channel': [256], 'kernel_size': [511], 'stride': [2], 'window_k': [5], 'pooling': 'mean', 'filter_type': 'sinc', 'learnable_bands': False, 'learnable_windows': True, 'shared_window': False, 'window_func': 'hamming', 'mel_resolution': 3, 'conv_mode': 'conv', 'norm_type': 'LocalResponseNorm'}, 'cls': {'n_block': 2, 'n_hidden_dim': [512, 512], 'norm_type': 'LayerNorm'}}


In [10]:
codebook_pretrained_path = f'../scb7_models/meld/codebook.epoch=75.pt'
model = SCB(
    in_channels=in_channels,
    out_channels=out_channels,
    num_embeddings=num_embeddings, 
    embedding_dim=embedding_dim, 
    num_tokens=num_tokens,
    num_classes=num_classes, 
    cls_dim=embedding_dim,
    sample_rate=16000,
    commitment_cost=commitment_cost,
    distance_type='euclidean',
    codebook_pretrained_path=codebook_pretrained_path,
    freeze_codebook=True,
    loss_type='signal_loss',
    iconnet_config=iconnet_config,
)

In [11]:
train_loader, test_loader, batch_size = get_dataloader(
    dataset_config, data_dir, batch_size=batch_size)

In [12]:
loss_ratio=VqVaeClsLoss(perplexity=0, loss_vq=1, loss_recon=1, loss_cls=1)
trainer = Trainer(batch_size=batch_size, log_dir=log_dir, 
                  experiment_prefix=experiment_prefix, device=device,
                 accumulate_grad_batches=4)
trainer.prepare(train_loader=train_loader, 
                test_loader=test_loader, 
                batch_size=batch_size,
               loss_ratio=loss_ratio)

In [13]:
trainer.setup(model=model, lr=learning_rate)

In [None]:
trainer.fit(n_epoch=70, self_supervised=False, train_task='embedding', test_n_epoch=10)

  1%|▌                                                                             | 0.47632390025215354/70 [10:48<1423:50:34, 73727.90s/it]

Epoch: 1	Loss: 1.653 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.653]	Val_acc: 198/340 (58.24%)

Saved new best val model: ../scb16_models/meld/run0/model.epoch=1.step=1275.loss=1.653.val_acc=0.582.pt


  1%|█                                                                              | 0.9523676099747226/70 [21:38<1394:42:25, 72717.13s/it]

Epoch: 1	Loss: 0.629 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.629]	Val_acc: 198/340 (58.24%)



  2%|█▋                                                                                 | 1.4286915102270215/70 [29:13<18:08:08, 952.12s/it]

Correct: 937/1700 (0.5512)
Saved new best test model: ../scb16_models/meld/run0/model.epoch=1.step=2549.test_acc=0.5512.pt
{'acc_unweighted': tensor(0.2500, device='cuda:0'),
 'acc_weighted': tensor(0.5512, device='cuda:0'),
 'f1s_unweighted': tensor(0.1777, device='cuda:0'),
 'f1s_weighted': tensor(0.3917, device='cuda:0'),
 'rocauc': tensor(0.5436, device='cuda:0'),
 'uar': tensor(0.2500, device='cuda:0'),
 'wap': tensor(0.3038, device='cuda:0')}
{'acc_detail': tensor([0., 1., 0., 0.], device='cuda:0'),
 'f1s_detail': tensor([0.0000, 0.7107, 0.0000, 0.0000], device='cuda:0'),
 'precision_detail': tensor([0.0000, 0.5512, 0.0000, 0.0000], device='cuda:0'),
 'recall_detail': tensor([0., 1., 0., 0.], device='cuda:0'),
 'rocauc_detail': tensor([0.6269, 0.5568, 0.4411, 0.5497], device='cuda:0')}
tensor([[  0, 257,   0,   0],
        [  0, 937,   0,   0],
        [  0, 161,   0,   0],
        [  0, 345,   0,   0]], device='cuda:0')


  3%|██▏                                                                            | 1.9050154104793393/70 [40:12<1389:08:22, 73440.10s/it]

Epoch: 2	Loss: 2.037 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=2.037]	Val_acc: 198/340 (58.24%)



  3%|██▋                                                                            | 2.3810591202020968/70 [51:16<1369:46:19, 72926.01s/it]

Epoch: 2	Loss: 1.134 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.134]	Val_acc: 198/340 (58.24%)



  4%|███▏                                                                         | 2.8573830204544146/70 [1:02:24<1370:37:35, 73489.17s/it]

Epoch: 3	Loss: 2.131 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=2.131]	Val_acc: 198/340 (58.24%)



  5%|███▋                                                                         | 3.3334267301771723/70 [1:13:26<1364:36:01, 73688.53s/it]

Epoch: 3	Loss: 0.626 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.626]	Val_acc: 198/340 (58.24%)



  5%|████▎                                                                          | 3.80975063042949/70 [1:24:39<1360:46:55, 74011.13s/it]

Epoch: 4	Loss: 1.360 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.360]	Val_acc: 189/340 (55.59%)



  6%|████▊                                                                         | 4.285794340152248/70 [1:35:49<1373:02:44, 75219.11s/it]

Epoch: 4	Loss: 2.424 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=2.424]	Val_acc: 198/340 (58.24%)



  7%|█████▎                                                                        | 4.762118240404566/70 [1:47:06<1347:21:17, 74350.63s/it]

Epoch: 5	Loss: 1.776 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.776]	Val_acc: 198/340 (58.24%)



  7%|█████▊                                                                        | 5.238161950127323/70 [1:58:16<1346:15:50, 74836.52s/it]

Epoch: 5	Loss: 1.080 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.080]	Val_acc: 198/340 (58.24%)



  8%|██████▎                                                                       | 5.714485850379641/70 [2:09:32<1322:13:34, 74044.91s/it]

Epoch: 6	Loss: 1.029 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.029]	Val_acc: 197/340 (57.94%)



  9%|██████▉                                                                       | 6.190529560102399/70 [2:20:44<1317:17:18, 74318.73s/it]

Epoch: 6	Loss: 0.898 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.898]	Val_acc: 194/340 (57.06%)



 10%|███████▍                                                                      | 6.666853460354717/70 [2:31:57<1283:56:19, 72982.00s/it]

Epoch: 7	Loss: 1.658 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.658]	Val_acc: 185/340 (54.41%)



 10%|███████▉                                                                      | 7.142897170077474/70 [2:43:05<1299:02:24, 74399.61s/it]

Epoch: 7	Loss: 0.387 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.387]	Val_acc: 197/340 (57.94%)



 11%|████████▍                                                                     | 7.619221070329792/70 [2:54:22<1293:46:50, 74664.20s/it]

Epoch: 8	Loss: 0.645 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.645]	Val_acc: 195/340 (57.35%)



 12%|█████████                                                                     | 8.095264780052549/70 [3:05:38<1294:17:32, 75268.12s/it]

Epoch: 8	Loss: 0.506 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.506]	Val_acc: 194/340 (57.06%)



 12%|█████████▌                                                                    | 8.571588680304867/70 [3:17:01<1280:00:06, 75014.26s/it]

Epoch: 9	Loss: 1.073 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.073]	Val_acc: 192/340 (56.47%)



 13%|██████████                                                                    | 9.047632390027625/70 [3:28:18<1263:32:07, 74627.57s/it]

Epoch: 9	Loss: 2.355 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=2.355]	Val_acc: 186/340 (54.71%)



 14%|██████████▌                                                                   | 9.523956290279942/70 [3:39:40<1248:53:12, 74343.37s/it]

Epoch: 10	Loss: 0.441 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.441]	Val_acc: 194/340 (57.06%)



 14%|███████████▎                                                                   | 10.0000000000027/70 [3:51:00<1245:49:21, 74749.35s/it]

Epoch: 10	Loss: 0.261 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.261]	Val_acc: 180/340 (52.94%)



 15%|███████████▌                                                                 | 10.476323900255018/70 [4:02:27<1266:13:26, 76581.40s/it]

Epoch: 11	Loss: 1.251 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.251]	Val_acc: 174/340 (51.18%)



 16%|████████████                                                                 | 10.952367609977776/70 [4:13:49<1229:24:17, 74954.02s/it]

Epoch: 11	Loss: 0.283 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.283]	Val_acc: 192/340 (56.47%)



 16%|█████████████                                                                   | 11.428691510230093/70 [4:21:50<16:25:36, 1009.65s/it]

Correct: 918/1700 (0.5400)
{'acc_unweighted': tensor(0.2760, device='cuda:0'),
 'acc_weighted': tensor(0.5400, device='cuda:0'),
 'f1s_unweighted': tensor(0.2304, device='cuda:0'),
 'f1s_weighted': tensor(0.4194, device='cuda:0'),
 'rocauc': tensor(0.5558, device='cuda:0'),
 'uar': tensor(0.2760, device='cuda:0'),
 'wap': tensor(0.3562, device='cuda:0')}
{'acc_detail': tensor([0.1712, 0.9328, 0.0000, 0.0000], device='cuda:0'),
 'f1s_detail': tensor([0.2211, 0.7003, 0.0000, 0.0000], device='cuda:0'),
 'precision_detail': tensor([0.3121, 0.5606, 0.0000, 0.0000], device='cuda:0'),
 'recall_detail': tensor([0.1712, 0.9328, 0.0000, 0.0000], device='cuda:0'),
 'rocauc_detail': tensor([0.6125, 0.5653, 0.5042, 0.5413], device='cuda:0')}
tensor([[ 44, 213,   0,   0],
        [ 63, 874,   0,   0],
        [  7, 154,   0,   0],
        [ 27, 318,   0,   0]], device='cuda:0')


 17%|█████████████                                                                | 11.905015410482411/70 [4:33:12<1201:33:46, 74457.82s/it]

Epoch: 12	Loss: 0.229 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.229]	Val_acc: 172/340 (50.59%)



 18%|█████████████▌                                                               | 12.381059120205169/70 [4:44:31<1207:15:03, 75428.39s/it]

Epoch: 12	Loss: 0.411 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.411]	Val_acc: 175/340 (51.47%)



 18%|██████████████▏                                                              | 12.857383020457487/70 [4:55:56<1194:17:01, 75240.20s/it]

Epoch: 13	Loss: 0.839 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.839]	Val_acc: 138/340 (40.59%)



 19%|██████████████▋                                                              | 13.333426730180244/70 [5:07:20<1202:34:38, 76399.16s/it]

Epoch: 13	Loss: 0.154 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.154]	Val_acc: 187/340 (55.00%)



 20%|███████████████▏                                                             | 13.809750630432562/70 [5:18:51<1174:24:53, 75242.48s/it]

Epoch: 14	Loss: 0.940 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.940]	Val_acc: 173/340 (50.88%)



 20%|███████████████▉                                                              | 14.28579434015532/70 [5:30:16<1173:28:09, 75824.28s/it]

Epoch: 14	Loss: 0.674 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.674]	Val_acc: 172/340 (50.59%)



 21%|████████████████▏                                                            | 14.762118240407638/70 [5:41:46<1179:15:13, 76855.12s/it]

Epoch: 15	Loss: 0.083 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.083]	Val_acc: 157/340 (46.18%)



 22%|████████████████▊                                                            | 15.238161950130396/70 [5:53:17<1162:45:17, 76438.59s/it]

Epoch: 15	Loss: 0.360 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.360]	Val_acc: 154/340 (45.29%)



 22%|█████████████████▎                                                           | 15.714485850382713/70 [6:04:52<1156:01:38, 76663.15s/it]

Epoch: 16	Loss: 0.539 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.539]	Val_acc: 166/340 (48.82%)



 23%|█████████████████▊                                                           | 16.190529560104263/70 [6:16:23<1148:30:45, 76838.61s/it]

Epoch: 16	Loss: 0.310 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.310]	Val_acc: 118/340 (34.71%)



 24%|██████████████████▌                                                           | 16.66685346035356/70 [6:27:58<1132:52:06, 76468.90s/it]

Epoch: 17	Loss: 1.651 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.651]	Val_acc: 151/340 (44.41%)



 24%|███████████████████▎                                                           | 17.1428971700733/70 [6:39:32<1112:59:08, 75803.41s/it]

Epoch: 17	Loss: 0.725 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.725]	Val_acc: 161/340 (47.35%)



 25%|███████████████████▉                                                           | 17.6192210703226/70 [6:51:09<1102:01:37, 75739.57s/it]

Epoch: 18	Loss: 0.314 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.314]	Val_acc: 118/340 (34.71%)



 26%|████████████████████▏                                                         | 18.09526478004234/70 [7:02:43<1112:01:07, 77127.22s/it]

Epoch: 18	Loss: 0.057 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.057]	Val_acc: 148/340 (43.53%)



 27%|████████████████████▍                                                        | 18.571588680291637/70 [7:14:22<1101:44:17, 77121.91s/it]

Epoch: 19	Loss: 2.037 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=2.037]	Val_acc: 163/340 (47.94%)



 27%|████████████████████▉                                                        | 19.047632390011376/70 [7:25:57<1091:37:13, 77127.60s/it]

Epoch: 19	Loss: 0.974 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.974]	Val_acc: 165/340 (48.53%)



 28%|█████████████████████▍                                                       | 19.523956290260674/70 [7:37:40<1099:54:15, 78446.24s/it]

Epoch: 20	Loss: 1.265 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.265]	Val_acc: 163/340 (47.94%)



 29%|█████████████████████▉                                                       | 19.999999999980414/70 [7:49:24<1070:23:41, 77068.44s/it]

Epoch: 20	Loss: 0.489 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.489]	Val_acc: 137/340 (40.29%)



 29%|██████████████████████▌                                                      | 20.476323900229712/70 [8:01:07<1064:03:20, 77348.87s/it]

Epoch: 21	Loss: 0.357 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.357]	Val_acc: 155/340 (45.59%)



 30%|███████████████████████▎                                                      | 20.95236760994945/70 [8:12:50<1053:51:36, 77351.26s/it]

Epoch: 21	Loss: 0.185 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.185]	Val_acc: 166/340 (48.82%)



 31%|████████████████████████▊                                                        | 21.42869151019875/70 [8:21:10<14:27:23, 1071.48s/it]

Correct: 804/1700 (0.4729)
{'acc_unweighted': tensor(0.3100, device='cuda:0'),
 'acc_weighted': tensor(0.4729, device='cuda:0'),
 'f1s_unweighted': tensor(0.2991, device='cuda:0'),
 'f1s_weighted': tensor(0.4486, device='cuda:0'),
 'rocauc': tensor(0.5688, device='cuda:0'),
 'uar': tensor(0.3100, device='cuda:0'),
 'wap': tensor(0.4428, device='cuda:0')}
{'acc_detail': tensor([0.2957, 0.6852, 0.0186, 0.2406], device='cuda:0'),
 'f1s_detail': tensor([0.2653, 0.6410, 0.0341, 0.2562], device='cuda:0'),
 'precision_detail': tensor([0.2405, 0.6023, 0.2000, 0.2739], device='cuda:0'),
 'recall_detail': tensor([0.2957, 0.6852, 0.0186, 0.2406], device='cuda:0'),
 'rocauc_detail': tensor([0.5967, 0.5796, 0.5549, 0.5439], device='cuda:0')}
tensor([[ 76, 132,   2,  47],
        [146, 642,   4, 145],
        [ 19, 111,   3,  28],
        [ 75, 181,   6,  83]], device='cuda:0')


 31%|████████████████████████                                                     | 21.905015410448048/70 [8:33:24<1076:15:34, 80560.06s/it]

Epoch: 22	Loss: 0.385 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.385]	Val_acc: 132/340 (38.82%)



 32%|████████████████████████▌                                                    | 22.381059120167787/70 [8:45:36<1068:04:36, 80746.78s/it]

Epoch: 22	Loss: 0.008 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.008]	Val_acc: 153/340 (45.00%)



 33%|█████████████████████████▏                                                   | 22.857383020417085/70 [8:57:50<1053:55:27, 80481.91s/it]

Epoch: 23	Loss: 0.000 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.000]	Val_acc: 156/340 (45.88%)



 33%|█████████████████████████▋                                                   | 23.333426730136825/70 [9:10:07<1043:46:12, 80519.58s/it]

Epoch: 23	Loss: 0.223 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.223]	Val_acc: 149/340 (43.82%)



 34%|██████████████████████████▏                                                  | 23.809750630386123/70 [9:22:25<1030:34:22, 80321.33s/it]

Epoch: 24	Loss: 0.107 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.107]	Val_acc: 132/340 (38.82%)



 35%|██████████████████████████▋                                                  | 24.285794340105863/70 [9:34:42<1025:37:16, 80767.80s/it]

Epoch: 24	Loss: 0.576 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.576]	Val_acc: 116/340 (34.12%)



 35%|███████████████████████████▌                                                  | 24.76211824035516/70 [9:47:03<1023:53:17, 81480.32s/it]

Epoch: 25	Loss: 0.251 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.251]	Val_acc: 142/340 (41.76%)



 36%|████████████████████████████▍                                                  | 25.2381619500749/70 [9:59:21<1005:17:50, 80851.69s/it]

Epoch: 25	Loss: 0.136 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.136]	Val_acc: 145/340 (42.65%)



 37%|█████████████████████████████                                                  | 25.7144858503242/70 [10:11:41<993:06:22, 80730.30s/it]

Epoch: 26	Loss: 0.320 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.320]	Val_acc: 141/340 (41.47%)



 37%|████████████████████████████▊                                                | 26.190529560043938/70 [10:23:55<957:50:20, 78709.48s/it]

Epoch: 26	Loss: 0.002 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.002]	Val_acc: 148/340 (43.53%)



 38%|█████████████████████████████▎                                               | 26.666853460293236/70 [10:35:53<950:15:54, 78945.44s/it]

Epoch: 27	Loss: 0.219 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.219]	Val_acc: 145/340 (42.65%)



 39%|█████████████████████████████▊                                               | 27.142897170012976/70 [10:47:50<938:32:36, 78837.73s/it]

Epoch: 27	Loss: 0.049 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.049]	Val_acc: 145/340 (42.65%)



 39%|██████████████████████████████▍                                              | 27.619221070262274/70 [10:59:48<930:01:37, 79000.37s/it]

Epoch: 28	Loss: 1.188 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.188]	Val_acc: 130/340 (38.24%)



 40%|██████████████████████████████▉                                              | 28.095264779982013/70 [11:11:43<922:32:57, 79255.42s/it]

Epoch: 28	Loss: 0.234 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.234]	Val_acc: 153/340 (45.00%)



 41%|███████████████████████████████▊                                              | 28.57158868023131/70 [11:23:46<921:30:13, 80075.81s/it]

Epoch: 29	Loss: 0.213 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.213]	Val_acc: 149/340 (43.82%)



 41%|████████████████████████████████▎                                             | 29.04763238995105/70 [11:35:45<901:11:00, 79220.35s/it]

Epoch: 29	Loss: 1.519 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.519]	Val_acc: 123/340 (36.18%)



 42%|████████████████████████████████▉                                             | 29.52395629020035/70 [11:47:47<889:28:09, 79110.74s/it]

Epoch: 30	Loss: 0.073 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.073]	Val_acc: 165/340 (48.53%)



 43%|█████████████████████████████████▍                                            | 29.99999999992009/70 [11:59:46<888:11:17, 79936.93s/it]

Epoch: 30	Loss: 0.000 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.000]	Val_acc: 152/340 (44.71%)



 44%|█████████████████████████████████▌                                           | 30.476323900169387/70 [12:11:50<874:10:25, 79623.82s/it]

Epoch: 31	Loss: 0.036 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.036]	Val_acc: 152/340 (44.71%)



 44%|██████████████████████████████████                                           | 30.952367609889126/70 [12:23:51<877:50:25, 80932.57s/it]

Epoch: 31	Loss: 0.226 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.226]	Val_acc: 136/340 (40.00%)



 45%|███████████████████████████████████▍                                           | 31.428691510138425/70 [12:32:33<11:42:15, 1092.41s/it]

Correct: 674/1700 (0.3965)
{'acc_unweighted': tensor(0.3152, device='cuda:0'),
 'acc_weighted': tensor(0.3965, device='cuda:0'),
 'f1s_unweighted': tensor(0.3073, device='cuda:0'),
 'f1s_weighted': tensor(0.4088, device='cuda:0'),
 'rocauc': tensor(0.5766, device='cuda:0'),
 'uar': tensor(0.3152, device='cuda:0'),
 'wap': tensor(0.4304, device='cuda:0')}
{'acc_detail': tensor([0.3268, 0.4909, 0.1242, 0.3188], device='cuda:0'),
 'f1s_detail': tensor([0.2824, 0.5393, 0.1278, 0.2799], device='cuda:0'),
 'precision_detail': tensor([0.2485, 0.5982, 0.1316, 0.2494], device='cuda:0'),
 'recall_detail': tensor([0.3268, 0.4909, 0.1242, 0.3188], device='cuda:0'),
 'rocauc_detail': tensor([0.5992, 0.5748, 0.5868, 0.5456], device='cuda:0')}
tensor([[ 84,  91,  22,  60],
        [157, 460,  89, 231],
        [ 21,  80,  20,  40],
        [ 76, 138,  21, 110]], device='cuda:0')


 46%|███████████████████████████████████                                          | 31.905015410387723/70 [12:44:41<849:22:00, 80265.69s/it]

Epoch: 32	Loss: 1.701 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.701]	Val_acc: 158/340 (46.47%)



 46%|███████████████████████████████████▌                                         | 32.381059120107466/70 [12:56:44<844:32:10, 80819.13s/it]

Epoch: 32	Loss: 0.292 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.292]	Val_acc: 147/340 (43.24%)



 47%|████████████████████████████████████▏                                        | 32.857383020356764/70 [13:08:53<837:09:07, 81139.88s/it]

Epoch: 33	Loss: 0.126 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.126]	Val_acc: 129/340 (37.94%)



 48%|█████████████████████████████████████▌                                         | 33.3334267300765/70 [13:21:00<827:28:16, 81242.84s/it]

Epoch: 33	Loss: 0.768 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.768]	Val_acc: 140/340 (41.18%)



 48%|██████████████████████████████████████▏                                        | 33.8097506303258/70 [13:33:11<805:53:51, 80166.12s/it]

Epoch: 34	Loss: 0.057 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.057]	Val_acc: 128/340 (37.65%)



 49%|██████████████████████████████████████▏                                       | 34.28579434004554/70 [13:45:17<796:07:08, 80248.98s/it]

Epoch: 34	Loss: 1.093 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=1.093]	Val_acc: 142/340 (41.76%)



 50%|██████████████████████████████████████▋                                       | 34.76211824029484/70 [13:57:28<787:10:20, 80419.72s/it]

Epoch: 35	Loss: 0.016 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.016]	Val_acc: 134/340 (39.41%)



 50%|███████████████████████████████████████▎                                      | 35.23816195001458/70 [14:09:39<787:42:36, 81576.71s/it]

Epoch: 35	Loss: 0.060 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.060]	Val_acc: 117/340 (34.41%)



 51%|███████████████████████████████████████▊                                      | 35.71448585026388/70 [14:21:53<759:29:58, 79747.93s/it]

Epoch: 36	Loss: 0.058 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.058]	Val_acc: 149/340 (43.82%)



 52%|████████████████████████████████████████▎                                     | 36.19052955998362/70 [14:34:07<758:19:27, 80745.64s/it]

Epoch: 36	Loss: 0.468 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.468]	Val_acc: 137/340 (40.29%)



 52%|████████████████████████████████████████▎                                    | 36.666853460232915/70 [14:46:24<751:40:04, 81180.58s/it]

Epoch: 37	Loss: 0.008 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.008]	Val_acc: 137/340 (40.29%)



 53%|████████████████████████████████████████▊                                    | 37.142897169952654/70 [14:58:39<738:11:44, 80880.68s/it]

Epoch: 37	Loss: 0.171 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.171]	Val_acc: 140/340 (41.18%)



 54%|█████████████████████████████████████████▉                                    | 37.61922107020195/70 [15:11:00<722:18:17, 80303.73s/it]

Epoch: 38	Loss: 0.814 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.814]	Val_acc: 138/340 (40.59%)



 54%|██████████████████████████████████████████▍                                   | 38.09526477992169/70 [15:23:15<705:01:18, 79551.78s/it]

Epoch: 38	Loss: 0.516 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.516]	Val_acc: 136/340 (40.00%)



 55%|██████████████████████████████████████████▉                                   | 38.57158868017099/70 [15:35:38<704:20:15, 80679.08s/it]

Epoch: 39	Loss: 0.645 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.645]	Val_acc: 133/340 (39.12%)



 56%|███████████████████████████████████████████▌                                  | 39.04763238989073/70 [15:47:58<699:55:32, 81406.78s/it]

Epoch: 39	Loss: 0.047 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.047]	Val_acc: 127/340 (37.35%)



 56%|████████████████████████████████████████████                                  | 39.52395629014003/70 [16:00:22<698:02:10, 82455.92s/it]

Epoch: 40	Loss: 0.019 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.019]	Val_acc: 124/340 (36.47%)



 57%|████████████████████████████████████████████▌                                 | 39.99999999985977/70 [16:12:46<680:36:49, 81673.63s/it]

Epoch: 40	Loss: 0.396 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.396]	Val_acc: 151/340 (44.41%)



 58%|████████████████████████████████████████████▌                                | 40.476323900109065/70 [16:25:15<673:56:16, 82177.33s/it]

Epoch: 41	Loss: 0.053 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.053]	Val_acc: 138/340 (40.59%)



 59%|█████████████████████████████████████████████                                | 40.952367609828805/70 [16:37:35<667:23:25, 82712.60s/it]

Epoch: 41	Loss: 0.174 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.174]	Val_acc: 143/340 (42.06%)



 59%|████████████████████████████████████████████████▌                                 | 41.4286915100781/70 [16:46:38<9:01:29, 1137.15s/it]

Correct: 698/1700 (0.4106)
{'acc_unweighted': tensor(0.3286, device='cuda:0'),
 'acc_weighted': tensor(0.4106, device='cuda:0'),
 'f1s_unweighted': tensor(0.3194, device='cuda:0'),
 'f1s_weighted': tensor(0.4227, device='cuda:0'),
 'rocauc': tensor(0.5775, device='cuda:0'),
 'uar': tensor(0.3286, device='cuda:0'),
 'wap': tensor(0.4457, device='cuda:0')}
{'acc_detail': tensor([0.1868, 0.5219, 0.2609, 0.3449], device='cuda:0'),
 'f1s_detail': tensor([0.2238, 0.5608, 0.1875, 0.3055], device='cuda:0'),
 'precision_detail': tensor([0.2791, 0.6059, 0.1463, 0.2742], device='cuda:0'),
 'recall_detail': tensor([0.1868, 0.5219, 0.2609, 0.3449], device='cuda:0'),
 'rocauc_detail': tensor([0.5930, 0.5729, 0.5801, 0.5641], device='cuda:0')}
tensor([[ 48, 102,  37,  70],
        [ 78, 489, 160, 210],
        [  8,  76,  42,  35],
        [ 38, 140,  48, 119]], device='cuda:0')


 60%|███████████████████████████████████████████████▎                               | 41.9050154103274/70 [16:59:10<646:01:26, 82779.41s/it]

Epoch: 42	Loss: 0.131 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.131]	Val_acc: 126/340 (37.06%)



 61%|███████████████████████████████████████████████▏                              | 42.38105912004714/70 [17:11:38<642:19:55, 83724.99s/it]

Epoch: 42	Loss: 0.116 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.116]	Val_acc: 145/340 (42.65%)



 61%|███████████████████████████████████████████████▊                              | 42.85738302029644/70 [17:24:09<620:56:22, 82356.93s/it]

Epoch: 43	Loss: 0.402 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.402]	Val_acc: 133/340 (39.12%)



 62%|████████████████████████████████████████████████▎                             | 43.33342673001618/70 [17:36:37<616:29:07, 83225.83s/it]

Epoch: 43	Loss: 0.000 [perplexity=0.000, loss_vq=0.000, loss_recon=0.000, loss_cls=0.000]	Val_acc: 141/340 (41.47%)



 62%|█████████████████████████████████████████████████▉                              | 43.67834127190258/70 [17:44:52<10:14:42, 1401.22s/it]

In [None]:
trainer.fit(n_epoch=30, self_supervised=False, train_task='embedding', test_n_epoch=1)