In [1]:
# # For COLAB
# from google.colab import drive
# drive.mount('/content/gdrive')

In [1]:
# !pip install -Uq kaggle torchsummary matplotlib tensorboard scikit-learn
# !mkdir ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# ! kaggle datasets download xhlulu/140k-real-and-fake-faces

# import shutil
# shutil.unpack_archive('140k-real-and-fake-faces.zip')

Dataset URL: https://www.kaggle.com/datasets/xhlulu/140k-real-and-fake-faces
License(s): other


In [2]:
import os
import numpy as np
import datetime
from copy import deepcopy
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.utils import save_image
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchsummary import summary
torch.manual_seed(69)
%load_ext tensorboard

In [4]:
# Hyperparameters
CHECKPOINT_GEN = "wgan_generator_128_2.pth"
CHECKPOINT_CRITIC = "wgan_critic_128_2.pth"
LOG_FOLDER = "logs/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_EPOCHS = 10
BATCH_SIZE = 512
TRAIN_SUBSET_SIZE = 25000
VALID_SUBSET_SIZE = 5000
TEST_SUBSET_SIZE = 5000
LEARNING_RATE = 1e-4
IMAGE_SIZE = 128
CHANNELS_IMG = 3
Z_DIM = 256
FEATURES_CRITIC = 16
FEATURES_GEN = 16
NUM_WORKERS = 2
# os.makedirs("/content/gdrive/MyDrive/generated-image-detection/models/", exist_ok=True)
os.makedirs(LOG_FOLDER, exist_ok=True)

# Data

In [5]:
def create_dataloader(
    data_path,
    subset_size,
    batch_size,
    img_size,
    only_real=True
):
    data_transforms = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize(
           [0.5 for _ in range(3)],
           [0.5 for _ in range(3)]),
    ])

    dataset = datasets.ImageFolder(root=data_path, transform=data_transforms)
    if only_real:
      class_idx = dataset.class_to_idx["real"]
      class_indices = [i for i, label in enumerate(dataset.targets) if label == class_idx]
      subset_indices = class_indices[:subset_size]
    else:
      real_class_idx = dataset.class_to_idx["real"]
      real_class_indices = [
          i for i, label in enumerate(dataset.targets) if label == real_class_idx
      ]
      real_subset_indices = real_class_indices[:subset_size//2]

      fake_class_idx = dataset.class_to_idx["fake"]
      fake_class_indices = [
          i for i, label in enumerate(dataset.targets) if label == fake_class_idx
      ]
      fake_subset_indices = fake_class_indices[:subset_size//2]
      subset_indices = real_subset_indices + fake_subset_indices

    subset_dataset = Subset(dataset, subset_indices)
    dataloader = DataLoader(
       subset_dataset,
       batch_size=batch_size,
       shuffle=True,
       num_workers=NUM_WORKERS,
       pin_memory=True,
       drop_last=True
    )
    return dataloader

In [6]:
train_loader = create_dataloader("real_vs_fake/real-vs-fake/train", TRAIN_SUBSET_SIZE, BATCH_SIZE, IMAGE_SIZE)
valid_loader = create_dataloader("real_vs_fake/real-vs-fake/valid", VALID_SUBSET_SIZE, BATCH_SIZE , IMAGE_SIZE)
test_loader = create_dataloader("real_vs_fake/real-vs-fake/test", TEST_SUBSET_SIZE, BATCH_SIZE , IMAGE_SIZE)

print(f"Training \t Batches: {len(train_loader)} \t Images: {len(train_loader.dataset)}")
print(f"Validation \t Batches: {len(valid_loader)} \t Images: {len(valid_loader.dataset)}")
print(f"Test \t\t Batches: {len(test_loader)} \t Images: {len(test_loader.dataset)}")

Training 	 Batches: 48 	 Images: 25000
Validation 	 Batches: 9 	 Images: 5000
Test 		 Batches: 9 	 Images: 5000


In [7]:
# For generators
fixed_train_vector = torch.randn(len(train_loader), train_loader.batch_size, Z_DIM, 1, 1).to(DEVICE)
fixed_valid_vector = torch.randn(len(valid_loader), valid_loader.batch_size, Z_DIM, 1, 1).to(DEVICE)
fixed_test_vector = torch.randn(len(test_loader), test_loader.batch_size, Z_DIM, 1, 1).to(DEVICE)

# Utils

In [8]:
def save_checkpoint(model, optimizer, epoch=None, loss=None, filename="my_checkpoint.pth.tar"):
    torch.save({
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
    }, filename)
    print("=> Model saved")


def load_checkpoint(checkpoint_file, model, optimizer=None, lr=None):
    checkpoint = torch.load(checkpoint_file, map_location="cuda", weights_only=True)
    model.load_state_dict(checkpoint["state_dict"])

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer"])

        # If we don't do this then it will just have learning rate of old checkpoint
        # and it will lead to many hours of debugging \:
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
    print("=> Loaded checkpoint")

## Train Function

In [11]:
import torch
import torch.utils.data
import torch.utils.tensorboard
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from typing import Optional, Tuple, List

def calculate_metrics(targets: List, predictions: List, scores: List) -> dict:
    """Calculate and return all evaluation metrics."""
    return {
        'accuracy': accuracy_score(targets, predictions),
        'precision': precision_score(targets, predictions, zero_division=1),
        'recall': recall_score(targets, predictions, zero_division=1),
        'f1': f1_score(targets, predictions, zero_division=1),
        'roc_auc': roc_auc_score(targets, scores)
    }

def process_batch(
    disc: torch.nn.Module,
    real_data: torch.Tensor,
    real_labels: torch.Tensor,
    gen: torch.nn.Module,
    fixed_vector: torch.Tensor,
    criterion: torch.nn.Module,
    device: torch.device
) -> Tuple[torch.Tensor, list, list, list]:
    """Process a single batch of data."""
    torch.cuda.empty_cache()
    # Process real samples
    real_outputs = disc(real_data).squeeze()
    real_loss = criterion(real_outputs, real_labels)

    # Generate and process fake samples
    fake = gen(fixed_vector)
    fake_labels = torch.zeros_like(real_labels).to(device).float()
    fake_outputs = disc(fake).squeeze()
    fake_loss = criterion(fake_outputs, fake_labels)

    # print(real_labels[:5], fake_labels[:5])
    # Combine predictions and targets
    scores = torch.cat((real_outputs, fake_outputs)).cpu().detach().numpy()
    targets = torch.cat((real_labels, fake_labels)).cpu().detach().numpy()
    preds = (scores >= 0.5).astype(int) # Converting to discreet 1s and 0s instead of probability
    # print(preds[:5], scores[:5], targets[:5])
    return real_loss + fake_loss, preds, scores, targets

def train_epoch(
    disc: torch.nn.Module,
    gen: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    optimizer: Optional[torch.optim.Optimizer],
    criterion: torch.nn.Module,
    fixed_vector: torch.Tensor,
    device: torch.device,
    phase: str
) -> dict:
    """Run one epoch of training or validation."""
    running_loss = 0.0
    all_preds, all_scores, all_targets = [], [], []

    loop = tqdm(data_loader, leave=True, desc=phase)
    for batch_idx, (real, real_labels) in enumerate(loop):
        real = real.to(device)
        real_labels = real_labels.to(device).float()

        loss, preds, scores, targets = process_batch(
            disc, real, real_labels, gen,
            fixed_vector[batch_idx], criterion, device
        )


        if optimizer is not None:  # Training phase
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Update metrics
        running_loss += loss.item()
        all_preds.extend(preds)
        all_scores.extend(scores)
        all_targets.extend(targets)

        # Update progress bar
        metrics = calculate_metrics(all_targets, all_preds, all_scores)
        metrics['loss'] = running_loss / (batch_idx + 1)
        loop.set_postfix(metrics)

    # Calculate final metrics
    metrics = calculate_metrics(all_targets, all_preds, all_scores)
    metrics['loss'] = running_loss / len(data_loader)
    return metrics

def log_metrics(
    writer: torch.utils.tensorboard.SummaryWriter,
    train_metrics: dict,
    valid_metrics: dict,
    epoch: int
) -> None:
    """Log metrics to TensorBoard."""
    for metric_name in train_metrics:
        writer.add_scalar(f"{metric_name}/train", train_metrics[metric_name], global_step=epoch)
        writer.add_scalar(f"{metric_name}/valid", valid_metrics[metric_name], global_step=epoch)

def train_discriminator(
    disc: torch.nn.Module,
    gen: torch.nn.Module,
    writer: torch.utils.tensorboard.SummaryWriter,
    train_loader: torch.utils.data.DataLoader,
    valid_loader: torch.utils.data.DataLoader,
    fixed_train_vector: torch.Tensor,
    fixed_valid_vector: torch.Tensor,
    num_epochs: int = 100,
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    criterion: torch.nn.Module = torch.nn.BCELoss(),
    optimizer: Optional[torch.optim.Optimizer] = None,
    checkpoint_file: str = "discriminator_checkpoint.pth"
) -> Tuple[torch.nn.Module, List[float], List[float]]:
    """
    Train a discriminator model in a GAN setup.

    Args:
        disc: Discriminator model
        gen: Generator model
        writer: TensorBoard writer
        train_loader: Training data loader
        valid_loader: Validation data loader
        fixed_train_vector: Fixed noise vector for generating fake training samples
        fixed_valid_vector: Fixed noise vector for generating fake validation samples
        num_epochs: Number of training epochs
        device: Device to train on
        criterion: Loss function
        optimizer: Optimizer (defaults to Adam if None)
        checkpoint_file: Path to save model checkpoints

    Returns:
        Tuple of (trained discriminator, training losses, validation losses)
    """
    optimizer = optimizer or optim.Adam(disc.parameters())
    train_losses, valid_losses = [], []
    least_loss = float('inf')
    gen.eval()

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch + 1}/{num_epochs}]")

        # Training phase
        disc.train()
        train_metrics = train_epoch(
            disc, gen, train_loader, optimizer, criterion,
            fixed_train_vector, device, "Train"
        )

        # Validation phase
        disc.eval()
        with torch.no_grad():
            valid_metrics = train_epoch(
                disc, gen, valid_loader, None, criterion,
                fixed_valid_vector, device, "Validation"
            )

        # Log metrics and save checkpoint
        log_metrics(writer, train_metrics, valid_metrics, epoch)
        train_losses.append(train_metrics['loss'])
        valid_losses.append(valid_metrics['loss'])

        if valid_metrics['loss'] < least_loss:
          least_loss = valid_metrics['loss']
          save_checkpoint(disc, optimizer, epoch, least_loss, checkpoint_file)

    return disc, train_losses, valid_losses

