In [1]:
# Autor: Pablo Rocamora

import matplotlib.pyplot as plt
import neptune
import numpy as np
import random
import seaborn as sns
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import defaultdict
from dataset_handlers import (FiberCupHandler, 
                              HCPHandler, 
                              TractoinfernoHandler, 
                              HCP_Without_CC_Handler)
from encoders import (
    GCNEncoder,
    ProjectionHead
)
from graph_transformer import GraphTransformerEncoder
from loss_functions import MultiTaskTripletLoss
from streamline_datasets import (
    MaxMinNormalization,
    StreamlineTestDataset,
    StreamlineTripletDataset,
    collate_test_ds,
    collate_triplet_ds,
    fill_tracts_ds
)
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.classification import (
    MulticlassAUROC,
    MulticlassAccuracy,
    MulticlassConfusionMatrix,
    MulticlassF1Score
)
from tqdm import tqdm
import os
from gcn_encoder_model_v2 import SiameseGraphNetworkGCN_v2
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv,GraphConv, global_mean_pool, BatchNorm
from torch.nn import ModuleList
import torch.nn.functional as F


In [2]:

# Comando para lanzar tensorboard en el navegador local a través del puerto 8888 reenviado por ssh:
# tensorboard --logdir=runs/embedding_visualization --host 0.0.0.0 --port 8888


# Enable TensorFloat32 for better performance in matrix multiplication
torch.set_float32_matmul_precision('high')

# Seed setting function
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Función para guardar un checkpoint
def save_checkpoint(epoch, model, optimizer, loss, filename='checkpoint.pth'):
    checkpoint_dir = '/app/trained_models'
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, os.path.join(checkpoint_dir, filename))




# log = True
# if log:
#     run = neptune.init_run(
#         project="pablorocamora/tfm-tractography",
#         api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI1ODA0YzA2NS04MjczLTQyNzItOGE5Mi05ZmI5YjZkMmY3MDcifQ==",
#     )

    
    

# Configuration class
class CFG:
    def __init__(self):
        self.seed = 42
        self.max_epochs = 4
        self.batch_size = 1024
        self.learning_rate = 1e-4
        self.max_batches_per_subject = 150# Un buen valor es 500
        self.optimizer = "AdamW"
        self.classification_weight = 1
        self.margin = 1.0
        self.encoder = "GCNEncoder_v2"
        self.embedding_projection_dim = 512
        self.dataset = "HCP_105"#"Tractoinferno"

        dataset_paths = {
            "HCP_105": ("/app/dataset/HCP_105", 72),
            "HCP_105_without_CC": ("/app/dataset/HCP_105", 71),
            "Tractoinferno": ("/app/dataset/Tractoinferno/tractoinferno_preprocessed_mni", 32),
            "FiberCup": ("/app/dataset/Fibercup", 7)
        }
        self.ds_path, self.n_classes = dataset_paths.get(self.dataset, (None, None))
    
    
# Initialize configuration
cfg = CFG()
# if log:
#     writer = SummaryWriter(log_dir=f"runs/{cfg.dataset}_embedding_visualization", filename_suffix=f"{time.time()}")
#     run["config"] = {
#         "dataset": cfg.dataset,
#         "seed": cfg.seed,
#         "max_epochs": cfg.max_epochs,
#         "batch_size": cfg.batch_size,
#         "learning_rate": cfg.learning_rate,
#         "max_batches_per_subject": cfg.max_batches_per_subject,
#         "n_classes": cfg.n_classes,
#         "embedding_projection_dim": cfg.embedding_projection_dim,
#         "classification_weight": cfg.classification_weight,
#         "margin": cfg.margin,
#         "encoder": cfg.encoder,
#         "optimizer": cfg.optimizer
#     }

# Set the seed
seed_everything(cfg.seed)



# Cargar las rutas de los sujetos de entrenamiento, validación y test
if cfg.dataset == "HCP_105":
    handler = HCPHandler(path = cfg.ds_path, scope = "trainset")
    train_data = handler.get_data()

    handler = HCPHandler(path = cfg.ds_path, scope = "validset")
    valid_data = handler.get_data()

    handler = HCPHandler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()

elif cfg.dataset == "HCP_105_without_CC":
    handler = HCP_Without_CC_Handler(path = cfg.ds_path, scope = "trainset")
    train_data = handler.get_data()

    handler = HCP_Without_CC_Handler(path = cfg.ds_path, scope = "validset")
    valid_data = handler.get_data()

    handler = HCP_Without_CC_Handler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()

elif cfg.dataset == "Tractoinferno":
    handler = TractoinfernoHandler(path = cfg.ds_path, scope = "trainset")
    train_data = handler.get_data()
    train_data = fill_tracts_ds(train_data)# Hacer que todos los sujetos tengan el mismo número de tractos 

    handler = TractoinfernoHandler(path = cfg.ds_path, scope = "validset")
    valid_data = handler.get_data()

    handler = TractoinfernoHandler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()

elif cfg.dataset == "FiberCup":
    handler = FiberCupHandler(path = cfg.ds_path, scope = "trainset")
    train_data = handler.get_data()

    handler = FiberCupHandler(path = cfg.ds_path, scope = "validset")
    valid_data = handler.get_data()

    handler = FiberCupHandler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()