## Test Function

In [12]:
def evaluate_model(discriminator, generator, test_loader, fixed_test_vector, device='cuda'):
    """
    Evaluates a discriminator on both real and generated samples.

    Args:
        discriminator: Discriminator model in eval mode
        generator: Generator model
        test_loader: DataLoader for real samples
        fixed_test_vector: Latent vectors for generating fake samples
        device: Device to run evaluation on

    Returns:
        dict: Evaluation metrics
    """
    discriminator.eval()
    generator.eval()
    criterion = torch.nn.BCELoss()
    running_loss = 0.0
    preds, scores, targets = [], [], []

    with torch.no_grad():
        loop = tqdm(test_loader, leave=True, desc="Testing")
        for batch_idx, (real, real_labels) in enumerate(loop):
            real = real.to(device)
            real_labels = real_labels.to(device).float()
            real_outputs = discriminator(real).squeeze()
            real_loss = criterion(real_outputs, real_labels)

            fake = generator(fixed_test_vector[batch_idx])
            fake_labels = torch.zeros_like(real_labels).to(device).float()
            fake_outputs = discriminator(fake).squeeze()
            fake_loss = criterion(fake_outputs, fake_labels)

            loss = real_loss + fake_loss
            running_loss += loss.item()

            outputs = torch.cat((real_outputs, fake_outputs), dim=0).cpu().numpy()
            labels = torch.cat((real_labels, fake_labels), dim=0).cpu().numpy()
            scores = np.concatenate((scores, outputs))
            targets = np.concatenate((targets, labels))
            preds = (scores >= 0.5).astype(int) # Converting to discreet 1s and 0s instead of probability
            metrics = {
                'loss': running_loss / (batch_idx + 1),
                'accuracy': accuracy_score(targets, preds),
                'precision': precision_score(targets, preds, zero_division=1),
                'recall': recall_score(targets, preds, zero_division=1),
                'f1': f1_score(targets, preds, zero_division=1),
                'roc_auc': roc_auc_score(targets, scores)
            }

            loop.set_postfix(**metrics)

    final_metrics = {
        k: v/len(test_loader) if k == 'loss' else v
        for k, v in metrics.items()
    }

    print(f"Test: Loss: {final_metrics['loss']:.4f} | "
          f"Accuracy: {final_metrics['accuracy']:.4f} | "
          f"Precision: {final_metrics['precision']:.4f} | "
          f"Recall: {final_metrics['recall']:.4f} | "
          f"F1: {final_metrics['f1']:.4f} | "
          f"ROC AUC: {final_metrics['roc_auc']:.4f}")

    return final_metrics

## Abalation function

In [13]:
import copy

def abalation(discriminator, generator, no_target_layers):
  time = datetime.datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
  log_folder = os.path.join(LOG_FOLDER, f"wgan_abalation/{time}")
  writer = SummaryWriter(log_folder, comment=f"WGAN discriminator {time}")

  for param in discriminator.parameters():
      param.requires_grad = False

  layers = None

  results = {}

  for n in range(1, len(no_target_layers) + 1):
    CHECKPOINT_WGAN_DISC = f"models/wgan_{n}unfreeze.pth"
    for param in layers[-n].parameters():
      param.requires_grad = True

    discriminator_copy = copy.deepcopy(discriminator)

    discriminator_copy, wgan_disc_train_losses, wgan_disc_valid_losses = train_discriminator(
      disc=discriminator_copy,
      gen=generator,
      writer=writer,
      train_loader=train_loader,
      valid_loader=valid_loader,
      fixed_train_vector=fixed_train_vector,
      fixed_valid_vector=fixed_valid_vector,
      num_epochs=NUM_EPOCHS,
      device=DEVICE,
      criterion=torch.nn.BCELoss(),
      optimizer=optim.Adam(wgan_disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9)),
      checkpoint_file=CHECKPOINT_WGAN_DISC,
    )

    load_checkpoint(CHECKPOINT_WGAN_DISC, discriminator_copy)

    wgan_disc_test_metrics = evaluate_model(
      discriminator=discriminator_copy,
      generator=generator,
      test_loader=test_loader,
      fixed_test_vector=fixed_test_vector,
    )

    results[CHECKPOINT_WGAN_DISC] = wgan_disc_test_metrics
  return results

# Models

In [14]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 32, 4, 1, 0),  # img: 4
            self._block(features_g * 32, features_g * 16, 4, 2, 1),  # img: 8
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 16
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 32
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 64
            self._block(features_g * 2, features_g * 1, 4, 2, 1),  # img: 128
            nn.Conv2d(
                features_g * 1, channels_img, kernel_size=3, stride=1, padding=1
            ),
            # Output: N x channels_img x 128 x 128
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

wgan_gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(DEVICE)
load_checkpoint(CHECKPOINT_GEN, wgan_gen)
wgan_gen.eval()

=> Loaded checkpoint


Generator(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(256, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): Sequential(
      (0): ConvTr

In [15]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 128 x 128
            nn.Conv2d(channels_img, features_d, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            self._block(features_d * 8, features_d * 16, 4, 2, 1),
            self._block(features_d * 16, features_d * 32, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 32, features_d * 32, kernel_size=4, stride=1, padding=0, bias=False),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 32, 1, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

wgan_disc = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(DEVICE)
load_checkpoint(CHECKPOINT_CRITIC, wgan_disc)
for param in wgan_disc.parameters():
    param.requires_grad = True
wgan_disc.train()

=> Loaded checkpoint


Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=

In [16]:
import os
import copy
import datetime
import torch
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List
from tqdm import tqdm

# Assumes the following are defined externally:
#  discriminator, generator: pretrained models
#  train_loader, valid_loader, test_loader: DataLoader
#  fixed_train_vector, fixed_valid_vector, fixed_test_vector: torch.Tensor noise vectors
#  DEVICE: torch.device
#  NUM_EPOCHS: int
#  LEARNING_RATE: float
#  LOG_FOLDER: str
#  train_discriminator, evaluate_model are imported

def ablation_study_incremental(discriminator: torch.nn.Module,
                               generator: torch.nn.Module,
                               train_loader,
                               valid_loader,
                               test_loader,
                               fixed_train_vector: torch.Tensor,
                               fixed_valid_vector: torch.Tensor,
                               fixed_test_vector: torch.Tensor,
                               device: torch.device,
                               num_epochs: int,
                               learning_rate: float,
                               log_folder: str) -> Dict[str, dict]:
    """
    Conduct incremental ablation where:
      1) First, unfreeze only the penultimate of the last 5 layers
      2) Then unfreeze the last 2 layers
      3) Then last 3, up to all 5
    Train only those layers each run and evaluate on the test set.

    Returns a dict mapping descriptive labels to test metrics.
    """
    # Extract all parameterized modules from the Sequential
    all_modules = list(discriminator.disc.children())
    param_modules = [m for m in all_modules if sum(p.numel() for p in m.parameters()) > 0]
    last_five = param_modules[-5:]

    results: Dict[str, dict] = {}
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Iterate over configurations with progress bar
    for n in tqdm(range(1, len(last_five) + 1), desc="Ablation configs", unit="config"):
        # Determine which modules to unfreeze
        if n == 1:
            modules_to_unfreeze = last_five[-2:-1]
            label = "penultimate_only"
        else:
            modules_to_unfreeze = last_five[-n:]
            label = f"last_{n}_layers"

        # Deep copy and freeze all parameters
        disc_copy = copy.deepcopy(discriminator).to(device)
        for p in disc_copy.parameters():
            p.requires_grad = False

        # Map copies of parameterized modules
        modules_copy = list(disc_copy.disc.children())
        param_copy = [m for m in modules_copy if sum(p.numel() for p in m.parameters()) > 0]

        # Slice target copies
        if n == 1:
            target_copies = param_copy[-2:-1]
        else:
            target_copies = param_copy[-n:]

        # Unfreeze target modules
        for module in target_copies:
            for p in module.parameters():
                p.requires_grad = True

        # Collect trainable params
        trainable = [p for p in disc_copy.parameters() if p.requires_grad]
        if not trainable:
            raise ValueError(f"No parameters to optimize for config '{label}'")

        # Optimizer for unfreezed params
        optimizer = torch.optim.Adam(trainable, lr=learning_rate, betas=(0.0, 0.9))

        # TensorBoard logging
        log_dir = os.path.join(log_folder, f"ablation_{label}_{timestamp}")
        writer = SummaryWriter(log_dir)

        # Train
        disc_copy, train_losses, valid_losses = train_discriminator(
            disc=disc_copy,
            gen=generator,
            writer=writer,
            train_loader=train_loader,
            valid_loader=valid_loader,
            fixed_train_vector=fixed_train_vector,
            fixed_valid_vector=fixed_valid_vector,
            num_epochs=num_epochs,
            device=device,
            criterion=torch.nn.BCELoss(),
            optimizer=optimizer,
            checkpoint_file=os.path.join(log_folder, f"disc_ablation_{label}.pth")
        )

        # Test
        test_metrics = evaluate_model(
            discriminator=disc_copy,
            generator=generator,
            test_loader=test_loader,
            fixed_test_vector=fixed_test_vector,
            device=device
        )

        results[label] = test_metrics
        writer.close()

    return results

# Example usage:
results = ablation_study_incremental(
    discriminator=wgan_disc,
    generator=wgan_gen,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    fixed_train_vector=fixed_train_vector,
    fixed_valid_vector=fixed_valid_vector,
    fixed_test_vector=fixed_test_vector,
    device=DEVICE,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    log_folder=LOG_FOLDER
)
print(results)


Ablation configs:   0%|          | 0/5 [00:00<?, ?config/s]

Epoch [1/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:03<?, ?it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   2%|▏         | 1/48 [00:03<02:47,  3.57s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   2%|▏         | 1/48 [00:03<02:47,  3.57s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:03<01:12,  1.58s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:03<01:12,  1.58s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:03<00:42,  1.06it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:04<00:42,  1.06it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.7]
[Ain:   8%|▊         | 4/48 [00:04<00:28,  1.54it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc

=> Model saved
Epoch [2/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.89, precision=0.908, recall=0.867, f1=0.887, roc_auc=0.951, loss=2.77]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.09s/it, accuracy=0.89, precision=0.908, recall=0.867, f1=0.887, roc_auc=0.951, loss=2.77]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.09s/it, accuracy=0.896, precision=0.889, recall=0.905, f1=0.897, roc_auc=0.956, loss=2.25]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.79it/s, accuracy=0.896, precision=0.889, recall=0.905, f1=0.897, roc_auc=0.956, loss=2.25]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.79it/s, accuracy=0.895, precision=0.899, recall=0.889, f1=0.894, roc_auc=0.957, loss=2.09]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.84it/s, accuracy=0.895, precision=0.899, recall=0.889, f1=0.894, roc_auc=0.957, loss=2.09]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.84it/s, accuracy=0.894, precision=0.891, recall=0.899, f1=0.895, roc_auc=0.958, loss=2.1] 
[Ai

=> Model saved
Epoch [3/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.938, precision=0.959, recall=0.916, f1=0.937, roc_auc=0.984, loss=0.693]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.938, precision=0.959, recall=0.916, f1=0.937, roc_auc=0.984, loss=0.693]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.938, precision=0.945, recall=0.932, f1=0.938, roc_auc=0.983, loss=0.614]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.938, precision=0.945, recall=0.932, f1=0.938, roc_auc=0.983, loss=0.614]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.943, precision=0.948, recall=0.938, f1=0.943, roc_auc=0.984, loss=0.624]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.943, precision=0.948, recall=0.938, f1=0.943, roc_auc=0.984, loss=0.624]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.82it/s, accuracy=0.945, precision=0.95, recall=0.939, f1=0.944, roc_auc=0.986, loss=0.5

=> Model saved
Epoch [4/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:00<?, ?it/s, accuracy=0.956, precision=0.932, recall=0.984, f1=0.957, roc_auc=0.989, loss=0.666]
[Ain:   2%|▏         | 1/48 [00:00<00:46,  1.01it/s, accuracy=0.956, precision=0.932, recall=0.984, f1=0.957, roc_auc=0.989, loss=0.666]
[Ain:   2%|▏         | 1/48 [00:01<00:46,  1.01it/s, accuracy=0.951, precision=0.956, recall=0.945, f1=0.95, roc_auc=0.985, loss=0.586] 
[Ain:   4%|▍         | 2/48 [00:01<00:23,  1.94it/s, accuracy=0.951, precision=0.956, recall=0.945, f1=0.95, roc_auc=0.985, loss=0.586]
[Ain:   4%|▍         | 2/48 [00:01<00:23,  1.94it/s, accuracy=0.945, precision=0.932, recall=0.96, f1=0.946, roc_auc=0.985, loss=0.632]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.94it/s, accuracy=0.945, precision=0.932, recall=0.96, f1=0.946, roc_auc=0.985, loss=0.632]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.94it/s, accuracy=0.942, precision=0.945, recall=0.938, f1=0.942, roc_auc=0.982, loss=0.656

=> Model saved
Epoch [5/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.966, precision=0.954, recall=0.979, f1=0.966, roc_auc=0.989, loss=0.611]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.09s/it, accuracy=0.966, precision=0.954, recall=0.979, f1=0.966, roc_auc=0.989, loss=0.611]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.09s/it, accuracy=0.967, precision=0.968, recall=0.966, f1=0.967, roc_auc=0.992, loss=0.42] 
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.80it/s, accuracy=0.967, precision=0.968, recall=0.966, f1=0.967, roc_auc=0.992, loss=0.42]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.80it/s, accuracy=0.968, precision=0.966, recall=0.969, f1=0.968, roc_auc=0.992, loss=0.429]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.87it/s, accuracy=0.968, precision=0.966, recall=0.969, f1=0.968, roc_auc=0.992, loss=0.429]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.87it/s, accuracy=0.968, precision=0.967, recall=0.968, f1=0.968, roc_auc=0.993, loss=0.3