{0: 'AF_left', 1: 'AF_right', 2: 'ATR_left', 3: 'ATR_right', 4: 'CA', 5: 'CC_1', 6: 'CC_2', 7: 'CC_3', 8: 'CC_4', 9: 'CC_5', 10: 'CC_6', 11: 'CC_7', 12: 'CC', 13: 'CG_left', 14: 'CG_right', 15: 'CST_left', 16: 'CST_right', 17: 'MLF_left', 18: 'MLF_right', 19: 'FPT_left', 20: 'FPT_right', 21: 'FX_left', 22: 'FX_right', 23: 'ICP_left', 24: 'ICP_right', 25: 'IFO_left', 26: 'IFO_right', 27: 'ILF_left', 28: 'ILF_right', 29: 'MCP', 30: 'OR_left', 31: 'OR_right', 32: 'POPT_left', 33: 'POPT_right', 34: 'SCP_left', 35: 'SCP_right', 36: 'SLF_I_left', 37: 'SLF_I_right', 38: 'SLF_II_left', 39: 'SLF_II_right', 40: 'SLF_III_left', 41: 'SLF_III_right', 42: 'STR_left', 43: 'STR_right', 44: 'UF_left', 45: 'UF_right', 46: 'T_PREF_left', 47: 'T_PREF_right', 48: 'T_PREM_left', 49: 'T_PREM_right', 50: 'T_PREC_left', 51: 'T_PREC_right', 52: 'T_POSTC_left', 53: 'T_POSTC_right', 54: 'T_PAR_left', 55: 'T_PAR_right', 56: 'T_OCC_left', 57: 'T_OCC_right', 58: 'ST_FO_left', 59: 'ST_FO_right', 60: 'ST_PREF_left', 6

In [3]:
# Modelo para cargar pesos y finetunear
#============================================================================
class GraphConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(GraphConvBlock, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
        self.bn = BatchNorm(out_channels)
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.bn(x)
        x = self.relu(x)
        if self.dropout:
            x = self.dropout(x)
        return x
    

class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, dropout, n_hidden_blocks):
        super(GCNEncoder, self).__init__()
        self.input_block = GraphConvBlock(in_channels, hidden_dim, dropout)
        self.hidden_blocks = ModuleList([GraphConvBlock(hidden_dim, hidden_dim, dropout) for _ in range(n_hidden_blocks - 1)])
        self.output_block = GraphConvBlock(hidden_dim, out_channels, dropout)
        self.bn = BatchNorm(out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_block(x, edge_index)
        for layer in self.hidden_blocks:
            x = layer(x, edge_index)
        x = self.output_block(x, edge_index)
        x = self.bn(x)

        return global_mean_pool(x, batch) # (batch_size, out_channels)


class ProjectionHead(nn.Module):
    """
    Proyección de las embeddings de texto a un espacio de dimensión reducida.
    """
    def __init__(
        self,
        embedding_dim,# Salida del modelo de lenguaje (768)
        projection_dim, # Dimensión de la proyección (256)
        # dropout=0.1
    ):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        # self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        # x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x
    

class SiameseContrastiveGraphNetwork(nn.Module):
    def __init__(self, encoder, projection_head):
        super(SiameseContrastiveGraphNetwork, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head

    def forward(self, graph):
        x_1 = self.encoder(graph)
        x_1 = self.projection_head(x_1)
        # x1_norm = F.normalize(x_1, p=2, dim=1)
        return x_1


model = SiameseContrastiveGraphNetwork(
    encoder = GCNEncoder(
        in_channels = 3, 
        hidden_dim = 128, 
        out_channels = 512, 
        dropout = 0.15, 
        n_hidden_blocks = 2
    ),

    projection_head = ProjectionHead(
        embedding_dim = 512, 
        projection_dim = 128
    )
).cuda()
model = torch.compile(model, dynamic=True)

# Cargar pesos preentrenados
checkpoint = torch.load('/app/trained_models/checkpoint_HCP_105_GCN_512_1_infonce_1723061405.9630373.pth')

model.load_state_dict(checkpoint['model_state_dict'])


# Eliminar la última capa de proyección
model = model.encoder
# Congelar los pesos del modelo
for param in model.parameters():
    param.requires_grad = False

# Añadir una capa de clasificación
class FinalModel(nn.Module):
    def __init__(self, base_model):
        super(FinalModel, self).__init__()
        self.base_model = base_model
        self.fc = nn.Linear(512, 512)
        self.classifier = nn.Linear(512, 71)

    def forward(self, graph):
        x = self.base_model(graph)
        x = self.fc(x)
        x = F.relu(x)
        x = self.classifier(x)
        return x

model = FinalModel(model).cuda()

model

RuntimeError: Error(s) in loading state_dict for OptimizedModule:
	size mismatch for _orig_mod.encoder.input_block.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.input_block.conv.lin.weight: copying a param with shape torch.Size([64, 3]) from checkpoint, the shape in current model is torch.Size([128, 3]).
	size mismatch for _orig_mod.encoder.input_block.bn.module.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.input_block.bn.module.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.input_block.bn.module.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.input_block.bn.module.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.hidden_blocks.0.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.hidden_blocks.0.conv.lin.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for _orig_mod.encoder.hidden_blocks.0.bn.module.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.hidden_blocks.0.bn.module.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.hidden_blocks.0.bn.module.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.hidden_blocks.0.bn.module.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.encoder.output_block.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.output_block.conv.lin.weight: copying a param with shape torch.Size([128, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for _orig_mod.encoder.output_block.bn.module.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.output_block.bn.module.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.output_block.bn.module.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.output_block.bn.module.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.bn.module.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.bn.module.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.bn.module.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.encoder.bn.module.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for _orig_mod.projection_head.projection.weight: copying a param with shape torch.Size([64, 128]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for _orig_mod.projection_head.projection.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.projection_head.fc.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for _orig_mod.projection_head.fc.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.projection_head.layer_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for _orig_mod.projection_head.layer_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).

In [4]:
# Comprobar que capas son entrenables
for name, param in model.named_parameters():
    print(name, param.requires_grad)

base_model.input_block.conv.bias False
base_model.input_block.conv.lin.weight False
base_model.input_block.bn.module.weight False
base_model.input_block.bn.module.bias False
base_model.hidden_blocks.0.conv.bias False
base_model.hidden_blocks.0.conv.lin.weight False
base_model.hidden_blocks.0.bn.module.weight False
base_model.hidden_blocks.0.bn.module.bias False
base_model.output_block.conv.bias False
base_model.output_block.conv.lin.weight False
base_model.output_block.bn.module.weight False
base_model.output_block.bn.module.bias False
base_model.bn.module.weight False
base_model.bn.module.bias False
fc.weight True
fc.bias True
classifier.weight True
classifier.bias True


In [5]:



# Definir loss function, optimizador y scheduler
# Cross entropy loss
criterion = nn.CrossEntropyLoss(
    weight = torch.tensor([1.0]*cfg.n_classes)
).cuda()


optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr = cfg.learning_rate,
    weight_decay = 1e-4
)


scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 
    step_size = 10, 
    gamma = 0.1
)




# Crear las métricas con torchmetrics
subj_accuracy_train = MulticlassAccuracy(num_classes = cfg.n_classes, average='macro').cuda()
subj_accuracy_val = MulticlassAccuracy(num_classes = cfg.n_classes, average='macro').cuda()

subj_f1_train = MulticlassF1Score(num_classes = cfg.n_classes, average='macro').cuda()
subj_f1_val = MulticlassF1Score(num_classes = cfg.n_classes, average='macro').cuda()

subj_auroc_train = MulticlassAUROC(num_classes = cfg.n_classes).cuda()
subj_auroc_val = MulticlassAUROC(num_classes = cfg.n_classes).cuda()

subj_confusion_matrix = MulticlassConfusionMatrix(num_classes = cfg.n_classes).cuda()




# Entrenar el modelo
for epoch in range(cfg.max_epochs):

    #shuffle train_data for each epoch
    random.shuffle(train_data)

    model.train()# Establecer el modelo en modo de entrenamiento

    for idx_suj, subject in enumerate(train_data):# Iterar sobre los sujetos de entrenamiento

        # Crear el dataset y el dataloader
        train_ds = StreamlineTripletDataset(
            subject, 
            handler, 
            transform = MaxMinNormalization(dataset=cfg.dataset)
        )
        
        train_dl = DataLoader(
            train_ds, 
            batch_size = cfg.batch_size, 
            shuffle = True, 
            num_workers = 1,              
            collate_fn = collate_triplet_ds
        )
        
        # Bucle de entrenamiento del modelo
        prog_bar = tqdm(
            iterable = train_dl, 
            total = cfg.max_batches_per_subject
        )

        for i, (graph_anch, _, _) in enumerate(prog_bar):

            # Enviar a la gpu
            graph_anch = graph_anch.to('cuda')
            target = graph_anch.y

            # Reiniciar los gradientes
            optimizer.zero_grad()

            # Forward pass
            preds = model(graph_anch)# Batch size necesario para calcular la media de los embeddings

            # Calcular la pérdida
            loss = criterion(preds, target)
            
            # Backward pass
            loss.backward()

            # Actualizar los pesos
            optimizer.step()

            # Concatenar las predicciones de los dos grafos
            # preds = torch.argmax(pred, dim=-1)
            #argmax(pred, dim=-1)

            # Concatenar las etiquetas de los dos grafos
            targets = target


            # calcular métricas de entrenamiento
            subj_accuracy_train.update(preds, targets)
            subj_f1_train.update(preds, targets)
            subj_auroc_train.update(preds, targets)


            # Mostar métricas de entrenamiento cada 100 batches
            if i % 25 == 0:
                print(f"[TRAIN] Epoch {epoch+1}/{cfg.max_epochs} - Subj {idx_suj} - Batch {i} - Loss {loss.item()} - Acc.: {subj_accuracy_train.compute().item():.4f}, F1: {subj_f1_train.compute().item():.4f}, AUROC: {subj_auroc_train.compute().item():.4f}")
                
                # if log:# Loggear las métricas de batch
                #     run["train/batch/loss"].log(loss.item())
                #     run["train/batch/acc"].log(subj_accuracy_train.compute().item())
                #     run["train/batch/f1"].log(subj_f1_train.compute().item())
                #     run["train/batch/auroc"].log(subj_auroc_train.compute().item())

            # Condicion de parada del bucle
            if i == cfg.max_batches_per_subject:

                # if log:# Loggear las métricas de subject
                #     run["train/subject/acc"].log(subj_accuracy_train.compute().item())
                #     run["train/subject/f1"].log(subj_f1_train.compute().item())
                #     run["train/subject/auroc"].log(subj_auroc_train.compute().item())
                
                # Reiniciar las métricas
                subj_accuracy_train.reset()
                subj_f1_train.reset()
                subj_auroc_train.reset()

                break
        
        # Actualizar el learning rate
        scheduler.step()

    # Save the model checkpoint
    checkpoint_name = f'checkpoint_{cfg.dataset}_{cfg.encoder}_{cfg.embedding_projection_dim}_{epoch}__finetuned.pth'
    save_checkpoint(epoch, model, optimizer, loss, filename=checkpoint_name)


  2%|▏         | 3/150 [00:00<00:27,  5.29it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 0 - Loss 4.264024257659912 - Acc.: 0.0061, F1: 0.0056, AUROC: 0.4925


 18%|█▊        | 27/150 [00:04<00:21,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 25 - Loss 3.813040018081665 - Acc.: 0.2299, F1: 0.1890, AUROC: 0.8526


 35%|███▍      | 52/150 [00:08<00:15,  6.33it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 50 - Loss 3.4241340160369873 - Acc.: 0.3602, F1: 0.3192, AUROC: 0.9129


 51%|█████     | 76/150 [00:12<00:12,  6.15it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 75 - Loss 2.834901809692383 - Acc.: 0.4524, F1: 0.4245, AUROC: 0.9306


 68%|██████▊   | 102/150 [00:17<00:07,  6.13it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 100 - Loss 2.2784104347229004 - Acc.: 0.5147, F1: 0.4924, AUROC: 0.9392


 84%|████████▍ | 126/150 [00:21<00:03,  6.01it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 125 - Loss 1.8214569091796875 - Acc.: 0.5640, F1: 0.5475, AUROC: 0.9456


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 0 - Batch 150 - Loss 1.4761323928833008 - Acc.: 0.6054, F1: 0.5942, AUROC: 0.9516



  1%|▏         | 2/150 [00:00<00:27,  5.39it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 0 - Loss 1.5398733615875244 - Acc.: 0.7755, F1: 0.7593, AUROC: 0.9948


 18%|█▊        | 27/150 [00:04<00:20,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 25 - Loss 1.2416636943817139 - Acc.: 0.7865, F1: 0.7824, AUROC: 0.9953


 35%|███▍      | 52/150 [00:08<00:15,  6.32it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 50 - Loss 1.0198630094528198 - Acc.: 0.8044, F1: 0.8008, AUROC: 0.9958


 51%|█████     | 76/150 [00:12<00:12,  6.15it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 75 - Loss 0.8902873396873474 - Acc.: 0.8149, F1: 0.8117, AUROC: 0.9961


 68%|██████▊   | 102/150 [00:17<00:07,  6.26it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 100 - Loss 0.7729455828666687 - Acc.: 0.8220, F1: 0.8190, AUROC: 0.9963


 85%|████████▍ | 127/150 [00:21<00:03,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 125 - Loss 0.6747633814811707 - Acc.: 0.8282, F1: 0.8255, AUROC: 0.9966


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 1 - Batch 150 - Loss 0.6375866532325745 - Acc.: 0.8326, F1: 0.8298, AUROC: 0.9968



  1%|▏         | 2/150 [00:00<00:27,  5.29it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 0 - Loss 0.7069101929664612 - Acc.: 0.8401, F1: 0.8314, AUROC: 0.9973


 18%|█▊        | 27/150 [00:04<00:20,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 25 - Loss 0.6233890056610107 - Acc.: 0.8449, F1: 0.8417, AUROC: 0.9975


 35%|███▍      | 52/150 [00:08<00:15,  6.37it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 50 - Loss 0.5196290016174316 - Acc.: 0.8520, F1: 0.8497, AUROC: 0.9978


 51%|█████▏    | 77/150 [00:13<00:12,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 75 - Loss 0.54063481092453 - Acc.: 0.8556, F1: 0.8530, AUROC: 0.9979


 68%|██████▊   | 102/150 [00:17<00:07,  6.31it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 100 - Loss 0.48316317796707153 - Acc.: 0.8594, F1: 0.8570, AUROC: 0.9980


 84%|████████▍ | 126/150 [00:21<00:03,  6.19it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 125 - Loss 0.50927734375 - Acc.: 0.8618, F1: 0.8596, AUROC: 0.9980


100%|██████████| 150/150 [00:25<00:00,  5.85it/s]

[TRAIN] Epoch 1/4 - Subj 2 - Batch 150 - Loss 0.4562292695045471 - Acc.: 0.8641, F1: 0.8618, AUROC: 0.9981



  1%|▏         | 2/150 [00:00<00:28,  5.26it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 0 - Loss 0.4878895580768585 - Acc.: 0.8428, F1: 0.8357, AUROC: 0.9984


 18%|█▊        | 27/150 [00:04<00:21,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 25 - Loss 0.4654131829738617 - Acc.: 0.8556, F1: 0.8536, AUROC: 0.9980


 35%|███▍      | 52/150 [00:08<00:15,  6.31it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 50 - Loss 0.4126361012458801 - Acc.: 0.8616, F1: 0.8599, AUROC: 0.9981


 51%|█████▏    | 77/150 [00:13<00:12,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 75 - Loss 0.42903074622154236 - Acc.: 0.8668, F1: 0.8650, AUROC: 0.9982


 68%|██████▊   | 102/150 [00:17<00:07,  6.27it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 100 - Loss 0.44026312232017517 - Acc.: 0.8691, F1: 0.8671, AUROC: 0.9982


 84%|████████▍ | 126/150 [00:21<00:03,  6.14it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 125 - Loss 0.3608866333961487 - Acc.: 0.8711, F1: 0.8693, AUROC: 0.9983


100%|██████████| 150/150 [00:25<00:00,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 3 - Batch 150 - Loss 0.3440038859844208 - Acc.: 0.8734, F1: 0.8716, AUROC: 0.9983



  1%|▏         | 2/150 [00:00<00:28,  5.27it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 0 - Loss 0.4320756494998932 - Acc.: 0.8581, F1: 0.8412, AUROC: 0.9982


 18%|█▊        | 27/150 [00:04<00:20,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 25 - Loss 0.3865158259868622 - Acc.: 0.8655, F1: 0.8634, AUROC: 0.9983


 35%|███▍      | 52/150 [00:08<00:15,  6.41it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 50 - Loss 0.3373969793319702 - Acc.: 0.8698, F1: 0.8679, AUROC: 0.9983


 51%|█████▏    | 77/150 [00:13<00:12,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 75 - Loss 0.33927270770072937 - Acc.: 0.8715, F1: 0.8698, AUROC: 0.9983


 68%|██████▊   | 102/150 [00:17<00:07,  6.31it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 100 - Loss 0.3588148057460785 - Acc.: 0.8719, F1: 0.8700, AUROC: 0.9984


 85%|████████▍ | 127/150 [00:21<00:03,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 125 - Loss 0.35142043232917786 - Acc.: 0.8738, F1: 0.8722, AUROC: 0.9984


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 4 - Batch 150 - Loss 0.3456278145313263 - Acc.: 0.8751, F1: 0.8734, AUROC: 0.9984



  1%|▏         | 2/150 [00:00<00:28,  5.27it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 0 - Loss 0.3660668134689331 - Acc.: 0.8557, F1: 0.8508, AUROC: 0.9985


 18%|█▊        | 27/150 [00:04<00:20,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 25 - Loss 0.38295719027519226 - Acc.: 0.8649, F1: 0.8638, AUROC: 0.9982


 35%|███▍      | 52/150 [00:08<00:15,  6.35it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 50 - Loss 0.34489327669143677 - Acc.: 0.8727, F1: 0.8714, AUROC: 0.9984


 51%|█████▏    | 77/150 [00:13<00:12,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 75 - Loss 0.33035406470298767 - Acc.: 0.8755, F1: 0.8743, AUROC: 0.9984


 68%|██████▊   | 102/150 [00:17<00:07,  6.29it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 100 - Loss 0.3055591881275177 - Acc.: 0.8771, F1: 0.8760, AUROC: 0.9985


 84%|████████▍ | 126/150 [00:21<00:03,  6.14it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 125 - Loss 0.36036956310272217 - Acc.: 0.8790, F1: 0.8779, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.85it/s]

[TRAIN] Epoch 1/4 - Subj 5 - Batch 150 - Loss 0.2834504246711731 - Acc.: 0.8805, F1: 0.8795, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:28,  5.27it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 0 - Loss 0.35019466280937195 - Acc.: 0.8519, F1: 0.8406, AUROC: 0.9984


 18%|█▊        | 27/150 [00:04<00:20,  5.92it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 25 - Loss 0.33373987674713135 - Acc.: 0.8673, F1: 0.8655, AUROC: 0.9984


 35%|███▍      | 52/150 [00:08<00:15,  6.39it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 50 - Loss 0.3396244943141937 - Acc.: 0.8711, F1: 0.8697, AUROC: 0.9984


 51%|█████▏    | 77/150 [00:13<00:12,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 75 - Loss 0.3036070466041565 - Acc.: 0.8738, F1: 0.8724, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.40it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 100 - Loss 0.29908838868141174 - Acc.: 0.8754, F1: 0.8740, AUROC: 0.9985


 84%|████████▍ | 126/150 [00:21<00:03,  6.20it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 125 - Loss 0.34328633546829224 - Acc.: 0.8771, F1: 0.8756, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 6 - Batch 150 - Loss 0.2543441951274872 - Acc.: 0.8783, F1: 0.8769, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:27,  5.33it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 0 - Loss 0.38215702772140503 - Acc.: 0.8717, F1: 0.8640, AUROC: 0.9980


 18%|█▊        | 27/150 [00:04<00:20,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 25 - Loss 0.3157341182231903 - Acc.: 0.8775, F1: 0.8759, AUROC: 0.9985


 35%|███▍      | 52/150 [00:08<00:15,  6.34it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 50 - Loss 0.31676962971687317 - Acc.: 0.8834, F1: 0.8823, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  5.80it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 75 - Loss 0.29895737767219543 - Acc.: 0.8853, F1: 0.8842, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.29it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 100 - Loss 0.28201958537101746 - Acc.: 0.8876, F1: 0.8866, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:03,  5.77it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 125 - Loss 0.30839914083480835 - Acc.: 0.8886, F1: 0.8877, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 7 - Batch 150 - Loss 0.2746686041355133 - Acc.: 0.8899, F1: 0.8890, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.21it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 0 - Loss 0.3192024827003479 - Acc.: 0.8914, F1: 0.8873, AUROC: 0.9985


 18%|█▊        | 27/150 [00:04<00:21,  5.77it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 25 - Loss 0.27228835225105286 - Acc.: 0.8841, F1: 0.8829, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.28it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 50 - Loss 0.2980805039405823 - Acc.: 0.8864, F1: 0.8857, AUROC: 0.9986


 51%|█████     | 76/150 [00:13<00:12,  6.10it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 75 - Loss 0.34231477975845337 - Acc.: 0.8872, F1: 0.8862, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.11it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 100 - Loss 0.2743087708950043 - Acc.: 0.8878, F1: 0.8868, AUROC: 0.9987


 84%|████████▍ | 126/150 [00:21<00:03,  6.12it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 125 - Loss 0.26896780729293823 - Acc.: 0.8890, F1: 0.8881, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 8 - Batch 150 - Loss 0.2565577030181885 - Acc.: 0.8891, F1: 0.8881, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.25it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 0 - Loss 0.35356801748275757 - Acc.: 0.8649, F1: 0.8559, AUROC: 0.9981


 18%|█▊        | 27/150 [00:04<00:20,  5.93it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 25 - Loss 0.2880344092845917 - Acc.: 0.8813, F1: 0.8806, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.40it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 50 - Loss 0.2639438509941101 - Acc.: 0.8848, F1: 0.8843, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 75 - Loss 0.2819863259792328 - Acc.: 0.8868, F1: 0.8861, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.35it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 100 - Loss 0.2962312400341034 - Acc.: 0.8895, F1: 0.8886, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 125 - Loss 0.2624397277832031 - Acc.: 0.8902, F1: 0.8894, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 9 - Batch 150 - Loss 0.2628917694091797 - Acc.: 0.8905, F1: 0.8896, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.28it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 0 - Loss 0.39764297008514404 - Acc.: 0.8555, F1: 0.8405, AUROC: 0.9981


 18%|█▊        | 27/150 [00:04<00:21,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 25 - Loss 0.2926792502403259 - Acc.: 0.8765, F1: 0.8687, AUROC: 0.9985


 34%|███▍      | 51/150 [00:08<00:16,  6.00it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 50 - Loss 0.2958134412765503 - Acc.: 0.8814, F1: 0.8756, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  6.06it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 75 - Loss 0.3177144229412079 - Acc.: 0.8820, F1: 0.8772, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.41it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 100 - Loss 0.31093263626098633 - Acc.: 0.8833, F1: 0.8792, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:03,  5.85it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 125 - Loss 0.28426653146743774 - Acc.: 0.8845, F1: 0.8811, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.77it/s]

[TRAIN] Epoch 1/4 - Subj 10 - Batch 150 - Loss 0.2972167730331421 - Acc.: 0.8854, F1: 0.8825, AUROC: 0.9986



  1%|▏         | 2/150 [00:00<00:28,  5.19it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 0 - Loss 0.37134450674057007 - Acc.: 0.8775, F1: 0.8665, AUROC: 0.9983


 18%|█▊        | 27/150 [00:04<00:20,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 25 - Loss 0.3471134305000305 - Acc.: 0.8759, F1: 0.8735, AUROC: 0.9984


 35%|███▍      | 52/150 [00:08<00:15,  6.31it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 50 - Loss 0.2981206774711609 - Acc.: 0.8800, F1: 0.8779, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 75 - Loss 0.3058043122291565 - Acc.: 0.8810, F1: 0.8792, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.32it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 100 - Loss 0.2676425576210022 - Acc.: 0.8825, F1: 0.8811, AUROC: 0.9985


 84%|████████▍ | 126/150 [00:21<00:03,  6.13it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 125 - Loss 0.2612578868865967 - Acc.: 0.8845, F1: 0.8833, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 11 - Batch 150 - Loss 0.2803753614425659 - Acc.: 0.8849, F1: 0.8838, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:27,  5.31it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 0 - Loss 0.30409833788871765 - Acc.: 0.8724, F1: 0.8644, AUROC: 0.9986


 18%|█▊        | 27/150 [00:04<00:20,  5.94it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 25 - Loss 0.28888022899627686 - Acc.: 0.8738, F1: 0.8710, AUROC: 0.9984


 35%|███▍      | 52/150 [00:08<00:15,  6.39it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 50 - Loss 0.28901398181915283 - Acc.: 0.8778, F1: 0.8764, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 75 - Loss 0.2978450357913971 - Acc.: 0.8792, F1: 0.8782, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.37it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 100 - Loss 0.2741561830043793 - Acc.: 0.8812, F1: 0.8805, AUROC: 0.9985


 85%|████████▍ | 127/150 [00:21<00:03,  5.92it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 125 - Loss 0.30076128244400024 - Acc.: 0.8822, F1: 0.8814, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 12 - Batch 150 - Loss 0.2940554618835449 - Acc.: 0.8832, F1: 0.8824, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:28,  5.25it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 0 - Loss 0.33599305152893066 - Acc.: 0.8800, F1: 0.8738, AUROC: 0.9984


 18%|█▊        | 27/150 [00:04<00:20,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 25 - Loss 0.3355705738067627 - Acc.: 0.8745, F1: 0.8725, AUROC: 0.9985


 35%|███▍      | 52/150 [00:08<00:15,  6.36it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 50 - Loss 0.28699105978012085 - Acc.: 0.8759, F1: 0.8747, AUROC: 0.9985


 51%|█████     | 76/150 [00:12<00:11,  6.19it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 75 - Loss 0.2955588400363922 - Acc.: 0.8774, F1: 0.8766, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.33it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 100 - Loss 0.2508509159088135 - Acc.: 0.8785, F1: 0.8778, AUROC: 0.9985


 84%|████████▍ | 126/150 [00:21<00:03,  6.17it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 125 - Loss 0.28509724140167236 - Acc.: 0.8801, F1: 0.8795, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 13 - Batch 150 - Loss 0.25660577416419983 - Acc.: 0.8812, F1: 0.8804, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:28,  5.12it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 0 - Loss 0.3154997229576111 - Acc.: 0.8721, F1: 0.8657, AUROC: 0.9985


 18%|█▊        | 27/150 [00:04<00:20,  5.93it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 25 - Loss 0.3148740530014038 - Acc.: 0.8830, F1: 0.8819, AUROC: 0.9987


 35%|███▍      | 52/150 [00:08<00:15,  6.28it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 50 - Loss 0.29673483967781067 - Acc.: 0.8828, F1: 0.8817, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.75it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 75 - Loss 0.31533873081207275 - Acc.: 0.8843, F1: 0.8833, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.31it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 100 - Loss 0.31576433777809143 - Acc.: 0.8840, F1: 0.8829, AUROC: 0.9986


 84%|████████▍ | 126/150 [00:21<00:03,  6.19it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 125 - Loss 0.2643440365791321 - Acc.: 0.8848, F1: 0.8837, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 14 - Batch 150 - Loss 0.2825338840484619 - Acc.: 0.8854, F1: 0.8842, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.18it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 0 - Loss 0.266427606344223 - Acc.: 0.9111, F1: 0.9035, AUROC: 0.9989


 18%|█▊        | 27/150 [00:04<00:20,  5.91it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 25 - Loss 0.30826660990715027 - Acc.: 0.8967, F1: 0.8946, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.47it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 50 - Loss 0.28593090176582336 - Acc.: 0.8961, F1: 0.8945, AUROC: 0.9988


 51%|█████▏    | 77/150 [00:13<00:12,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 75 - Loss 0.26944392919540405 - Acc.: 0.8962, F1: 0.8947, AUROC: 0.9988


 68%|██████▊   | 102/150 [00:17<00:07,  6.33it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 100 - Loss 0.2872856855392456 - Acc.: 0.8961, F1: 0.8947, AUROC: 0.9988


 85%|████████▍ | 127/150 [00:21<00:03,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 125 - Loss 0.24780258536338806 - Acc.: 0.8961, F1: 0.8948, AUROC: 0.9988


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 15 - Batch 150 - Loss 0.2975994944572449 - Acc.: 0.8963, F1: 0.8951, AUROC: 0.9988



  1%|▏         | 2/150 [00:00<00:28,  5.23it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 0 - Loss 0.32563334703445435 - Acc.: 0.8714, F1: 0.8638, AUROC: 0.9985


 18%|█▊        | 27/150 [00:04<00:21,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 25 - Loss 0.25135281682014465 - Acc.: 0.8847, F1: 0.8840, AUROC: 0.9987


 35%|███▍      | 52/150 [00:08<00:15,  6.45it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 50 - Loss 0.28192955255508423 - Acc.: 0.8853, F1: 0.8840, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  5.91it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 75 - Loss 0.29319387674331665 - Acc.: 0.8863, F1: 0.8850, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.36it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 100 - Loss 0.2656921446323395 - Acc.: 0.8880, F1: 0.8869, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 125 - Loss 0.2722371518611908 - Acc.: 0.8891, F1: 0.8880, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 16 - Batch 150 - Loss 0.25091397762298584 - Acc.: 0.8894, F1: 0.8882, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.19it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 0 - Loss 0.30562517046928406 - Acc.: 0.8827, F1: 0.8745, AUROC: 0.9986


 18%|█▊        | 27/150 [00:04<00:21,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 25 - Loss 0.2980409264564514 - Acc.: 0.8953, F1: 0.8941, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.29it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 50 - Loss 0.2866879105567932 - Acc.: 0.8958, F1: 0.8949, AUROC: 0.9988


 51%|█████▏    | 77/150 [00:13<00:12,  5.80it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 75 - Loss 0.22015607357025146 - Acc.: 0.8975, F1: 0.8966, AUROC: 0.9988


 68%|██████▊   | 102/150 [00:17<00:07,  6.28it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 100 - Loss 0.25557219982147217 - Acc.: 0.8982, F1: 0.8974, AUROC: 0.9988


 84%|████████▍ | 126/150 [00:21<00:03,  6.12it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 125 - Loss 0.2724795937538147 - Acc.: 0.8985, F1: 0.8976, AUROC: 0.9988


100%|██████████| 150/150 [00:25<00:00,  5.78it/s]

[TRAIN] Epoch 1/4 - Subj 17 - Batch 150 - Loss 0.25701215863227844 - Acc.: 0.8987, F1: 0.8978, AUROC: 0.9989



  1%|▏         | 2/150 [00:00<00:27,  5.30it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 0 - Loss 0.32690173387527466 - Acc.: 0.8621, F1: 0.8550, AUROC: 0.9982


 18%|█▊        | 27/150 [00:04<00:20,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 25 - Loss 0.29735490679740906 - Acc.: 0.8714, F1: 0.8692, AUROC: 0.9985


 35%|███▍      | 52/150 [00:08<00:15,  6.29it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 50 - Loss 0.36025717854499817 - Acc.: 0.8731, F1: 0.8714, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 75 - Loss 0.28443604707717896 - Acc.: 0.8748, F1: 0.8732, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.34it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 100 - Loss 0.30189692974090576 - Acc.: 0.8760, F1: 0.8743, AUROC: 0.9985


 84%|████████▍ | 126/150 [00:21<00:03,  6.15it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 125 - Loss 0.3038196563720703 - Acc.: 0.8779, F1: 0.8762, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 18 - Batch 150 - Loss 0.29697877168655396 - Acc.: 0.8787, F1: 0.8771, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:28,  5.23it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 0 - Loss 0.3254700303077698 - Acc.: 0.8836, F1: 0.8739, AUROC: 0.9982


 18%|█▊        | 27/150 [00:04<00:20,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 25 - Loss 0.30742573738098145 - Acc.: 0.8839, F1: 0.8836, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.36it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 50 - Loss 0.2629905939102173 - Acc.: 0.8848, F1: 0.8841, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 75 - Loss 0.27515941858291626 - Acc.: 0.8856, F1: 0.8848, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.28it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 100 - Loss 0.28615936636924744 - Acc.: 0.8869, F1: 0.8859, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 125 - Loss 0.3024059236049652 - Acc.: 0.8875, F1: 0.8865, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 19 - Batch 150 - Loss 0.2994612455368042 - Acc.: 0.8871, F1: 0.8861, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.12it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 0 - Loss 0.2981555163860321 - Acc.: 0.8817, F1: 0.8764, AUROC: 0.9986


 18%|█▊        | 27/150 [00:04<00:20,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 25 - Loss 0.25795459747314453 - Acc.: 0.8803, F1: 0.8769, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.44it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 50 - Loss 0.29857370257377625 - Acc.: 0.8805, F1: 0.8772, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 75 - Loss 0.2950688302516937 - Acc.: 0.8806, F1: 0.8774, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.44it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 100 - Loss 0.3220479488372803 - Acc.: 0.8810, F1: 0.8778, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:03,  5.86it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 125 - Loss 0.3259499669075012 - Acc.: 0.8813, F1: 0.8782, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 20 - Batch 150 - Loss 0.33841371536254883 - Acc.: 0.8817, F1: 0.8787, AUROC: 0.9986



  1%|▏         | 2/150 [00:00<00:27,  5.31it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 0 - Loss 0.346479594707489 - Acc.: 0.8570, F1: 0.8545, AUROC: 0.9983


 18%|█▊        | 27/150 [00:04<00:20,  6.00it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 25 - Loss 0.3342551589012146 - Acc.: 0.8631, F1: 0.8613, AUROC: 0.9984


 35%|███▍      | 52/150 [00:09<00:15,  6.37it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 50 - Loss 0.31275689601898193 - Acc.: 0.8636, F1: 0.8616, AUROC: 0.9984


 51%|█████▏    | 77/150 [00:13<00:12,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 75 - Loss 0.3087359666824341 - Acc.: 0.8642, F1: 0.8624, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.28it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 100 - Loss 0.32671117782592773 - Acc.: 0.8642, F1: 0.8624, AUROC: 0.9985


 85%|████████▍ | 127/150 [00:21<00:03,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 125 - Loss 0.30057236552238464 - Acc.: 0.8646, F1: 0.8628, AUROC: 0.9985


100%|██████████| 150/150 [00:26<00:00,  5.74it/s]

[TRAIN] Epoch 1/4 - Subj 21 - Batch 150 - Loss 0.3444768190383911 - Acc.: 0.8652, F1: 0.8636, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:29,  4.95it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 0 - Loss 0.2538411319255829 - Acc.: 0.9043, F1: 0.8938, AUROC: 0.9990


 18%|█▊        | 27/150 [00:04<00:21,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 25 - Loss 0.26913517713546753 - Acc.: 0.8879, F1: 0.8852, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.23it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 50 - Loss 0.26708680391311646 - Acc.: 0.8900, F1: 0.8878, AUROC: 0.9988


 51%|█████▏    | 77/150 [00:13<00:12,  5.77it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 75 - Loss 0.33013463020324707 - Acc.: 0.8897, F1: 0.8875, AUROC: 0.9988


 68%|██████▊   | 102/150 [00:17<00:07,  6.27it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 100 - Loss 0.3226638734340668 - Acc.: 0.8894, F1: 0.8872, AUROC: 0.9988


 84%|████████▍ | 126/150 [00:21<00:03,  6.06it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 125 - Loss 0.2693278193473816 - Acc.: 0.8898, F1: 0.8877, AUROC: 0.9988


100%|██████████| 150/150 [00:26<00:00,  5.74it/s]

[TRAIN] Epoch 1/4 - Subj 22 - Batch 150 - Loss 0.29212668538093567 - Acc.: 0.8895, F1: 0.8874, AUROC: 0.9988



  1%|▏         | 2/150 [00:00<00:27,  5.31it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 0 - Loss 0.36324650049209595 - Acc.: 0.8644, F1: 0.8585, AUROC: 0.9982


 18%|█▊        | 27/150 [00:04<00:20,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 25 - Loss 0.2834169864654541 - Acc.: 0.8810, F1: 0.8795, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.28it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 50 - Loss 0.3041233718395233 - Acc.: 0.8806, F1: 0.8790, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.85it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 75 - Loss 0.25450193881988525 - Acc.: 0.8806, F1: 0.8791, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.33it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 100 - Loss 0.28344032168388367 - Acc.: 0.8801, F1: 0.8784, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:03,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 125 - Loss 0.30415427684783936 - Acc.: 0.8800, F1: 0.8783, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 23 - Batch 150 - Loss 0.2931031584739685 - Acc.: 0.8805, F1: 0.8788, AUROC: 0.9986



  1%|▏         | 2/150 [00:00<00:29,  4.95it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 0 - Loss 0.2810722589492798 - Acc.: 0.8924, F1: 0.8866, AUROC: 0.9988


 18%|█▊        | 27/150 [00:04<00:20,  6.08it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 25 - Loss 0.2335224598646164 - Acc.: 0.8935, F1: 0.8919, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.41it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 50 - Loss 0.25886794924736023 - Acc.: 0.8932, F1: 0.8913, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 75 - Loss 0.26922711730003357 - Acc.: 0.8934, F1: 0.8918, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.31it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 100 - Loss 0.22327247262001038 - Acc.: 0.8934, F1: 0.8919, AUROC: 0.9988


 85%|████████▍ | 127/150 [00:21<00:04,  5.66it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 125 - Loss 0.28200432658195496 - Acc.: 0.8931, F1: 0.8917, AUROC: 0.9988


100%|██████████| 150/150 [00:25<00:00,  5.80it/s]

[TRAIN] Epoch 1/4 - Subj 24 - Batch 150 - Loss 0.2817757725715637 - Acc.: 0.8934, F1: 0.8920, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:29,  5.06it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 0 - Loss 0.3649210035800934 - Acc.: 0.8615, F1: 0.8472, AUROC: 0.9982


 18%|█▊        | 27/150 [00:04<00:20,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 25 - Loss 0.4066183269023895 - Acc.: 0.8640, F1: 0.8592, AUROC: 0.9983


 35%|███▍      | 52/150 [00:08<00:15,  6.32it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 50 - Loss 0.29476261138916016 - Acc.: 0.8689, F1: 0.8646, AUROC: 0.9983


 51%|█████▏    | 77/150 [00:13<00:12,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 75 - Loss 0.3459431827068329 - Acc.: 0.8706, F1: 0.8666, AUROC: 0.9983


 68%|██████▊   | 102/150 [00:17<00:07,  6.30it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 100 - Loss 0.31748268008232117 - Acc.: 0.8701, F1: 0.8662, AUROC: 0.9983


 84%|████████▍ | 126/150 [00:21<00:03,  6.16it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 125 - Loss 0.2944258749485016 - Acc.: 0.8709, F1: 0.8671, AUROC: 0.9984


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 25 - Batch 150 - Loss 0.30719056725502014 - Acc.: 0.8709, F1: 0.8672, AUROC: 0.9984



  1%|▏         | 2/150 [00:00<00:29,  5.09it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 0 - Loss 0.3004818856716156 - Acc.: 0.8758, F1: 0.8712, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:20,  6.10it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 25 - Loss 0.3016616106033325 - Acc.: 0.8853, F1: 0.8828, AUROC: 0.9987


 35%|███▍      | 52/150 [00:08<00:15,  6.46it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 50 - Loss 0.30327606201171875 - Acc.: 0.8851, F1: 0.8831, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  5.85it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 75 - Loss 0.3084215521812439 - Acc.: 0.8853, F1: 0.8833, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.41it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 100 - Loss 0.29103007912635803 - Acc.: 0.8844, F1: 0.8826, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  6.08it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 125 - Loss 0.27548372745513916 - Acc.: 0.8840, F1: 0.8822, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.78it/s]

[TRAIN] Epoch 1/4 - Subj 26 - Batch 150 - Loss 0.27307888865470886 - Acc.: 0.8839, F1: 0.8821, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.21it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 0 - Loss 0.2614499032497406 - Acc.: 0.9038, F1: 0.8971, AUROC: 0.9989


 18%|█▊        | 27/150 [00:04<00:20,  6.10it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 25 - Loss 0.25182783603668213 - Acc.: 0.8930, F1: 0.8913, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.47it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 50 - Loss 0.2987789213657379 - Acc.: 0.8907, F1: 0.8891, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  6.00it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 75 - Loss 0.27280882000923157 - Acc.: 0.8906, F1: 0.8891, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.47it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 100 - Loss 0.28707000613212585 - Acc.: 0.8897, F1: 0.8883, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 125 - Loss 0.26382362842559814 - Acc.: 0.8901, F1: 0.8888, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 27 - Batch 150 - Loss 0.27537158131599426 - Acc.: 0.8904, F1: 0.8892, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.20it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 0 - Loss 0.30703699588775635 - Acc.: 0.8795, F1: 0.8708, AUROC: 0.9845


 18%|█▊        | 27/150 [00:04<00:20,  6.05it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 25 - Loss 0.30579906702041626 - Acc.: 0.8785, F1: 0.8767, AUROC: 0.9846


 35%|███▍      | 52/150 [00:08<00:15,  6.42it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 50 - Loss 0.3019632399082184 - Acc.: 0.8775, F1: 0.8758, AUROC: 0.9845


 51%|█████▏    | 77/150 [00:13<00:12,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 75 - Loss 0.3045803904533386 - Acc.: 0.8776, F1: 0.8758, AUROC: 0.9845


 68%|██████▊   | 102/150 [00:17<00:07,  6.37it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 100 - Loss 0.32846060395240784 - Acc.: 0.8767, F1: 0.8751, AUROC: 0.9845


 85%|████████▍ | 127/150 [00:21<00:03,  5.80it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 125 - Loss 0.27946776151657104 - Acc.: 0.8776, F1: 0.8761, AUROC: 0.9845


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 28 - Batch 150 - Loss 0.2818537652492523 - Acc.: 0.8773, F1: 0.8757, AUROC: 0.9845



  1%|▏         | 2/150 [00:00<00:28,  5.14it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 0 - Loss 0.3555064797401428 - Acc.: 0.8696, F1: 0.8641, AUROC: 0.9981


 18%|█▊        | 27/150 [00:04<00:20,  6.12it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 25 - Loss 0.3198513090610504 - Acc.: 0.8731, F1: 0.8710, AUROC: 0.9984


 35%|███▍      | 52/150 [00:08<00:15,  6.43it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 50 - Loss 0.3444714844226837 - Acc.: 0.8721, F1: 0.8702, AUROC: 0.9984


 51%|█████▏    | 77/150 [00:13<00:12,  5.90it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 75 - Loss 0.328108549118042 - Acc.: 0.8724, F1: 0.8704, AUROC: 0.9984


 68%|██████▊   | 102/150 [00:17<00:07,  6.30it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 100 - Loss 0.3745537996292114 - Acc.: 0.8720, F1: 0.8698, AUROC: 0.9984


 85%|████████▍ | 127/150 [00:21<00:03,  5.91it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 125 - Loss 0.32712453603744507 - Acc.: 0.8723, F1: 0.8702, AUROC: 0.9984


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 29 - Batch 150 - Loss 0.34715327620506287 - Acc.: 0.8730, F1: 0.8710, AUROC: 0.9984



  1%|▏         | 2/150 [00:00<00:29,  5.10it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 0 - Loss 0.26360511779785156 - Acc.: 0.8799, F1: 0.8787, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:20,  6.10it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 25 - Loss 0.25695520639419556 - Acc.: 0.8955, F1: 0.8930, AUROC: 0.9988


 34%|███▍      | 51/150 [00:08<00:17,  5.61it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 50 - Loss 0.2765048146247864 - Acc.: 0.8964, F1: 0.8944, AUROC: 0.9988


 51%|█████▏    | 77/150 [00:13<00:11,  6.09it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 75 - Loss 0.2727664113044739 - Acc.: 0.8959, F1: 0.8936, AUROC: 0.9988


 68%|██████▊   | 102/150 [00:17<00:07,  6.43it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 100 - Loss 0.2489060014486313 - Acc.: 0.8960, F1: 0.8936, AUROC: 0.9988


 85%|████████▍ | 127/150 [00:21<00:03,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 125 - Loss 0.2449025809764862 - Acc.: 0.8970, F1: 0.8947, AUROC: 0.9988


100%|██████████| 150/150 [00:25<00:00,  5.78it/s]

[TRAIN] Epoch 1/4 - Subj 30 - Batch 150 - Loss 0.2628341615200043 - Acc.: 0.8977, F1: 0.8954, AUROC: 0.9988



  1%|▏         | 2/150 [00:00<00:28,  5.18it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 0 - Loss 0.2685057520866394 - Acc.: 0.8812, F1: 0.8741, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:20,  6.07it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 25 - Loss 0.28321805596351624 - Acc.: 0.8730, F1: 0.8677, AUROC: 0.9985


 35%|███▍      | 52/150 [00:08<00:15,  6.32it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 50 - Loss 0.32293617725372314 - Acc.: 0.8733, F1: 0.8681, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  5.98it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 75 - Loss 0.3190813362598419 - Acc.: 0.8734, F1: 0.8684, AUROC: 0.9985


 67%|██████▋   | 101/150 [00:17<00:08,  6.06it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 100 - Loss 0.31397950649261475 - Acc.: 0.8739, F1: 0.8689, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:03,  6.16it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 125 - Loss 0.3113219738006592 - Acc.: 0.8742, F1: 0.8693, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.85it/s]

[TRAIN] Epoch 1/4 - Subj 31 - Batch 150 - Loss 0.2792128622531891 - Acc.: 0.8747, F1: 0.8699, AUROC: 0.9986



  1%|▏         | 2/150 [00:00<00:29,  5.00it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 0 - Loss 0.26816031336784363 - Acc.: 0.8949, F1: 0.8865, AUROC: 0.9989


 18%|█▊        | 27/150 [00:04<00:20,  6.14it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 25 - Loss 0.2997334897518158 - Acc.: 0.8789, F1: 0.8783, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.48it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 50 - Loss 0.27766475081443787 - Acc.: 0.8792, F1: 0.8785, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.96it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 75 - Loss 0.32286256551742554 - Acc.: 0.8785, F1: 0.8777, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.41it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 100 - Loss 0.27884534001350403 - Acc.: 0.8800, F1: 0.8793, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:04,  5.60it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 125 - Loss 0.3137739896774292 - Acc.: 0.8798, F1: 0.8791, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 32 - Batch 150 - Loss 0.29334473609924316 - Acc.: 0.8799, F1: 0.8790, AUROC: 0.9986



  1%|▏         | 2/150 [00:00<00:28,  5.20it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 0 - Loss 0.290261447429657 - Acc.: 0.8884, F1: 0.8840, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:20,  6.13it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 25 - Loss 0.3090709447860718 - Acc.: 0.8883, F1: 0.8874, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.47it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 50 - Loss 0.3235216438770294 - Acc.: 0.8853, F1: 0.8843, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  5.93it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 75 - Loss 0.30610957741737366 - Acc.: 0.8866, F1: 0.8856, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.39it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 100 - Loss 0.2859106659889221 - Acc.: 0.8863, F1: 0.8852, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.97it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 125 - Loss 0.29612571001052856 - Acc.: 0.8868, F1: 0.8858, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 33 - Batch 150 - Loss 0.2823062539100647 - Acc.: 0.8869, F1: 0.8859, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.21it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 0 - Loss 0.2853623032569885 - Acc.: 0.8900, F1: 0.8838, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:20,  6.14it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 25 - Loss 0.28117769956588745 - Acc.: 0.8804, F1: 0.8778, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.51it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 50 - Loss 0.321466863155365 - Acc.: 0.8802, F1: 0.8778, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.99it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 75 - Loss 0.2863115966320038 - Acc.: 0.8821, F1: 0.8800, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.53it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 100 - Loss 0.2756502330303192 - Acc.: 0.8829, F1: 0.8810, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  6.13it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 125 - Loss 0.3266022205352783 - Acc.: 0.8829, F1: 0.8810, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.97it/s]

[TRAIN] Epoch 1/4 - Subj 34 - Batch 150 - Loss 0.29265883564949036 - Acc.: 0.8825, F1: 0.8807, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.83it/s]
  1%|▏         | 2/150 [00:00<00:28,  5.19it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 0 - Loss 0.27445125579833984 - Acc.: 0.8929, F1: 0.8874, AUROC: 0.9989


 18%|█▊        | 27/150 [00:04<00:20,  6.12it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 25 - Loss 0.3208111822605133 - Acc.: 0.8870, F1: 0.8854, AUROC: 0.9987


 35%|███▍      | 52/150 [00:08<00:17,  5.50it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 50 - Loss 0.2923020124435425 - Acc.: 0.8872, F1: 0.8856, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:11,  6.08it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 75 - Loss 0.29319044947624207 - Acc.: 0.8883, F1: 0.8867, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.51it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 100 - Loss 0.309912770986557 - Acc.: 0.8880, F1: 0.8865, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  6.12it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 125 - Loss 0.23384030163288116 - Acc.: 0.8874, F1: 0.8859, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 35 - Batch 150 - Loss 0.2537730932235718 - Acc.: 0.8875, F1: 0.8861, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:28,  5.15it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 0 - Loss 0.2769280970096588 - Acc.: 0.9014, F1: 0.8976, AUROC: 0.9988


 18%|█▊        | 27/150 [00:04<00:20,  6.08it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 25 - Loss 0.2791823148727417 - Acc.: 0.8890, F1: 0.8875, AUROC: 0.9986


 34%|███▍      | 51/150 [00:08<00:16,  6.01it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 50 - Loss 0.2587010860443115 - Acc.: 0.8909, F1: 0.8896, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  6.06it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 75 - Loss 0.2589242458343506 - Acc.: 0.8912, F1: 0.8898, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.38it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 100 - Loss 0.2892121970653534 - Acc.: 0.8908, F1: 0.8893, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 125 - Loss 0.284710556268692 - Acc.: 0.8909, F1: 0.8894, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 36 - Batch 150 - Loss 0.29065948724746704 - Acc.: 0.8907, F1: 0.8892, AUROC: 0.9987



  1%|▏         | 2/150 [00:00<00:29,  5.07it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 0 - Loss 0.2755090296268463 - Acc.: 0.8853, F1: 0.8834, AUROC: 0.9989


 18%|█▊        | 27/150 [00:04<00:20,  6.08it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 25 - Loss 0.30410778522491455 - Acc.: 0.8878, F1: 0.8858, AUROC: 0.9986


 35%|███▍      | 52/150 [00:08<00:15,  6.50it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 50 - Loss 0.3483922779560089 - Acc.: 0.8879, F1: 0.8863, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:12,  5.95it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 75 - Loss 0.29022014141082764 - Acc.: 0.8874, F1: 0.8858, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.51it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 100 - Loss 0.2767517566680908 - Acc.: 0.8879, F1: 0.8863, AUROC: 0.9986


 85%|████████▍ | 127/150 [00:21<00:03,  6.17it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 125 - Loss 0.3305405378341675 - Acc.: 0.8877, F1: 0.8862, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.84it/s]

[TRAIN] Epoch 1/4 - Subj 37 - Batch 150 - Loss 0.27656176686286926 - Acc.: 0.8885, F1: 0.8870, AUROC: 0.9986



  1%|▏         | 2/150 [00:00<00:28,  5.15it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 0 - Loss 0.24346576631069183 - Acc.: 0.8815, F1: 0.8755, AUROC: 0.9991


 18%|█▊        | 27/150 [00:04<00:20,  6.10it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 25 - Loss 0.35579410195350647 - Acc.: 0.8847, F1: 0.8839, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.46it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 50 - Loss 0.28967002034187317 - Acc.: 0.8855, F1: 0.8845, AUROC: 0.9987


 51%|█████▏    | 77/150 [00:13<00:12,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 75 - Loss 0.29979726672172546 - Acc.: 0.8872, F1: 0.8862, AUROC: 0.9987


 68%|██████▊   | 102/150 [00:17<00:07,  6.39it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 100 - Loss 0.26240602135658264 - Acc.: 0.8874, F1: 0.8863, AUROC: 0.9987


 85%|████████▍ | 127/150 [00:21<00:03,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 125 - Loss 0.28293338418006897 - Acc.: 0.8875, F1: 0.8866, AUROC: 0.9987


100%|██████████| 150/150 [00:25<00:00,  5.79it/s]

[TRAIN] Epoch 1/4 - Subj 38 - Batch 150 - Loss 0.2822864055633545 - Acc.: 0.8872, F1: 0.8864, AUROC: 0.9987



  1%|          | 1/150 [00:00<00:40,  3.66it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 0 - Loss 0.28626400232315063 - Acc.: 0.8868, F1: 0.8815, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:20,  6.11it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 25 - Loss 0.2902900278568268 - Acc.: 0.8741, F1: 0.8733, AUROC: 0.9985


 35%|███▍      | 52/150 [00:08<00:15,  6.45it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 50 - Loss 0.3140396177768707 - Acc.: 0.8729, F1: 0.8722, AUROC: 0.9985


 51%|█████▏    | 77/150 [00:13<00:12,  5.75it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 75 - Loss 0.313055157661438 - Acc.: 0.8743, F1: 0.8734, AUROC: 0.9985


 68%|██████▊   | 102/150 [00:17<00:07,  6.32it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 100 - Loss 0.3258112967014313 - Acc.: 0.8741, F1: 0.8730, AUROC: 0.9985


 85%|████████▍ | 127/150 [00:21<00:03,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 125 - Loss 0.27380937337875366 - Acc.: 0.8735, F1: 0.8723, AUROC: 0.9985


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 39 - Batch 150 - Loss 0.30003049969673157 - Acc.: 0.8724, F1: 0.8713, AUROC: 0.9985



  1%|▏         | 2/150 [00:00<00:28,  5.16it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 0 - Loss 0.2846922278404236 - Acc.: 0.8853, F1: 0.8799, AUROC: 0.9986


 18%|█▊        | 27/150 [00:04<00:20,  5.87it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 25 - Loss 0.2587887644767761 - Acc.: 0.8836, F1: 0.8822, AUROC: 0.9986


 34%|███▍      | 51/150 [00:08<00:16,  6.10it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 50 - Loss 0.3258083462715149 - Acc.: 0.8826, F1: 0.8814, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:11,  6.13it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 75 - Loss 0.28984981775283813 - Acc.: 0.8815, F1: 0.8800, AUROC: 0.9985


 67%|██████▋   | 101/150 [00:17<00:08,  6.04it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 100 - Loss 0.29572054743766785 - Acc.: 0.8817, F1: 0.8801, AUROC: 0.9985


 85%|████████▍ | 127/150 [00:21<00:03,  5.96it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 125 - Loss 0.26515647768974304 - Acc.: 0.8818, F1: 0.8802, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 40 - Batch 150 - Loss 0.30241501331329346 - Acc.: 0.8814, F1: 0.8799, AUROC: 0.9985



  1%|          | 1/150 [00:00<00:41,  3.61it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 0 - Loss 0.277883917093277 - Acc.: 0.9002, F1: 0.8972, AUROC: 0.9987


 18%|█▊        | 27/150 [00:04<00:19,  6.15it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 25 - Loss 0.3223961591720581 - Acc.: 0.8859, F1: 0.8845, AUROC: 0.9986


 34%|███▍      | 51/150 [00:08<00:16,  6.09it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 50 - Loss 0.312569260597229 - Acc.: 0.8853, F1: 0.8840, AUROC: 0.9986


 51%|█████▏    | 77/150 [00:13<00:11,  6.11it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 75 - Loss 0.4030589163303375 - Acc.: 0.8840, F1: 0.8825, AUROC: 0.9986


 68%|██████▊   | 102/150 [00:17<00:07,  6.40it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 100 - Loss 0.2735387980937958 - Acc.: 0.8843, F1: 0.8828, AUROC: 0.9985


 85%|████████▍ | 127/150 [00:21<00:03,  5.88it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 125 - Loss 0.3123050928115845 - Acc.: 0.8848, F1: 0.8832, AUROC: 0.9986


100%|██████████| 150/150 [00:25<00:00,  5.82it/s]

[TRAIN] Epoch 1/4 - Subj 41 - Batch 150 - Loss 0.26619628071784973 - Acc.: 0.8843, F1: 0.8827, AUROC: 0.9986



  1%|          | 1/150 [00:00<00:42,  3.54it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 0 - Loss 0.2839941680431366 - Acc.: 0.8855, F1: 0.8817, AUROC: 0.9988


 18%|█▊        | 27/150 [00:04<00:20,  6.09it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 25 - Loss 0.25832560658454895 - Acc.: 0.8912, F1: 0.8903, AUROC: 0.9988


 34%|███▍      | 51/150 [00:08<00:16,  6.04it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 50 - Loss 0.2867819368839264 - Acc.: 0.8945, F1: 0.8934, AUROC: 0.9989


 51%|█████▏    | 77/150 [00:13<00:11,  6.11it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 75 - Loss 0.2347380369901657 - Acc.: 0.8941, F1: 0.8931, AUROC: 0.9989


 68%|██████▊   | 102/150 [00:17<00:07,  6.42it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 100 - Loss 0.2566666603088379 - Acc.: 0.8947, F1: 0.8939, AUROC: 0.9989


 85%|████████▍ | 127/150 [00:21<00:03,  5.89it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 125 - Loss 0.2748925983905792 - Acc.: 0.8951, F1: 0.8941, AUROC: 0.9989


100%|██████████| 150/150 [00:25<00:00,  5.81it/s]

[TRAIN] Epoch 1/4 - Subj 42 - Batch 150 - Loss 0.2605270743370056 - Acc.: 0.8948, F1: 0.8938, AUROC: 0.9989



  1%|          | 1/150 [00:00<00:41,  3.62it/s]

[TRAIN] Epoch 1/4 - Subj 43 - Batch 0 - Loss 0.25727301836013794 - Acc.: 0.8907, F1: 0.8806, AUROC: 0.9990


 18%|█▊        | 27/150 [00:04<00:20,  6.02it/s]

[TRAIN] Epoch 1/4 - Subj 43 - Batch 25 - Loss 0.24755068123340607 - Acc.: 0.8964, F1: 0.8952, AUROC: 0.9988


 35%|███▍      | 52/150 [00:08<00:15,  6.40it/s]

[TRAIN] Epoch 1/4 - Subj 43 - Batch 50 - Loss 0.28642553091049194 - Acc.: 0.8950, F1: 0.8939, AUROC: 0.9988


 51%|█████▏    | 77/150 [00:13<00:12,  5.83it/s]

[TRAIN] Epoch 1/4 - Subj 43 - Batch 75 - Loss 0.25147104263305664 - Acc.: 0.8941, F1: 0.8929, AUROC: 0.9988


 68%|██████▊   | 102/150 [00:17<00:07,  6.30it/s]

[TRAIN] Epoch 1/4 - Subj 43 - Batch 100 - Loss 0.2744048237800598 - Acc.: 0.8943, F1: 0.8929, AUROC: 0.9988


 85%|████████▍ | 127/150 [00:21<00:03,  5.78it/s]

[TRAIN] Epoch 1/4 - Subj 43 - Batch 125 - Loss 0.2433452606201172 - Acc.: 0.8947, F1: 0.8933, AUROC: 0.9988


 93%|█████████▎| 139/150 [00:23<00:01,  5.83it/s]


KeyboardInterrupt: 

In [None]:


#     # Fase de validación del modelo
#     model.eval()
    
#     for idx_val, subject in enumerate(valid_data):
#         val_ds = StreamlineTestDataset(subject, handler, 
#                                         transform = MaxMinNormalization(dataset = cfg.dataset))
        
#         val_dl = DataLoader(val_ds, batch_size = cfg.batch_size , 
#                               shuffle = False, num_workers = 1,
#                               collate_fn = collate_test_ds)
        
#         # Diccionario para guardar los embeddings de los grafos por clase cuyas claves son las etiquetas 0 - 71
#         # Diccionario para guardar los embeddings por clase
#         embeddings_list_by_class = defaultdict(list)
#         max_embeddings_per_class = 100
        

#         with torch.no_grad():
#             for i, graph in enumerate(val_dl):
                
#                 # Enviar a la gpu
#                 graph = graph.to('cuda')
#                 target = graph.y

#                 # Forward pass
#                 embedding, pred = model(graph)# Batch size necesario para calcular la media de los embeddings

#                 # Calcular métricas de validación
#                 subj_accuracy_val.update(pred, target)
#                 subj_f1_val.update(pred, target)
#                 subj_auroc_val.update(pred, target)
#                 subj_confusion_matrix.update(pred, target)

#                 # Guardar los embeddings y etiquetas
#                 if idx_val == 0:

#                     embedding = embedding.cpu().numpy()
#                     target = graph.y.cpu().numpy()

#                     # Guardar embeddings y etiquetas en el diccionario por clase
#                     for emb, label in zip(embedding, target):
#                         if len(embeddings_list_by_class[label]) < max_embeddings_per_class:
#                             embeddings_list_by_class[label].append(emb)
                    

#                 if log:# Loggear las métricas de batch
#                     run["val/batch/acc"].log(subj_accuracy_val.compute().item())
#                     run["val/batch/f1"].log(subj_f1_val.compute().item())
#                     run["val/batch/auroc"].log(subj_auroc_val.compute().item())

#                 if i % 25 == 0:
#                     print(f"[VAL] Epoch {epoch+1}/{cfg.max_epochs} - Subj {idx_val} - Batch {i} - Acc.: {subj_accuracy_val.compute().item():.4f}, F1: {subj_f1_val.compute().item():.4f}, AUROC: {subj_auroc_val.compute().item():.4f}")

#             # Para el idx_val = 1, guardar los embeddings de los grafos con sus etiquetas para visualizarlos en TensorBoard Projector
#             if idx_val == 0:

#                 # Preparar los datos para TensorBoard
#                 all_embeddings = []
#                 all_labels = []

#                 for label, embeddings in embeddings_list_by_class.items():
#                     all_embeddings.extend(embeddings)
#                     # Obtener la etiqueta textual de la clase a través del handler
#                     label = handler.get_tract_from_label(label)
#                     print(label)
#                     all_labels.extend([label] * len(embeddings))

#                 # Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1704987394225/work/torch/csrc/utils/tensor_new.cpp:275.)
#                 # Convertir listas a numpy.ndarray
#                 all_embeddings = np.array(all_embeddings)
#                 all_labels = np.array(all_labels)


#                 # Convertir a tensores
#                 # Convertir a tensores
#                 all_embeddings = torch.tensor(all_embeddings)
#                 all_labels = all_labels.tolist()  # TensorBoard necesita etiquetas como lista de strings


#                 # Guardar los embeddings y etiquetas en TensorBoard
#                 writer.add_embedding(
#                     all_embeddings, 
#                     metadata = all_labels, 
#                     global_step = epoch,
#                     tag = f"{cfg.dataset}_{cfg.encoder}_{cfg.embedding_projection_dim}_{time.time()}"
#                 ) 

#             if log:# Loggear las métricas de subject
#                 run["val/subject/acc"].log(subj_accuracy_val.compute().item())
#                 run["val/subject/f1"].log(subj_f1_val.compute().item())
#                 run["val/subject/auroc"].log(subj_auroc_val.compute().item())

#                 cm = subj_confusion_matrix.compute()
#                 # Convertir la matriz de confusión a numpy
#                 cm = cm.cpu().numpy()

#                 # Visualiza la matriz de confusión y guárdala como imagen
#                 plt.figure(figsize=(35, 35))
#                 sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')

#                 text_labels = [handler.get_tract_from_label(i) for i in range(cfg.n_classes)]
#                 plt.xticks(ticks = range(cfg.n_classes), labels = text_labels, rotation = 90)
#                 plt.yticks(ticks = range(cfg.n_classes), labels = text_labels, rotation = 0)
#                 plt.xlabel('Predicted Labels')
#                 plt.ylabel('True Labels')
#                 plt.title(f'Confusion Matrix Subj {idx_val}')
#                 plt.tight_layout()

#                 # Guarda la imagen
#                 img_path = f'/app/confusion_matrix_imgs/confusion_matrix_val_suj{idx_val}.png'
#                 plt.savefig(img_path)
#                 plt.close()

#                 # Sube la imagen a Neptune
#                 run["confusion_matrix_fig"].upload(img_path)
            
#             subj_accuracy_val.reset()
#             subj_f1_val.reset()
#             subj_auroc_val.reset()
#             subj_confusion_matrix.reset()

#         if idx_val == 6:
#             break

# # Cerrar el SummaryWriter
# writer.close()

                   


Ejemplo de configuracion del clasificador preentrenado 

In [7]:
# Autor: Pablo Rocamora

import torch
import torch.nn as nn
import torch.nn.functional as F
from encoders import (
    ClassifierHead,
    GATEncoder,
    GCNEncoder,
    ProjectionHead,
    SiameseGraphNetwork
)

from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm

from encoders import (
    ClassifierHead,
    GCNEncoder,
    ProjectionHead,
)


class GraphConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(GraphConvBlock, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
        self.bn = BatchNorm(out_channels)
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.relu(x)
        x = self.bn(x)
        if self.training:  # Aplicación de Dropout solo durante el entrenamiento
            x = self.dropout(x)
        
        return x
    

class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, dropout, n_hidden_blocks):
        super(GCNEncoder, self).__init__()
        self.input_block = GraphConvBlock(in_channels, hidden_dim, dropout)
        self.hidden_blocks = nn.ModuleList([GraphConvBlock(hidden_dim, hidden_dim, dropout) for _ in range(n_hidden_blocks - 1)])
        self.output_block = GraphConvBlock(hidden_dim, out_channels, dropout)
        # self.bn = BatchNorm(out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_block(x, edge_index)
        for layer in self.hidden_blocks:
            x = layer(x, edge_index)
        x = self.output_block(x, edge_index)
        # x = self.bn(x)

        return global_mean_pool(x, batch) # (batch_size, out_channels)


class ProjectionHead(nn.Module):
    """
    Proyección de las embeddings de texto a un espacio de dimensión reducida.
    """
    def __init__(
        self,
        embedding_dim,# Salida del modelo de lenguaje (768)
        projection_dim, # Dimensión de la proyección (256)
        # dropout=0.1
    ):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        # self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        # x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x
    

class SiameseContrastiveGraphNetwork(nn.Module):
    def __init__(self, encoder, projection_head):
        super(SiameseContrastiveGraphNetwork, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head

    def forward(self, graph):
        x_1 = self.encoder(graph)
        x_1 = self.projection_head(x_1)
        return x_1

# ========================================================================


In [14]:

model = SiameseContrastiveGraphNetwork(
    encoder = GCNEncoder(
        in_channels = 3, 
        hidden_dim = 64, 
        out_channels = 128, 
        dropout = 0.5, 
        n_hidden_blocks = 4
    ),

    projection_head = ProjectionHead(
        embedding_dim = 128, 
        projection_dim = 64
    )
)

model = torch.compile(model, dynamic=True)

# Cargar pesos preentrenados
checkpoint = torch.load('/app/trained_models/checkpoint_HCP_105_GCN_512_5_infonce_1723148869.1645732.pth')

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model

OptimizedModule(
  (_orig_mod): SiameseContrastiveGraphNetwork(
    (encoder): GCNEncoder(
      (input_block): GraphConvBlock(
        (conv): GCNConv(3, 64)
        (bn): BatchNorm(64)
        (relu): LeakyReLU(negative_slope=0.01)
        (dropout): Dropout(p=0.5, inplace=False)
      )
      (hidden_blocks): ModuleList(
        (0-2): 3 x GraphConvBlock(
          (conv): GCNConv(64, 64)
          (bn): BatchNorm(64)
          (relu): LeakyReLU(negative_slope=0.01)
          (dropout): Dropout(p=0.5, inplace=False)
        )
      )
      (output_block): GraphConvBlock(
        (conv): GCNConv(64, 128)
        (bn): BatchNorm(128)
        (relu): LeakyReLU(negative_slope=0.01)
        (dropout): Dropout(p=0.5, inplace=False)
      )
    )
    (projection_head): ProjectionHead(
      (projection): Linear(in_features=128, out_features=64, bias=True)
      (gelu): GELU(approximate='none')
      (fc): Linear(in_features=64, out_features=64, bias=True)
      (layer_norm): LayerNorm((64,

In [17]:
# Delete dropout layers and freeze the weights
model.encoder.input_block.dropout = None
model.encoder.hidden_blocks[0].dropout = None
model.encoder.hidden_blocks[1].dropout = None
model.encoder.hidden_blocks[2].dropout = None

model.encoder.output_block.dropout = None

# Delete projection head
model = model.encoder

for param in model.parameters():
    param.requires_grad = False

model    

GCNEncoder(
  (input_block): GraphConvBlock(
    (conv): GCNConv(3, 64)
    (bn): BatchNorm(64)
    (relu): LeakyReLU(negative_slope=0.01)
    (dropout): None
  )
  (hidden_blocks): ModuleList(
    (0-2): 3 x GraphConvBlock(
      (conv): GCNConv(64, 64)
      (bn): BatchNorm(64)
      (relu): LeakyReLU(negative_slope=0.01)
      (dropout): None
    )
  )
  (output_block): GraphConvBlock(
    (conv): GCNConv(64, 128)
    (bn): BatchNorm(128)
    (relu): LeakyReLU(negative_slope=0.01)
    (dropout): None
  )
)

In [19]:

class GraphClassifier(nn.Module):
    def __init__(self, encoder, n_classes):
        super(GraphClassifier, self).__init__()
        self.encoder = encoder

        self.classifier = ClassifierHead(
            projection_dim = 128, 
            n_classes = n_classes
        )   

    def forward(self, graph):
        x = self.encoder(graph)
        x = self.classifier(x)
        return x

model = GraphClassifier(model, 72)

model

GraphClassifier(
  (encoder): GCNEncoder(
    (input_block): GraphConvBlock(
      (conv): GCNConv(3, 64)
      (bn): BatchNorm(64)
      (relu): LeakyReLU(negative_slope=0.01)
      (dropout): None
    )
    (hidden_blocks): ModuleList(
      (0-2): 3 x GraphConvBlock(
        (conv): GCNConv(64, 64)
        (bn): BatchNorm(64)
        (relu): LeakyReLU(negative_slope=0.01)
        (dropout): None
      )
    )
    (output_block): GraphConvBlock(
      (conv): GCNConv(64, 128)
      (bn): BatchNorm(128)
      (relu): LeakyReLU(negative_slope=0.01)
      (dropout): None
    )
  )
  (classifier): ClassifierHead(
    (fc): Linear(in_features=128, out_features=72, bias=True)
    (softmax): LogSoftmax(dim=1)
  )
)