=> Model saved
Epoch [6/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.972, precision=0.962, recall=0.982, f1=0.972, roc_auc=0.992, loss=0.49]
[Ain:   2%|▏         | 1/48 [00:01<00:48,  1.04s/it, accuracy=0.972, precision=0.962, recall=0.982, f1=0.972, roc_auc=0.992, loss=0.49]
[Ain:   2%|▏         | 1/48 [00:01<00:48,  1.04s/it, accuracy=0.97, precision=0.973, recall=0.968, f1=0.97, roc_auc=0.993, loss=0.359] 
[Ain:   4%|▍         | 2/48 [00:01<00:24,  1.87it/s, accuracy=0.97, precision=0.973, recall=0.968, f1=0.97, roc_auc=0.993, loss=0.359]
[Ain:   4%|▍         | 2/48 [00:01<00:24,  1.87it/s, accuracy=0.971, precision=0.968, recall=0.975, f1=0.971, roc_auc=0.993, loss=0.377]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.971, precision=0.968, recall=0.975, f1=0.971, roc_auc=0.993, loss=0.377]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.971, precision=0.972, recall=0.971, f1=0.971, roc_auc=0.994, loss=0.333]


=> Model saved
Epoch [7/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.972, precision=0.967, recall=0.977, f1=0.972, roc_auc=0.993, loss=0.305]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.972, precision=0.967, recall=0.977, f1=0.972, roc_auc=0.993, loss=0.305]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.972, precision=0.974, recall=0.97, f1=0.972, roc_auc=0.994, loss=0.257] 
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.972, precision=0.974, recall=0.97, f1=0.972, roc_auc=0.994, loss=0.257]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.974, precision=0.973, recall=0.976, f1=0.974, roc_auc=0.995, loss=0.284]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.81it/s, accuracy=0.974, precision=0.973, recall=0.976, f1=0.974, roc_auc=0.995, loss=0.284]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.81it/s, accuracy=0.974, precision=0.975, recall=0.974, f1=0.974, roc_auc=0.995, loss=0.2

=> Model saved
Epoch [8/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.979, precision=0.966, recall=0.992, f1=0.979, roc_auc=0.995, loss=0.247]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.979, precision=0.966, recall=0.992, f1=0.979, roc_auc=0.995, loss=0.247]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.975, precision=0.976, recall=0.974, f1=0.975, roc_auc=0.995, loss=0.231]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.975, precision=0.976, recall=0.974, f1=0.975, roc_auc=0.995, loss=0.231]
[Ain:   4%|▍         | 2/48 [00:02<00:26,  1.74it/s, accuracy=0.976, precision=0.975, recall=0.977, f1=0.976, roc_auc=0.995, loss=0.273]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.56it/s, accuracy=0.976, precision=0.975, recall=0.977, f1=0.976, roc_auc=0.995, loss=0.273]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.56it/s, accuracy=0.977, precision=0.975, recall=0.98, f1=0.977, roc_auc=0.996, loss=0.2

=> Model saved
Epoch [9/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.979, precision=0.973, recall=0.984, f1=0.979, roc_auc=0.996, loss=0.241]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.979, precision=0.973, recall=0.984, f1=0.979, roc_auc=0.996, loss=0.241]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.979, precision=0.981, recall=0.976, f1=0.978, roc_auc=0.994, loss=0.258]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.979, precision=0.981, recall=0.976, f1=0.978, roc_auc=0.994, loss=0.258]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.978, precision=0.979, recall=0.977, f1=0.978, roc_auc=0.995, loss=0.285]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.81it/s, accuracy=0.978, precision=0.979, recall=0.977, f1=0.978, roc_auc=0.995, loss=0.285]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.81it/s, accuracy=0.979, precision=0.979, recall=0.979, f1=0.979, roc_auc=0.996, loss=0.

=> Model saved
Epoch [10/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.981, precision=0.975, recall=0.988, f1=0.982, roc_auc=0.997, loss=0.17]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.981, precision=0.975, recall=0.988, f1=0.982, roc_auc=0.997, loss=0.17]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.98, precision=0.982, recall=0.979, f1=0.98, roc_auc=0.997, loss=0.143] 
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.98, precision=0.982, recall=0.979, f1=0.98, roc_auc=0.997, loss=0.143]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.982, precision=0.982, recall=0.982, f1=0.982, roc_auc=0.997, loss=0.147]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.86it/s, accuracy=0.982, precision=0.982, recall=0.982, f1=0.982, roc_auc=0.997, loss=0.147]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.86it/s, accuracy=0.983, precision=0.983, recall=0.983, f1=0.983, roc_auc=0.997, loss=0.173]


=> Model saved



[Ating:   0%|          | 0/9 [00:00<?, ?it/s]
[Ating:   0%|          | 0/9 [00:01<?, ?it/s, accuracy=0.981, f1=0.981, loss=0.515, precision=0.979, recall=0.984, roc_auc=0.995]
[Ating:  11%|█         | 1/9 [00:01<00:08,  1.04s/it, accuracy=0.981, f1=0.981, loss=0.515, precision=0.979, recall=0.984, roc_auc=0.995]
[Ating:  11%|█         | 1/9 [00:01<00:08,  1.04s/it, accuracy=0.982, f1=0.982, loss=0.32, precision=0.982, recall=0.981, roc_auc=0.997] 
[Ating:  11%|█         | 1/9 [00:01<00:08,  1.04s/it, accuracy=0.983, f1=0.983, loss=0.257, precision=0.982, recall=0.983, roc_auc=0.997]
[Ating:  33%|███▎      | 3/9 [00:01<00:03,  1.99it/s, accuracy=0.983, f1=0.983, loss=0.257, precision=0.982, recall=0.983, roc_auc=0.997]
[Ating:  33%|███▎      | 3/9 [00:01<00:03,  1.99it/s, accuracy=0.983, f1=0.983, loss=0.224, precision=0.983, recall=0.983, roc_auc=0.997]
[Ating:  33%|███▎      | 3/9 [00:02<00:03,  1.99it/s, accuracy=0.982, f1=0.982, loss=0.208, precision=0.982, recall=0.982, ro

Test: Loss: 0.0283 | Accuracy: 0.9811 | Precision: 0.9807 | Recall: 0.9816 | F1: 0.9811 | ROC AUC: 0.9964
Epoch [1/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.9]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.9]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.83it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.83it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   8%|▊         | 4/48 [00:02<00:17,  2.46it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc

=> Model saved
Epoch [2/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.917, precision=0.898, recall=0.941, f1=0.919, roc_auc=0.967, loss=1.98]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.917, precision=0.898, recall=0.941, f1=0.919, roc_auc=0.967, loss=1.98]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.922, precision=0.924, recall=0.919, f1=0.922, roc_auc=0.971, loss=1.46]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.922, precision=0.924, recall=0.919, f1=0.922, roc_auc=0.971, loss=1.46]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.913, precision=0.903, recall=0.926, f1=0.915, roc_auc=0.969, loss=1.43]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.913, precision=0.903, recall=0.926, f1=0.915, roc_auc=0.969, loss=1.43]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.82it/s, accuracy=0.911, precision=0.916, recall=0.905, f1=0.91, roc_auc=0.968, loss=1.34] 
[

=> Model saved
Epoch [3/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.904, precision=0.977, recall=0.828, f1=0.896, roc_auc=0.986, loss=1.12]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.904, precision=0.977, recall=0.828, f1=0.896, roc_auc=0.986, loss=1.12]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.909, precision=0.91, recall=0.907, f1=0.909, roc_auc=0.967, loss=1.07] 
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.909, precision=0.91, recall=0.907, f1=0.909, roc_auc=0.967, loss=1.07]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.904, precision=0.931, recall=0.873, f1=0.901, roc_auc=0.969, loss=1.13]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.88it/s, accuracy=0.904, precision=0.931, recall=0.873, f1=0.901, roc_auc=0.969, loss=1.13]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.88it/s, accuracy=0.904, precision=0.904, recall=0.903, f1=0.904, roc_auc=0.967, loss=1.07]
[A

Epoch [4/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.902, precision=0.841, recall=0.992, f1=0.91, roc_auc=0.983, loss=2]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.902, precision=0.841, recall=0.992, f1=0.91, roc_auc=0.983, loss=2]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.89, precision=0.901, recall=0.875, f1=0.888, roc_auc=0.952, loss=1.78]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.89, precision=0.901, recall=0.875, f1=0.888, roc_auc=0.952, loss=1.78]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.878, precision=0.852, recall=0.915, f1=0.882, roc_auc=0.954, loss=1.82]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.81it/s, accuracy=0.878, precision=0.852, recall=0.915, f1=0.882, roc_auc=0.954, loss=1.82]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.81it/s, accuracy=0.883, precision=0.881, recall=0.886, f1=0.883, roc_auc=0.953, loss=1.69]
[Ain:   8%|

=> Model saved
Epoch [5/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.942, precision=0.901, recall=0.994, f1=0.945, roc_auc=0.99, loss=0.94]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.942, precision=0.901, recall=0.994, f1=0.945, roc_auc=0.99, loss=0.94]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.934, precision=0.94, recall=0.926, f1=0.933, roc_auc=0.981, loss=0.769]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.73it/s, accuracy=0.934, precision=0.94, recall=0.926, f1=0.933, roc_auc=0.981, loss=0.769]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.73it/s, accuracy=0.928, precision=0.911, recall=0.95, f1=0.93, roc_auc=0.981, loss=0.778] 
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.928, precision=0.911, recall=0.95, f1=0.93, roc_auc=0.981, loss=0.778]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.82it/s, accuracy=0.922, precision=0.929, recall=0.914, f1=0.921, roc_auc=0.977, loss=0.822]
[Ai

=> Model saved
Epoch [6/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.967, precision=0.965, recall=0.969, f1=0.967, roc_auc=0.994, loss=0.309]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.967, precision=0.965, recall=0.969, f1=0.967, roc_auc=0.994, loss=0.309]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.973, precision=0.971, recall=0.975, f1=0.973, roc_auc=0.995, loss=0.247]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.973, precision=0.971, recall=0.975, f1=0.973, roc_auc=0.995, loss=0.247]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.973, precision=0.973, recall=0.973, f1=0.973, roc_auc=0.994, loss=0.31] 
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.85it/s, accuracy=0.973, precision=0.973, recall=0.973, f1=0.973, roc_auc=0.994, loss=0.31]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.85it/s, accuracy=0.974, precision=0.972, recall=0.976, f1=0.974, roc_auc=0.995, loss=0.2

Epoch [7/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.972, precision=0.96, recall=0.984, f1=0.972, roc_auc=0.994, loss=0.464]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.972, precision=0.96, recall=0.984, f1=0.972, roc_auc=0.994, loss=0.464]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.975, precision=0.972, recall=0.978, f1=0.975, roc_auc=0.994, loss=0.333]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.975, precision=0.972, recall=0.978, f1=0.975, roc_auc=0.994, loss=0.333]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.977, precision=0.974, recall=0.98, f1=0.977, roc_auc=0.995, loss=0.337] 
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.977, precision=0.974, recall=0.98, f1=0.977, roc_auc=0.995, loss=0.337]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.82it/s, accuracy=0.975, precision=0.975, recall=0.976, f1=0.975, roc_auc=0.995, loss=0.318

=> Model saved
Epoch [8/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.97, precision=0.967, recall=0.973, f1=0.97, roc_auc=0.994, loss=0.285]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.97, precision=0.967, recall=0.973, f1=0.97, roc_auc=0.994, loss=0.285]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.974, precision=0.973, recall=0.976, f1=0.974, roc_auc=0.996, loss=0.208]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.73it/s, accuracy=0.974, precision=0.973, recall=0.976, f1=0.974, roc_auc=0.996, loss=0.208]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.73it/s, accuracy=0.976, precision=0.975, recall=0.977, f1=0.976, roc_auc=0.995, loss=0.296]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.79it/s, accuracy=0.976, precision=0.975, recall=0.977, f1=0.976, roc_auc=0.995, loss=0.296]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.79it/s, accuracy=0.976, precision=0.975, recall=0.977, f1=0.976, roc_auc=0.995, loss=0.27] 

=> Model saved
Epoch [9/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.975, precision=0.975, recall=0.975, f1=0.975, roc_auc=0.996, loss=0.207]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.975, precision=0.975, recall=0.975, f1=0.975, roc_auc=0.996, loss=0.207]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.978, precision=0.978, recall=0.977, f1=0.978, roc_auc=0.997, loss=0.167]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.978, precision=0.978, recall=0.977, f1=0.978, roc_auc=0.997, loss=0.167]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.98, precision=0.98, recall=0.981, f1=0.98, roc_auc=0.997, loss=0.211]   
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.84it/s, accuracy=0.98, precision=0.98, recall=0.981, f1=0.98, roc_auc=0.997, loss=0.211]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.84it/s, accuracy=0.979, precision=0.981, recall=0.977, f1=0.979, roc_auc=0.997, loss=0.202

Epoch [10/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.978, precision=0.969, recall=0.986, f1=0.978, roc_auc=0.996, loss=0.21]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.13s/it, accuracy=0.978, precision=0.969, recall=0.986, f1=0.978, roc_auc=0.996, loss=0.21]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.13s/it, accuracy=0.98, precision=0.976, recall=0.985, f1=0.981, roc_auc=0.997, loss=0.159]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.98, precision=0.976, recall=0.985, f1=0.981, roc_auc=0.997, loss=0.159]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.98, precision=0.979, recall=0.98, f1=0.98, roc_auc=0.997, loss=0.171]  
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.91it/s, accuracy=0.98, precision=0.979, recall=0.98, f1=0.98, roc_auc=0.997, loss=0.171]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.91it/s, accuracy=0.98, precision=0.978, recall=0.982, f1=0.98, roc_auc=0.997, loss=0.157]
[Ain

=> Model saved



[Ating:   0%|          | 0/9 [00:00<?, ?it/s]
[Ating:   0%|          | 0/9 [00:01<?, ?it/s, accuracy=0.975, f1=0.975, loss=0.548, precision=0.978, recall=0.971, roc_auc=0.995]
[Ating:  11%|█         | 1/9 [00:01<00:08,  1.05s/it, accuracy=0.975, f1=0.975, loss=0.548, precision=0.978, recall=0.971, roc_auc=0.995]
[Ating:  11%|█         | 1/9 [00:01<00:08,  1.05s/it, accuracy=0.977, f1=0.977, loss=0.369, precision=0.983, recall=0.971, roc_auc=0.995]
[Ating:  11%|█         | 1/9 [00:01<00:08,  1.05s/it, accuracy=0.979, f1=0.979, loss=0.281, precision=0.984, recall=0.975, roc_auc=0.996]
[Ating:  33%|███▎      | 3/9 [00:01<00:03,  1.95it/s, accuracy=0.979, f1=0.979, loss=0.281, precision=0.984, recall=0.975, roc_auc=0.996]
[Ating:  33%|███▎      | 3/9 [00:01<00:03,  1.95it/s, accuracy=0.98, f1=0.98, loss=0.242, precision=0.984, recall=0.977, roc_auc=0.996]  
[Ating:  33%|███▎      | 3/9 [00:02<00:03,  1.95it/s, accuracy=0.979, f1=0.979, loss=0.233, precision=0.983, recall=0.975, ro

Test: Loss: 0.0299 | Accuracy: 0.9792 | Precision: 0.9825 | Recall: 0.9757 | F1: 0.9791 | ROC AUC: 0.9961
Epoch [1/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.7]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.11s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.7]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.11s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.77it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.77it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.502, loss=99.8]
[Ain:   8%|▊         | 4/48 [00:02<00:18,  2.38it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc

=> Model saved
Epoch [2/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.61, precision=0.562, recall=1, f1=0.72, roc_auc=0.862, loss=29.3]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.10s/it, accuracy=0.61, precision=0.562, recall=1, f1=0.72, roc_auc=0.862, loss=29.3]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.10s/it, accuracy=0.685, precision=0.661, recall=0.759, f1=0.706, roc_auc=0.799, loss=18]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.77it/s, accuracy=0.685, precision=0.661, recall=0.759, f1=0.706, roc_auc=0.799, loss=18]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.77it/s, accuracy=0.708, precision=0.666, recall=0.834, f1=0.741, roc_auc=0.825, loss=15.8]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.86it/s, accuracy=0.708, precision=0.666, recall=0.834, f1=0.741, roc_auc=0.825, loss=15.8]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.86it/s, accuracy=0.731, precision=0.712, recall=0.775, f1=0.742, roc_auc=0.836, loss=13]  
[Ain:   8%|▊     

Epoch [3/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.83it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.83it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   8%|▊         | 4/48 [00:02<00:17,  2.45it/s, accuracy=0.5, precision=0

=> Model saved
Epoch [4/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.978, precision=0.973, recall=0.982, f1=0.978, roc_auc=0.998, loss=0.178]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.18s/it, accuracy=0.978, precision=0.973, recall=0.982, f1=0.978, roc_auc=0.998, loss=0.178]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.18s/it, accuracy=0.979, precision=0.981, recall=0.976, f1=0.978, roc_auc=0.998, loss=0.177]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.979, precision=0.981, recall=0.976, f1=0.978, roc_auc=0.998, loss=0.177]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.978, precision=0.977, recall=0.979, f1=0.978, roc_auc=0.997, loss=0.251]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.75it/s, accuracy=0.978, precision=0.977, recall=0.979, f1=0.978, roc_auc=0.997, loss=0.251]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.75it/s, accuracy=0.977, precision=0.978, recall=0.976, f1=0.977, roc_auc=0.997, loss=0.

Epoch [5/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.924, precision=0.938, recall=0.908, f1=0.923, roc_auc=0.978, loss=2.01]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.924, precision=0.938, recall=0.908, f1=0.923, roc_auc=0.978, loss=2.01]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.942, precision=0.956, recall=0.927, f1=0.941, roc_auc=0.987, loss=1.19]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.942, precision=0.956, recall=0.927, f1=0.941, roc_auc=0.987, loss=1.19]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.951, precision=0.958, recall=0.944, f1=0.951, roc_auc=0.989, loss=1.01]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.85it/s, accuracy=0.951, precision=0.958, recall=0.944, f1=0.951, roc_auc=0.989, loss=1.01]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.85it/s, accuracy=0.957, precision=0.964, recall=0.949, f1=0.957, roc_auc=0.991, loss=0.813]


Epoch [6/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.758, precision=1, recall=0.516, f1=0.68, roc_auc=0.993, loss=6.48]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.758, precision=1, recall=0.516, f1=0.68, roc_auc=0.993, loss=6.48]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.13s/it, accuracy=0.72, precision=0.705, recall=0.758, f1=0.73, roc_auc=0.839, loss=12.6]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.72it/s, accuracy=0.72, precision=0.705, recall=0.758, f1=0.73, roc_auc=0.839, loss=12.6]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.72it/s, accuracy=0.741, precision=0.766, recall=0.693, f1=0.728, roc_auc=0.868, loss=10.2]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.83it/s, accuracy=0.741, precision=0.766, recall=0.693, f1=0.728, roc_auc=0.868, loss=10.2]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.83it/s, accuracy=0.765, precision=0.763, recall=0.768, f1=0.766, roc_auc=0.87, loss=8.95] 
[Ain:   8%|▊   

=> Model saved
Epoch [7/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.99, precision=0.988, recall=0.992, f1=0.99, roc_auc=1, loss=0.048]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.17s/it, accuracy=0.99, precision=0.988, recall=0.992, f1=0.99, roc_auc=1, loss=0.048]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.17s/it, accuracy=0.991, precision=0.99, recall=0.991, f1=0.991, roc_auc=1, loss=0.0486]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.991, precision=0.99, recall=0.991, f1=0.991, roc_auc=1, loss=0.0486]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.989, precision=0.989, recall=0.99, f1=0.989, roc_auc=0.999, loss=0.0695]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.82it/s, accuracy=0.989, precision=0.989, recall=0.99, f1=0.989, roc_auc=0.999, loss=0.0695]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.82it/s, accuracy=0.989, precision=0.99, recall=0.989, f1=0.989, roc_auc=0.999, loss=0.0814]
[Ain:   8%|▊  

Epoch [8/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.984, precision=0.984, recall=0.984, f1=0.984, roc_auc=0.998, loss=0.152]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.15s/it, accuracy=0.984, precision=0.984, recall=0.984, f1=0.984, roc_auc=0.998, loss=0.152]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.15s/it, accuracy=0.985, precision=0.989, recall=0.981, f1=0.985, roc_auc=0.999, loss=0.12] 
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.985, precision=0.989, recall=0.981, f1=0.985, roc_auc=0.999, loss=0.12]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.74it/s, accuracy=0.987, precision=0.988, recall=0.986, f1=0.987, roc_auc=0.999, loss=0.117]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.92it/s, accuracy=0.987, precision=0.988, recall=0.986, f1=0.987, roc_auc=0.999, loss=0.117]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.92it/s, accuracy=0.988, precision=0.989, recall=0.987, f1=0.988, roc_auc=0.999, loss=0.1

=> Model saved
Epoch [9/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.992, precision=0.99, recall=0.994, f1=0.992, roc_auc=1, loss=0.0378]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.19s/it, accuracy=0.992, precision=0.99, recall=0.994, f1=0.992, roc_auc=1, loss=0.0378]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.19s/it, accuracy=0.992, precision=0.992, recall=0.991, f1=0.992, roc_auc=1, loss=0.043]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.992, precision=0.992, recall=0.991, f1=0.992, roc_auc=1, loss=0.043]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.992, precision=0.991, recall=0.993, f1=0.992, roc_auc=1, loss=0.0469]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.76it/s, accuracy=0.992, precision=0.991, recall=0.993, f1=0.992, roc_auc=1, loss=0.0469]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.76it/s, accuracy=0.991, precision=0.991, recall=0.991, f1=0.991, roc_auc=0.999, loss=0.0649]
[Ain:   8%|▊   

Epoch [10/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.83, precision=1, recall=0.66, f1=0.795, roc_auc=0.999, loss=3.46]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.83, precision=1, recall=0.66, f1=0.795, roc_auc=0.999, loss=3.46]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.904, precision=0.981, recall=0.823, f1=0.895, roc_auc=0.973, loss=1.94]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.904, precision=0.981, recall=0.823, f1=0.895, roc_auc=0.973, loss=1.94]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.93, precision=0.981, recall=0.878, f1=0.926, roc_auc=0.979, loss=1.35] 
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.85it/s, accuracy=0.93, precision=0.981, recall=0.878, f1=0.926, roc_auc=0.979, loss=1.35]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.85it/s, accuracy=0.944, precision=0.982, recall=0.904, f1=0.942, roc_auc=0.984, loss=1.04]
[Ain:   8%|▊  

Test: Loss: 0.0195 | Accuracy: 0.9845 | Precision: 0.9906 | Recall: 0.9783 | F1: 0.9844 | ROC AUC: 0.9986
Epoch [1/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.501, loss=99.9]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.501, loss=99.9]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.8]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.73it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.8]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.73it/s, accuracy=0.501, precision=1, recall=0.00195, f1=0.0039, roc_auc=0.517, loss=98.8]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.86it/s, accuracy=0.501, precision=1, recall=0.00195, f1=0.0039, roc_auc=0.517, loss=98.8]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.86it/s, accuracy=0.521, precision=0.938, recall=0.0439, f1=0.084, roc_auc=0.578, loss=87.4]
[Ain:   8%|▊         | 4/48 [00:02<00:17,  2.47it/s, accur

=> Model saved
Epoch [2/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.578, precision=1, recall=0.156, f1=0.27, roc_auc=0.98, loss=24.1]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.578, precision=1, recall=0.156, f1=0.27, roc_auc=0.98, loss=24.1]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.649, precision=0.677, recall=0.572, f1=0.62, roc_auc=0.752, loss=25.8]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.649, precision=0.677, recall=0.572, f1=0.62, roc_auc=0.752, loss=25.8]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.707, precision=0.761, recall=0.604, f1=0.674, roc_auc=0.812, loss=18.9]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.78it/s, accuracy=0.707, precision=0.761, recall=0.604, f1=0.674, roc_auc=0.812, loss=18.9]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.78it/s, accuracy=0.727, precision=0.74, recall=0.7, f1=0.719, roc_auc=0.806, loss=17.9]   
[Ain:   8%|▊   

=> Model saved
Epoch [3/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.88, precision=1, recall=0.76, f1=0.863, roc_auc=0.993, loss=2.13]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.10s/it, accuracy=0.88, precision=1, recall=0.76, f1=0.863, roc_auc=0.993, loss=2.13]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.10s/it, accuracy=0.873, precision=0.869, recall=0.878, f1=0.873, roc_auc=0.95, loss=3.54]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.77it/s, accuracy=0.873, precision=0.869, recall=0.878, f1=0.873, roc_auc=0.95, loss=3.54]
[Ain:   4%|▍         | 2/48 [00:01<00:25,  1.77it/s, accuracy=0.857, precision=0.901, recall=0.802, f1=0.848, roc_auc=0.949, loss=3.42]
[Ain:   6%|▋         | 3/48 [00:01<00:28,  1.60it/s, accuracy=0.857, precision=0.901, recall=0.802, f1=0.848, roc_auc=0.949, loss=3.42]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.60it/s, accuracy=0.799, precision=0.77, recall=0.852, f1=0.809, roc_auc=0.895, loss=9.51] 
[Ain:   8%|▊   

=> Model saved
Epoch [4/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.937, precision=0.993, recall=0.879, f1=0.933, roc_auc=0.995, loss=0.72]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.11s/it, accuracy=0.937, precision=0.993, recall=0.879, f1=0.933, roc_auc=0.995, loss=0.72]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.11s/it, accuracy=0.956, precision=0.975, recall=0.936, f1=0.955, roc_auc=0.992, loss=0.704]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.956, precision=0.975, recall=0.936, f1=0.955, roc_auc=0.992, loss=0.704]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.961, precision=0.976, recall=0.945, f1=0.96, roc_auc=0.993, loss=0.621] 
[Ain:   6%|▋         | 3/48 [00:01<00:28,  1.61it/s, accuracy=0.961, precision=0.976, recall=0.945, f1=0.96, roc_auc=0.993, loss=0.621]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.61it/s, accuracy=0.966, precision=0.978, recall=0.955, f1=0.966, roc_auc=0.995, loss=0.518

=> Model saved
Epoch [5/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.986, precision=0.99, recall=0.982, f1=0.986, roc_auc=0.999, loss=0.0942]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.986, precision=0.99, recall=0.982, f1=0.986, roc_auc=0.999, loss=0.0942]
[Ain:   2%|▏         | 1/48 [00:01<00:52,  1.12s/it, accuracy=0.989, precision=0.99, recall=0.987, f1=0.989, roc_auc=0.999, loss=0.0966]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.75it/s, accuracy=0.989, precision=0.99, recall=0.987, f1=0.989, roc_auc=0.999, loss=0.0966]
[Ain:   4%|▍         | 2/48 [00:02<00:26,  1.75it/s, accuracy=0.989, precision=0.989, recall=0.988, f1=0.989, roc_auc=0.999, loss=0.106]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.58it/s, accuracy=0.989, precision=0.989, recall=0.988, f1=0.989, roc_auc=0.999, loss=0.106]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.58it/s, accuracy=0.986, precision=0.99, recall=0.981, f1=0.986, roc_auc=0.999, loss=0.1

Epoch [6/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:02<00:26,  1.71it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:29,  1.55it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:29,  1.55it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   8%|▊         | 4/48 [00:02<00:20,  2.14it/s, accuracy=0.5, precision=0

Epoch [7/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:02<00:26,  1.71it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.57it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.57it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   8%|▊         | 4/48 [00:02<00:20,  2.16it/s, accuracy=0.5, precision=0

Epoch [8/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:02<00:26,  1.71it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:29,  1.50it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:29,  1.50it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   8%|▊         | 4/48 [00:02<00:21,  2.08it/s, accuracy=0.5, precision=0

Epoch [9/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<01:05,  1.39s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<01:05,  1.39s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:31,  1.46it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:02<00:31,  1.46it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.60it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:28,  1.60it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   8%|▊         | 4/48 [00:02<00:20,  2.20it/s, accuracy=0.5, precision=0

Epoch [10/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<01:01,  1.32s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<01:01,  1.32s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:01<00:29,  1.53it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   4%|▍         | 2/48 [00:02<00:29,  1.53it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:26,  1.67it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   6%|▋         | 3/48 [00:02<00:26,  1.67it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   8%|▊         | 4/48 [00:02<00:19,  2.27it/s, accuracy=0.5, precision=0

Test: Loss: 11.1111 | Accuracy: 0.5000 | Precision: 0.5000 | Recall: 1.0000 | F1: 0.6667 | ROC AUC: 0.5000
Epoch [1/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.8]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.19s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.504, loss=99.8]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.19s/it, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.77it/s, accuracy=0.5, precision=1, recall=0, f1=0, roc_auc=0.503, loss=99.8]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.77it/s, accuracy=0.503, precision=1, recall=0.00586, f1=0.0117, roc_auc=0.533, loss=96.9]
[Ain:   8%|▊         | 4/48 [00:02<00:18,  2.36it/s, accuracy=0.503, precision=1, reca

=> Model saved
Epoch [2/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.503, precision=0.501, recall=1, f1=0.668, roc_auc=0.53, loss=94.3]
[Ain:   2%|▏         | 1/48 [00:01<00:56,  1.20s/it, accuracy=0.503, precision=0.501, recall=1, f1=0.668, roc_auc=0.53, loss=94.3]
[Ain:   2%|▏         | 1/48 [00:01<00:56,  1.20s/it, accuracy=0.513, precision=0.507, recall=1, f1=0.673, roc_auc=0.576, loss=85.7]
[Ain:   4%|▍         | 2/48 [00:01<00:28,  1.64it/s, accuracy=0.513, precision=0.507, recall=1, f1=0.673, roc_auc=0.576, loss=85.7]
[Ain:   4%|▍         | 2/48 [00:01<00:28,  1.64it/s, accuracy=0.588, precision=0.549, recall=0.984, f1=0.705, roc_auc=0.653, loss=66.6]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.76it/s, accuracy=0.588, precision=0.549, recall=0.984, f1=0.705, roc_auc=0.653, loss=66.6]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.76it/s, accuracy=0.646, precision=0.595, recall=0.911, f1=0.72, roc_auc=0.689, loss=51.6] 
[Ain:   8%|▊       

=> Model saved
Epoch [3/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.838, precision=0.994, recall=0.68, f1=0.807, roc_auc=0.984, loss=4.71]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.18s/it, accuracy=0.838, precision=0.994, recall=0.68, f1=0.807, roc_auc=0.984, loss=4.71]
[Ain:   2%|▏         | 1/48 [00:01<00:55,  1.18s/it, accuracy=0.859, precision=0.877, recall=0.835, f1=0.855, roc_auc=0.932, loss=5.56]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.859, precision=0.877, recall=0.835, f1=0.855, roc_auc=0.932, loss=5.56]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.865, precision=0.909, recall=0.81, f1=0.857, roc_auc=0.949, loss=4.44] 
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.74it/s, accuracy=0.865, precision=0.909, recall=0.81, f1=0.857, roc_auc=0.949, loss=4.44]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.74it/s, accuracy=0.853, precision=0.851, recall=0.856, f1=0.853, roc_auc=0.933, loss=5.41]
[Ain

Epoch [4/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.667, precision=0.989, recall=0.338, f1=0.504, roc_auc=0.949, loss=17.7]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.10s/it, accuracy=0.667, precision=0.989, recall=0.338, f1=0.504, roc_auc=0.949, loss=17.7]
[Ain:   2%|▏         | 1/48 [00:01<00:51,  1.10s/it, accuracy=0.683, precision=0.688, recall=0.668, f1=0.678, roc_auc=0.77, loss=26.4] 
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.683, precision=0.688, recall=0.668, f1=0.678, roc_auc=0.77, loss=26.4]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.76it/s, accuracy=0.733, precision=0.758, recall=0.686, f1=0.72, roc_auc=0.825, loss=19.9]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.85it/s, accuracy=0.733, precision=0.758, recall=0.686, f1=0.72, roc_auc=0.825, loss=19.9]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.85it/s, accuracy=0.756, precision=0.752, recall=0.763, f1=0.757, roc_auc=0.829, loss=18.4]
[Ain

=> Model saved
Epoch [5/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.916, precision=0.975, recall=0.854, f1=0.91, roc_auc=0.985, loss=1.76]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.916, precision=0.975, recall=0.854, f1=0.91, roc_auc=0.985, loss=1.76]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.935, precision=0.962, recall=0.906, f1=0.933, roc_auc=0.986, loss=1.39]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.935, precision=0.962, recall=0.906, f1=0.933, roc_auc=0.986, loss=1.39]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.948, precision=0.966, recall=0.929, f1=0.947, roc_auc=0.989, loss=1.15]
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.81it/s, accuracy=0.948, precision=0.966, recall=0.929, f1=0.947, roc_auc=0.989, loss=1.15]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.81it/s, accuracy=0.955, precision=0.97, recall=0.94, f1=0.955, roc_auc=0.991, loss=0.934] 
[Ai

=> Model saved
Epoch [6/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.992, precision=0.988, recall=0.996, f1=0.992, roc_auc=1, loss=0.0694]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.992, precision=0.988, recall=0.996, f1=0.992, roc_auc=1, loss=0.0694]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.99, precision=0.992, recall=0.988, f1=0.99, roc_auc=0.999, loss=0.095]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.99, precision=0.992, recall=0.988, f1=0.99, roc_auc=0.999, loss=0.095]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.989, precision=0.989, recall=0.989, f1=0.989, roc_auc=0.999, loss=0.11]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.78it/s, accuracy=0.989, precision=0.989, recall=0.989, f1=0.989, roc_auc=0.999, loss=0.11]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.78it/s, accuracy=0.989, precision=0.991, recall=0.987, f1=0.989, roc_auc=0.999, loss=0.11]
[Ain:  

Epoch [7/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:56,  1.20s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=100]
[Ain:   2%|▏         | 1/48 [00:01<00:56,  1.20s/it, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.64it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=99.9]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.64it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.501, loss=99.9]
[Ain:   6%|▋         | 3/48 [00:01<00:26,  1.73it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.501, loss=99.9]
[Ain:   6%|▋         | 3/48 [00:02<00:26,  1.73it/s, accuracy=0.5, precision=0.5, recall=1, f1=0.667, roc_auc=0.5, loss=99.9]  
[Ain:   8%|▊         | 4/48 [00:02<00:18,  2.32it/s, accuracy=0.5, 

Epoch [8/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.987, precision=0.992, recall=0.982, f1=0.987, roc_auc=0.999, loss=0.112]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.987, precision=0.992, recall=0.982, f1=0.987, roc_auc=0.999, loss=0.112]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.16s/it, accuracy=0.991, precision=0.993, recall=0.989, f1=0.991, roc_auc=1, loss=0.0801]   
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.991, precision=0.993, recall=0.989, f1=0.991, roc_auc=1, loss=0.0801]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.69it/s, accuracy=0.991, precision=0.993, recall=0.988, f1=0.991, roc_auc=1, loss=0.094] 
[Ain:   6%|▋         | 3/48 [00:01<00:24,  1.84it/s, accuracy=0.991, precision=0.993, recall=0.988, f1=0.991, roc_auc=1, loss=0.094]
[Ain:   6%|▋         | 3/48 [00:02<00:24,  1.84it/s, accuracy=0.991, precision=0.992, recall=0.991, f1=0.991, roc_auc=1, loss=0.0879]
[Ain:  

=> Model saved
Epoch [9/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.996, precision=1, recall=0.992, f1=0.996, roc_auc=1, loss=0.0201]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.17s/it, accuracy=0.996, precision=1, recall=0.992, f1=0.996, roc_auc=1, loss=0.0201]
[Ain:   2%|▏         | 1/48 [00:01<00:54,  1.17s/it, accuracy=0.997, precision=0.999, recall=0.995, f1=0.997, roc_auc=1, loss=0.0219]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.997, precision=0.999, recall=0.995, f1=0.997, roc_auc=1, loss=0.0219]
[Ain:   4%|▍         | 2/48 [00:01<00:27,  1.67it/s, accuracy=0.997, precision=0.998, recall=0.995, f1=0.997, roc_auc=1, loss=0.0333]
[Ain:   6%|▋         | 3/48 [00:01<00:25,  1.74it/s, accuracy=0.997, precision=0.998, recall=0.995, f1=0.997, roc_auc=1, loss=0.0333]
[Ain:   6%|▋         | 3/48 [00:02<00:25,  1.74it/s, accuracy=0.996, precision=0.998, recall=0.995, f1=0.996, roc_auc=1, loss=0.0389]
[Ain:   8%|▊         | 

Epoch [10/10]



[Ain:   0%|          | 0/48 [00:00<?, ?it/s]
[Ain:   0%|          | 0/48 [00:01<?, ?it/s, accuracy=0.981, precision=0.964, recall=1, f1=0.982, roc_auc=1, loss=0.128]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.981, precision=0.964, recall=1, f1=0.982, roc_auc=1, loss=0.128]
[Ain:   2%|▏         | 1/48 [00:01<00:53,  1.14s/it, accuracy=0.968, precision=0.981, recall=0.954, f1=0.967, roc_auc=0.996, loss=0.408]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.968, precision=0.981, recall=0.954, f1=0.967, roc_auc=0.996, loss=0.408]
[Ain:   4%|▍         | 2/48 [00:01<00:26,  1.71it/s, accuracy=0.831, precision=0.759, recall=0.969, f1=0.851, roc_auc=0.9, loss=16]     
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.92it/s, accuracy=0.831, precision=0.759, recall=0.969, f1=0.851, roc_auc=0.9, loss=16]
[Ain:   6%|▋         | 3/48 [00:01<00:23,  1.92it/s, accuracy=0.761, precision=0.765, recall=0.753, f1=0.759, roc_auc=0.835, loss=18.3]
[Ain:   8%|▊    

Test: Loss: 0.0324 | Accuracy: 0.9839 | Precision: 0.9844 | Recall: 0.9835 | F1: 0.9839 | ROC AUC: 0.9984
{'penultimate_only': {'loss': 0.028321669793423307, 'accuracy': 0.9811197916666666, 'precision': 0.980702515177797, 'recall': 0.9815538194444444, 'f1': 0.9811279826464209, 'roc_auc': np.float64(0.996413431049865)}, 'last_2_layers': {'loss': 0.02987432038342511, 'accuracy': 0.9791666666666666, 'precision': 0.9825174825174825, 'recall': 0.9756944444444444, 'f1': 0.9790940766550522, 'roc_auc': np.float64(0.9961034751232759)}, 'last_3_layers': {'loss': 0.01948475635346071, 'accuracy': 0.9844835069444444, 'precision': 0.990551527136893, 'recall': 0.9782986111111112, 'f1': 0.984386941805874, 'roc_auc': np.float64(0.9985877378487292)}, 'last_4_layers': {'loss': 11.11111111111111, 'accuracy': 0.5, 'precision': 0.5, 'recall': 1.0, 'f1': 0.6666666666666666, 'roc_auc': np.float64(0.5)}, 'last_5_layers': {'loss': 0.03242345227871412, 'accuracy': 0.9839409722222222, 'precision': 0.9843614248479




In [17]:
results

{'penultimate_only': {'loss': 0.028321669793423307,
  'accuracy': 0.9811197916666666,
  'precision': 0.980702515177797,
  'recall': 0.9815538194444444,
  'f1': 0.9811279826464209,
  'roc_auc': np.float64(0.996413431049865)},
 'last_2_layers': {'loss': 0.02987432038342511,
  'accuracy': 0.9791666666666666,
  'precision': 0.9825174825174825,
  'recall': 0.9756944444444444,
  'f1': 0.9790940766550522,
  'roc_auc': np.float64(0.9961034751232759)},
 'last_3_layers': {'loss': 0.01948475635346071,
  'accuracy': 0.9844835069444444,
  'precision': 0.990551527136893,
  'recall': 0.9782986111111112,
  'f1': 0.984386941805874,
  'roc_auc': np.float64(0.9985877378487292)},
 'last_4_layers': {'loss': 11.11111111111111,
  'accuracy': 0.5,
  'precision': 0.5,
  'recall': 1.0,
  'f1': 0.6666666666666666,
  'roc_auc': np.float64(0.5)},
 'last_5_layers': {'loss': 0.03242345227871412,
  'accuracy': 0.9839409722222222,
  'precision': 0.9843614248479583,
  'recall': 0.9835069444444444,
  'f1': 0.98393399913