In [None]:
# Imports

import pickle
from typing import List
from tqdm import tqdm

import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

import random
import os
import torchvision
import torch

from copy import deepcopy

from google.colab import drive


drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Early Stopping

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=3, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            # self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            # self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [None]:
# Reproducibility

seed = 0

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

torch.use_deterministic_algorithms(True)

generator = torch.Generator()
generator.manual_seed(seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['NCCL_P2P_DISABLE'] = "1"


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
# Dataset from Subset

class DatasetFromSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

In [None]:
# Plotting & metrics

def plot_train_validation_results(train_losses: List, train_accuracies: List, val_losses: List, val_accuracies: List):
    plt.figure()
    plt.plot(train_losses, label='train_loss', color='red')
    plt.plot(val_losses, label='val_loss', color='blue')
    plt.legend()
    plt.show()
    plt.plot(train_accuracies, label='train_accuracy', color='red')
    plt.plot(val_accuracies, label='val_accuracy', color='blue')
    plt.legend()
    plt.show()

def print_current_epoch_stats(epoch: int, train_loss: float, train_acc: float,
                              val_loss: float, val_acc: float):
    print("""epoch %d | train_loss=%f | train_acc = %f
      | val_loss = %f | val_acc = %f""" %
          (epoch, train_loss, train_acc, val_loss, val_acc))

In [None]:
# Epoch stuff

def run_epoch(data_loader: torch.utils.data.DataLoader, training: bool,
                model: nn.Module, loss_criterion: nn.modules.loss,
                optimizer: torch.optim.Optimizer,
                device: torch.device = None):
      predictions = []
      labels = []

      epoch_loss = 0.0

      if training:
          model.train()
      else:
          model.eval()

      for batch in tqdm(data_loader):
          batch_data = batch[0]
          batch_labels = batch[1]

          if device:
              batch_data = batch_data.to(device)
              batch_labels = batch_labels.to(device)

          with torch.set_grad_enabled(training):
              output = model(batch_data)

              batch_predictions = torch.argmax(output, dim=-1)

              predictions += batch_predictions.tolist()
              labels += batch_labels.squeeze().tolist()
              loss = loss_criterion(output, batch_labels)
              epoch_loss += loss.item()

              if training:
                  loss.backward()
                  optimizer.step()
                  model.zero_grad()

      accuracy = compute_accuracy(labels, predictions)
      return epoch_loss / len(data_loader), accuracy, predictions, labels

def compute_accuracy(y: List, y_pred: List):
  return np.sum(np.array(y) == np.array(y_pred)) / len(y)

In [None]:
# EWC

class EWC(object):
      def __init__(self, model: nn.Module, data_loader: torch.utils.data.DataLoader, device=None):
          self.model = model
          self.data_loader = data_loader
          self.device = device
          self.params = {
              n: p for n, p in self.model.named_parameters() if p.requires_grad}

          self._means = {}

          self._precision_matrices = self._diag_fisher()

          for n, p in deepcopy(self.params).items():
              if self.device is not None:
                  p.data = p.data.to(self.device)
              self._means[n] = p.data

      def _diag_fisher(self):
          precision_matrices = {}
          for n, p in deepcopy(self.params).items():
              p.data.zero_()
              if self.device is not None:
                  p.data = p.data.to(self.device)
              precision_matrices[n] = p.data
          self.model.eval()

          for data, _ in tqdm(self.data_loader):
              self.model.zero_grad()

              if self.device:
                  data = data.to(self.device)

              output = self.model(data).view(1, -1)
              label = output.max(1)[1].view(-1)
              loss = F.nll_loss(F.log_softmax(output, dim=1), label)
              loss.backward()

              for n, p in self.model.named_parameters():
                  precision_matrices[n].data += p.grad.data ** 2 / \
                      len(self.data_loader)

          precision_matrices = {n: p for n, p in precision_matrices.items()}
          return precision_matrices

      def penalty(self, model: nn.Module):
          loss = 0
          for n, p in model.named_parameters():
              _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
              loss += _loss.sum()
          return loss

def ewc_train(model: nn.Module, optimizer: torch.optim, loss_criterion,
              data_loader: torch.utils.data.DataLoader, ewc: EWC,
              importance: float, device=None):
        model.train()
        epoch_loss = 0

        predictions = []
        all_labels = []

        for data, labels in tqdm(data_loader):
            if device is not None:
                data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()

            output = model(data)

            loss = loss_criterion(output, labels) + \
                importance * ewc.penalty(model)

            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

            all_labels += labels.squeeze().tolist()
            predictions += torch.argmax(output, dim=-1).tolist()

        return epoch_loss / len(data_loader), compute_accuracy(all_labels, predictions), predictions, all_labels


In [None]:
 # Define constants

start_bulk = False
start_sequential = True

my_drive_path = "./drive/MyDrive"

model_path = f"{my_drive_path}/model"

learning_rate = 1e-3

train_test_split_size = 0.85

sequential_train_test_split_size = 0.85

epochs = 50

epochs_sequential = 50

ewc_importance = 0.5

early_stopping_patience = 3

optimizer_momentum = 0.95

seq_optimizer_momentum = 0.95

batch_size = 128

out_features = 11

lr_scheduler_patience = 3

image_size = 128

num_workers = 2

number_of_years = 10

ewc_number_of_data_from_prev_datasets = 350

# it's for range() :)
number_of_years += 1

directory_path = f"{my_drive_path}/CLEAR-10-PUBLIC"

path = f"{directory_path}/labeled_images"

best_model_path = f"{my_drive_path}/best_model.ptk"

In [None]:
# Define a label map (e.g. { 1: "basketball", ... } )

label_map = {}

with open(f"{directory_path}/class_names.txt") as label_names:
        for index, label_name in enumerate(label_names):
            label_map[index] = label_name.strip()


In [None]:
# Define datasets

train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=90),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ColorJitter()
])

validation_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
])

datasets = []

for index in range(1, number_of_years, 1):
    dataset = torchvision.datasets.ImageFolder(
        f"{path}/{index}")

    # dataset = torch.utils.data.Subset(dataset, np.random.choice(len(dataset), 1000, replace=False))

    datasets.append(dataset)

concatenated_datasets = torch.utils.data.ConcatDataset(datasets)


In [None]:
# Define dataloaders

data_size = sum(len(dataset) for dataset in datasets)

train_size = int(train_test_split_size * data_size)

test_size = data_size - train_size
train_dataset, validation_dataset = torch.utils.data.random_split(
    concatenated_datasets, [train_size, test_size])
train_dataset = DatasetFromSubset(train_dataset, train_transform)
validation_dataset = DatasetFromSubset(
    validation_dataset, validation_transform)

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=num_workers,
                                                worker_init_fn=seed_worker,
                                                generator=generator,
                                                drop_last=True)

validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                                    batch_size=batch_size,
                                                    num_workers=num_workers,
                                                    worker_init_fn=seed_worker,
                                                    generator=generator,
                                                    drop_last=True)

In [None]:
# Device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# For bulk

if start_bulk:
  model = torchvision.models.resnet18(pretrained=True)

  model.to(device)

  optimizer = torch.optim.SGD(lr=learning_rate,
                              params=model.parameters(),
                              momentum=optimizer_momentum)

  loss_criterion = nn.CrossEntropyLoss()

  best_val_acc = 0

  # if os.path.exists(model_path):
  #     model.load_state_dict(torch.load(model_path))

  all_train_data = []

  for epoch in range(epochs):
      train_loss, train_acc, train_pred, train_labels = run_epoch(
          data_loader=train_dataloader,
          training=True,
          device=device,
          optimizer=optimizer,
          loss_criterion=loss_criterion,
          model=model
      )

      val_loss, val_acc, val_pred, val_labels = run_epoch(
          data_loader=validation_dataloader,
          training=False,
          device=device,
          optimizer=optimizer,
          loss_criterion=loss_criterion,
          model=model
      )

      print_current_epoch_stats(
          epoch + 1,
          train_loss,
          train_acc,
          val_loss,
          val_acc)
      
      all_train_data.append((epoch, train_loss, train_acc, val_loss, val_acc, train_pred, train_labels, val_pred, val_labels))
      
      if val_acc > best_val_acc:
        torch.save(model.state_dict(), f"{best_model_path}-bulk.pth")

        best_val_acc = val_acc

  all_validation_data = []

  for index, val_dataset in enumerate(datasets):
    val_dataset = DatasetFromSubset(val_dataset, validation_transform)

    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  drop_last=True)

    val_loss, val_acc, val_pred, val_labels = run_epoch(
        data_loader=val_dataloader,
        training=False,
        device=device,
        optimizer=optimizer,
        loss_criterion=loss_criterion,
        model=model
    )

    all_validation_data.append((index, val_loss, val_acc, val_pred, val_labels))

    print(f"Accuracy for {index + 1}: %f" % val_acc)

In [None]:
# Save stuff

if start_bulk:
  with open(f"{my_drive_path}/validation_data.pkl", "wb") as file:
    pickle.dump(all_validation_data, file)

  with open(f"{my_drive_path}/train_data.pkl", "wb") as file:
      pickle.dump(all_train_data, file)


In [None]:
# For sequential

if start_sequential:
  model_sequential = torchvision.models.resnet18(pretrained=True)
  
  model_sequential.fc = nn.Linear(model_sequential.fc.in_features, out_features)

  model_sequential.to(device)

  train_dataloaders = []
  val_dataloaders = []

  train_datasets = []

  seq_accuracy_matrix = []

  all_seq_train_data = []
  all_seq_validation_data = []


  for dataset in datasets:
      data_size = len(dataset)

      train_size = int(sequential_train_test_split_size * data_size)

      test_size = data_size - train_size

      train_dataset, validation_dataset = torch.utils.data.random_split(
          dataset, [train_size, test_size])

      train_dataset = DatasetFromSubset(train_dataset, train_transform)

      validation_dataset = DatasetFromSubset(
          validation_dataset, validation_transform)

      train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=num_workers,
                                                      worker_init_fn=seed_worker,
                                                      generator=generator,
                                                      drop_last=True)

      validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                                          batch_size=batch_size,
                                                          num_workers=num_workers,
                                                          worker_init_fn=seed_worker,
                                                          generator=generator,
                                                          drop_last=True)
      
      train_datasets.append(train_dataset)
      train_dataloaders.append(train_dataloader)
      val_dataloaders.append(validation_dataloader)

  datasets_used_so_far = []
  

  for index, (train_dataloader, train_dataset, val_dataloader) in enumerate(zip(train_dataloaders, train_datasets, val_dataloaders)):
      early_stopper = EarlyStopping(patience=early_stopping_patience, verbose=True)

      best_val_acc = 0

      optimizer = torch.optim.SGD(lr=learning_rate,
                                  params=model_sequential.parameters(),
                                  momentum=seq_optimizer_momentum)

      loss_criterion = nn.CrossEntropyLoss()

      scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
          patience=lr_scheduler_patience,
          optimizer=optimizer,
          verbose=True
      )

      datasets_used_so_far += torch.utils.data.Subset(train_dataset, np.random.choice(
          len(train_dataset), min(len(train_dataset), ewc_number_of_data_from_prev_datasets),
          replace=False))

      dataloader_data_used_so_far = torch.utils.data.DataLoader(dataset=datasets_used_so_far,
                                                                batch_size=batch_size,
                                                                num_workers=num_workers,
                                                                worker_init_fn=seed_worker,
                                                                generator=generator,
                                                                shuffle=True,
                                                                drop_last=True)

      for epoch in range(epochs_sequential):
          ewc = EWC(
              model=model_sequential,
              data_loader=dataloader_data_used_so_far,
              device=device
          )

          # XXX: EWC or normal train:

          train_loss, train_acc, train_pred, train_labels = ewc_train(model_sequential, optimizer, loss_criterion,
                                                                      train_dataloader, ewc, ewc_importance, device)

          # train_loss, train_acc, train_pred, train_labels = run_epoch(
          #     data_loader=train_dataloader,
          #     training=True,
          #     device=device,
          #     optimizer=optimizer,
          #     loss_criterion=loss_criterion,
          #     model=model_sequential,
          # )

          val_loss, val_acc, val_pred, val_labels = run_epoch(
              data_loader=val_dataloader,
              training=False,
              device=device,
              optimizer=optimizer,
              loss_criterion=loss_criterion,
              model=model_sequential)

          all_seq_train_data.append((index, epoch, train_loss, train_acc, val_loss, val_acc, train_pred, train_labels, val_pred, val_labels))

          scheduler.step(val_loss)

          print_current_epoch_stats(
              epoch + 1,
              train_loss,
              train_acc,
              val_loss,
              val_acc)

          early_stopper(val_loss, model_sequential)

          if val_acc > best_val_acc:
            torch.save(model_sequential.state_dict(), f"{best_model_path}-{index}-sequential.pth")

            best_val_acc = val_acc

          if early_stopper.early_stop:
              break

      current_accuracies = []
      current_val_data = []

      for index, val_dataloader in enumerate(val_dataloaders):
          val_loss, val_acc, val_pred, val_labels = run_epoch(
              data_loader=val_dataloader,
              training=False,
              device=device,
              optimizer=optimizer,
              loss_criterion=loss_criterion,
              model=model_sequential
          )

          accuracy = compute_accuracy(val_labels, val_pred)

          current_accuracies.append(accuracy)

          current_val_data.append((val_loss, val_acc, val_pred, val_labels))

      seq_accuracy_matrix.append(current_accuracies)
      all_seq_validation_data.append(current_val_data)

100%|██████████| 2/2 [00:00<00:00,  4.79it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 1 | train_loss=2.378750 | train_acc = 0.174107
      | val_loss = 1.951929 | val_acc = 0.335938


100%|██████████| 2/2 [00:00<00:00,  4.77it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 2 | train_loss=1.648080 | train_acc = 0.496652
      | val_loss = 1.277649 | val_acc = 0.562500


100%|██████████| 2/2 [00:00<00:00,  4.85it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 3 | train_loss=1.223155 | train_acc = 0.606399
      | val_loss = 1.035997 | val_acc = 0.648438


100%|██████████| 2/2 [00:00<00:00,  4.65it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 4 | train_loss=1.024356 | train_acc = 0.659226
      | val_loss = 0.938239 | val_acc = 0.669271


100%|██████████| 2/2 [00:00<00:00,  4.73it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 5 | train_loss=0.882747 | train_acc = 0.702381
      | val_loss = 0.869143 | val_acc = 0.700521


100%|██████████| 2/2 [00:00<00:00,  4.93it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 6 | train_loss=0.832296 | train_acc = 0.712426
      | val_loss = 0.794645 | val_acc = 0.731771


100%|██████████| 2/2 [00:00<00:00,  4.64it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 7 | train_loss=0.748758 | train_acc = 0.746280
      | val_loss = 0.792513 | val_acc = 0.739583


100%|██████████| 2/2 [00:00<00:00,  4.78it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 8 | train_loss=0.686685 | train_acc = 0.776414
      | val_loss = 0.734420 | val_acc = 0.781250


100%|██████████| 2/2 [00:00<00:00,  4.86it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 9 | train_loss=0.643578 | train_acc = 0.784598
      | val_loss = 0.735761 | val_acc = 0.765625
EarlyStopping counter: 1 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.91it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 10 | train_loss=0.590156 | train_acc = 0.803199
      | val_loss = 0.706288 | val_acc = 0.776042


100%|██████████| 2/2 [00:00<00:00,  4.67it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 11 | train_loss=0.564105 | train_acc = 0.816964
      | val_loss = 0.702657 | val_acc = 0.778646


100%|██████████| 2/2 [00:00<00:00,  4.90it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 12 | train_loss=0.513204 | train_acc = 0.828869
      | val_loss = 0.659683 | val_acc = 0.799479


100%|██████████| 2/2 [00:00<00:00,  4.88it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 13 | train_loss=0.488025 | train_acc = 0.837054
      | val_loss = 0.654790 | val_acc = 0.794271


100%|██████████| 2/2 [00:00<00:00,  4.69it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 14 | train_loss=0.465558 | train_acc = 0.850074
      | val_loss = 0.664944 | val_acc = 0.781250
EarlyStopping counter: 1 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.84it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 15 | train_loss=0.452024 | train_acc = 0.853795
      | val_loss = 0.623830 | val_acc = 0.796875


100%|██████████| 2/2 [00:00<00:00,  4.78it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 16 | train_loss=0.421630 | train_acc = 0.864955
      | val_loss = 0.622578 | val_acc = 0.783854


100%|██████████| 2/2 [00:00<00:00,  4.88it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 17 | train_loss=0.410438 | train_acc = 0.866443
      | val_loss = 0.623105 | val_acc = 0.799479
EarlyStopping counter: 1 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.71it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 18 | train_loss=0.378820 | train_acc = 0.876860
      | val_loss = 0.617941 | val_acc = 0.809896


100%|██████████| 2/2 [00:00<00:00,  4.82it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 19 | train_loss=0.367577 | train_acc = 0.882068
      | val_loss = 0.635832 | val_acc = 0.804688
EarlyStopping counter: 1 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.65it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.12it/s]


epoch 20 | train_loss=0.356191 | train_acc = 0.883929
      | val_loss = 0.608669 | val_acc = 0.807292


100%|██████████| 2/2 [00:00<00:00,  4.70it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 21 | train_loss=0.314602 | train_acc = 0.891369
      | val_loss = 0.618480 | val_acc = 0.804688
EarlyStopping counter: 1 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.84it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 22 | train_loss=0.314766 | train_acc = 0.897321
      | val_loss = 0.622236 | val_acc = 0.815104
EarlyStopping counter: 2 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.87it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 23 | train_loss=0.300438 | train_acc = 0.901042
      | val_loss = 0.604636 | val_acc = 0.815104


100%|██████████| 2/2 [00:00<00:00,  4.69it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 24 | train_loss=0.274314 | train_acc = 0.911086
      | val_loss = 0.588472 | val_acc = 0.817708


100%|██████████| 2/2 [00:00<00:00,  4.72it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 25 | train_loss=0.254574 | train_acc = 0.920015
      | val_loss = 0.587172 | val_acc = 0.822917


100%|██████████| 2/2 [00:00<00:00,  4.53it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 26 | train_loss=0.242752 | train_acc = 0.921875
      | val_loss = 0.584723 | val_acc = 0.815104


100%|██████████| 2/2 [00:00<00:00,  5.01it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 27 | train_loss=0.226809 | train_acc = 0.929315
      | val_loss = 0.604173 | val_acc = 0.802083
EarlyStopping counter: 1 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.64it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 28 | train_loss=0.221552 | train_acc = 0.927827
      | val_loss = 0.610894 | val_acc = 0.807292
EarlyStopping counter: 2 out of 3


100%|██████████| 2/2 [00:00<00:00,  4.60it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]


epoch 29 | train_loss=0.211081 | train_acc = 0.939360
      | val_loss = 0.627829 | val_acc = 0.807292
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.14it/s]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.12it/s]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]
100%|██████████| 3/3 [00:02<00:00,  1.12it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 5/5 [00:00<00:00,  5.98it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 1 | train_loss=0.799271 | train_acc = 0.756324
      | val_loss = 0.503792 | val_acc = 0.830729


100%|██████████| 5/5 [00:00<00:00,  6.15it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 2 | train_loss=0.673906 | train_acc = 0.785342
      | val_loss = 0.473403 | val_acc = 0.838542


100%|██████████| 5/5 [00:00<00:00,  6.21it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 3 | train_loss=0.593554 | train_acc = 0.800967
      | val_loss = 0.453395 | val_acc = 0.856771


100%|██████████| 5/5 [00:00<00:00,  6.02it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 4 | train_loss=0.498950 | train_acc = 0.834449
      | val_loss = 0.439061 | val_acc = 0.848958


100%|██████████| 5/5 [00:00<00:00,  6.14it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 5 | train_loss=0.452719 | train_acc = 0.849702
      | val_loss = 0.454410 | val_acc = 0.848958
EarlyStopping counter: 1 out of 3


100%|██████████| 5/5 [00:00<00:00,  6.18it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 6 | train_loss=0.438029 | train_acc = 0.857887
      | val_loss = 0.443088 | val_acc = 0.856771
EarlyStopping counter: 2 out of 3


100%|██████████| 5/5 [00:00<00:00,  6.01it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 7 | train_loss=0.391921 | train_acc = 0.873140
      | val_loss = 0.443914 | val_acc = 0.848958
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 8/8 [00:01<00:00,  6.77it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 1 | train_loss=0.728923 | train_acc = 0.765625
      | val_loss = 0.562130 | val_acc = 0.822917


100%|██████████| 8/8 [00:01<00:00,  6.82it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 2 | train_loss=0.652125 | train_acc = 0.786830
      | val_loss = 0.538742 | val_acc = 0.830729


100%|██████████| 8/8 [00:01<00:00,  6.80it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 3 | train_loss=0.562548 | train_acc = 0.815104
      | val_loss = 0.499562 | val_acc = 0.851562


100%|██████████| 8/8 [00:01<00:00,  6.69it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 4 | train_loss=0.500833 | train_acc = 0.829613
      | val_loss = 0.482755 | val_acc = 0.856771


100%|██████████| 8/8 [00:01<00:00,  6.88it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 5 | train_loss=0.452751 | train_acc = 0.851562
      | val_loss = 0.491164 | val_acc = 0.828125
EarlyStopping counter: 1 out of 3


100%|██████████| 8/8 [00:01<00:00,  6.85it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 6 | train_loss=0.410638 | train_acc = 0.862723
      | val_loss = 0.480291 | val_acc = 0.846354


100%|██████████| 8/8 [00:01<00:00,  6.91it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 7 | train_loss=0.403989 | train_acc = 0.862723
      | val_loss = 0.474392 | val_acc = 0.841146


100%|██████████| 8/8 [00:01<00:00,  6.88it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 8 | train_loss=0.385235 | train_acc = 0.873884
      | val_loss = 0.480691 | val_acc = 0.854167
EarlyStopping counter: 1 out of 3


100%|██████████| 8/8 [00:01<00:00,  6.89it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 9 | train_loss=0.358670 | train_acc = 0.880952
      | val_loss = 0.470242 | val_acc = 0.843750


100%|██████████| 8/8 [00:01<00:00,  6.76it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 10 | train_loss=0.322256 | train_acc = 0.899182
      | val_loss = 0.478885 | val_acc = 0.851562
EarlyStopping counter: 1 out of 3


100%|██████████| 8/8 [00:01<00:00,  6.82it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 11 | train_loss=0.314502 | train_acc = 0.895089
      | val_loss = 0.482765 | val_acc = 0.856771
EarlyStopping counter: 2 out of 3


100%|██████████| 8/8 [00:01<00:00,  6.94it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]


epoch 12 | train_loss=0.303561 | train_acc = 0.899182
      | val_loss = 0.478104 | val_acc = 0.841146
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 10/10 [00:01<00:00,  7.11it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 1 | train_loss=0.610887 | train_acc = 0.808408
      | val_loss = 0.397734 | val_acc = 0.885417


100%|██████████| 10/10 [00:01<00:00,  6.91it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]


epoch 2 | train_loss=0.553142 | train_acc = 0.819568
      | val_loss = 0.396701 | val_acc = 0.888021


100%|██████████| 10/10 [00:01<00:00,  7.14it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 3 | train_loss=0.494532 | train_acc = 0.834449
      | val_loss = 0.393937 | val_acc = 0.888021


100%|██████████| 10/10 [00:01<00:00,  7.06it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 4 | train_loss=0.420367 | train_acc = 0.860491
      | val_loss = 0.407731 | val_acc = 0.890625
EarlyStopping counter: 1 out of 3


100%|██████████| 10/10 [00:01<00:00,  7.17it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 5 | train_loss=0.387798 | train_acc = 0.873512
      | val_loss = 0.403152 | val_acc = 0.875000
EarlyStopping counter: 2 out of 3


100%|██████████| 10/10 [00:01<00:00,  7.08it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


epoch 6 | train_loss=0.364613 | train_acc = 0.873884
      | val_loss = 0.376175 | val_acc = 0.888021


100%|██████████| 10/10 [00:01<00:00,  7.09it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 7 | train_loss=0.337044 | train_acc = 0.886905
      | val_loss = 0.381824 | val_acc = 0.888021
EarlyStopping counter: 1 out of 3


100%|██████████| 10/10 [00:01<00:00,  7.10it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 8 | train_loss=0.323455 | train_acc = 0.890625
      | val_loss = 0.389877 | val_acc = 0.875000
EarlyStopping counter: 2 out of 3


100%|██████████| 10/10 [00:01<00:00,  7.20it/s]
100%|██████████| 21/21 [00:21<00:00,  1.00s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 9 | train_loss=0.280649 | train_acc = 0.910714
      | val_loss = 0.411411 | val_acc = 0.882812
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.20it/s]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 13/13 [00:01<00:00,  7.27it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 1 | train_loss=0.659763 | train_acc = 0.795759
      | val_loss = 0.517609 | val_acc = 0.815104


100%|██████████| 13/13 [00:01<00:00,  7.34it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 2 | train_loss=0.541582 | train_acc = 0.820312
      | val_loss = 0.500043 | val_acc = 0.825521


100%|██████████| 13/13 [00:01<00:00,  7.30it/s]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 3 | train_loss=0.498773 | train_acc = 0.832589
      | val_loss = 0.500947 | val_acc = 0.820312
EarlyStopping counter: 1 out of 3


100%|██████████| 13/13 [00:01<00:00,  7.34it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]


epoch 4 | train_loss=0.436240 | train_acc = 0.850074
      | val_loss = 0.493592 | val_acc = 0.843750


100%|██████████| 13/13 [00:01<00:00,  7.41it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 5 | train_loss=0.402707 | train_acc = 0.862723
      | val_loss = 0.478466 | val_acc = 0.856771


100%|██████████| 13/13 [00:01<00:00,  7.32it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 6 | train_loss=0.359173 | train_acc = 0.877604
      | val_loss = 0.466295 | val_acc = 0.848958


100%|██████████| 13/13 [00:01<00:00,  7.28it/s]
100%|██████████| 21/21 [00:21<00:00,  1.00s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 7 | train_loss=0.347430 | train_acc = 0.881696
      | val_loss = 0.467664 | val_acc = 0.838542
EarlyStopping counter: 1 out of 3


100%|██████████| 13/13 [00:01<00:00,  7.35it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 8 | train_loss=0.335013 | train_acc = 0.892485
      | val_loss = 0.491787 | val_acc = 0.851562
EarlyStopping counter: 2 out of 3


100%|██████████| 13/13 [00:01<00:00,  7.29it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 9 | train_loss=0.300125 | train_acc = 0.896577
      | val_loss = 0.490507 | val_acc = 0.841146
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 16/16 [00:02<00:00,  7.51it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 1 | train_loss=0.600982 | train_acc = 0.804688
      | val_loss = 0.406953 | val_acc = 0.867188


100%|██████████| 16/16 [00:02<00:00,  7.52it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 2 | train_loss=0.531450 | train_acc = 0.833333
      | val_loss = 0.378603 | val_acc = 0.869792


100%|██████████| 16/16 [00:02<00:00,  7.51it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 3 | train_loss=0.491297 | train_acc = 0.837054
      | val_loss = 0.392130 | val_acc = 0.875000
EarlyStopping counter: 1 out of 3


100%|██████████| 16/16 [00:02<00:00,  7.48it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 4 | train_loss=0.447748 | train_acc = 0.848586
      | val_loss = 0.384817 | val_acc = 0.867188
EarlyStopping counter: 2 out of 3


100%|██████████| 16/16 [00:02<00:00,  7.50it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 5 | train_loss=0.412297 | train_acc = 0.856027
      | val_loss = 0.410745 | val_acc = 0.864583
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.20it/s]
100%|██████████| 19/19 [00:02<00:00,  7.59it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 1 | train_loss=0.617538 | train_acc = 0.794271
      | val_loss = 0.429393 | val_acc = 0.872396


100%|██████████| 19/19 [00:02<00:00,  7.59it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 2 | train_loss=0.562375 | train_acc = 0.815848
      | val_loss = 0.426685 | val_acc = 0.877604


100%|██████████| 19/19 [00:02<00:00,  7.58it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 3 | train_loss=0.496890 | train_acc = 0.829241
      | val_loss = 0.424934 | val_acc = 0.867188


100%|██████████| 19/19 [00:02<00:00,  7.60it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 4 | train_loss=0.448860 | train_acc = 0.850818
      | val_loss = 0.406538 | val_acc = 0.877604


100%|██████████| 19/19 [00:02<00:00,  7.62it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 5 | train_loss=0.412709 | train_acc = 0.864583
      | val_loss = 0.415749 | val_acc = 0.867188
EarlyStopping counter: 1 out of 3


100%|██████████| 19/19 [00:02<00:00,  7.58it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 6 | train_loss=0.396672 | train_acc = 0.870536
      | val_loss = 0.424268 | val_acc = 0.856771
EarlyStopping counter: 2 out of 3


100%|██████████| 19/19 [00:02<00:00,  7.61it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 7 | train_loss=0.352202 | train_acc = 0.889137
      | val_loss = 0.437757 | val_acc = 0.848958
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.13it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 21/21 [00:02<00:00,  7.65it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 1 | train_loss=0.593774 | train_acc = 0.800595
      | val_loss = 0.498365 | val_acc = 0.835938


100%|██████████| 21/21 [00:02<00:00,  7.62it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 2 | train_loss=0.540901 | train_acc = 0.819940
      | val_loss = 0.474510 | val_acc = 0.854167


100%|██████████| 21/21 [00:02<00:00,  7.72it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 3 | train_loss=0.493169 | train_acc = 0.833333
      | val_loss = 0.465571 | val_acc = 0.848958


100%|██████████| 21/21 [00:02<00:00,  7.65it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]


epoch 4 | train_loss=0.448136 | train_acc = 0.844122
      | val_loss = 0.460969 | val_acc = 0.851562


100%|██████████| 21/21 [00:02<00:00,  7.67it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 5 | train_loss=0.403989 | train_acc = 0.860863
      | val_loss = 0.469246 | val_acc = 0.841146
EarlyStopping counter: 1 out of 3


100%|██████████| 21/21 [00:02<00:00,  7.70it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 6 | train_loss=0.383625 | train_acc = 0.870536
      | val_loss = 0.466152 | val_acc = 0.848958
EarlyStopping counter: 2 out of 3


100%|██████████| 21/21 [00:02<00:00,  7.62it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 7 | train_loss=0.367370 | train_acc = 0.873884
      | val_loss = 0.473005 | val_acc = 0.848958
EarlyStopping counter: 3 out of 3


100%|██████████| 3/3 [00:02<00:00,  1.13it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.20it/s]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
100%|██████████| 3/3 [00:02<00:00,  1.14it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]
100%|██████████| 24/24 [00:03<00:00,  7.79it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 1 | train_loss=0.556688 | train_acc = 0.808780
      | val_loss = 0.449499 | val_acc = 0.867188


100%|██████████| 24/24 [00:03<00:00,  7.75it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 2 | train_loss=0.475091 | train_acc = 0.833705
      | val_loss = 0.437753 | val_acc = 0.854167


100%|██████████| 24/24 [00:03<00:00,  7.74it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


epoch 3 | train_loss=0.443278 | train_acc = 0.855283
      | val_loss = 0.422295 | val_acc = 0.877604


100%|██████████| 24/24 [00:03<00:00,  7.75it/s]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 4 | train_loss=0.407465 | train_acc = 0.861607
      | val_loss = 0.424264 | val_acc = 0.875000
EarlyStopping counter: 1 out of 3


100%|██████████| 24/24 [00:03<00:00,  7.72it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 5 | train_loss=0.362902 | train_acc = 0.876116
      | val_loss = 0.420659 | val_acc = 0.872396


100%|██████████| 24/24 [00:03<00:00,  7.73it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 6 | train_loss=0.348068 | train_acc = 0.881696
      | val_loss = 0.418608 | val_acc = 0.872396


100%|██████████| 24/24 [00:03<00:00,  7.69it/s]
100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


epoch 7 | train_loss=0.324921 | train_acc = 0.895089
      | val_loss = 0.434652 | val_acc = 0.859375
EarlyStopping counter: 1 out of 3


100%|██████████| 24/24 [00:03<00:00,  7.76it/s]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 3/3 [00:02<00:00,  1.16it/s]


epoch 8 | train_loss=0.310876 | train_acc = 0.894717
      | val_loss = 0.437840 | val_acc = 0.861979
EarlyStopping counter: 2 out of 3


100%|██████████| 24/24 [00:03<00:00,  7.65it/s]
 76%|███████▌  | 16/21 [00:16<00:04,  1.17it/s]

In [None]:
# Save stuff

if start_sequential:
  with open(f"{my_drive_path}/accuracy_matrix-seq.pkl", "wb") as file:
      pickle.dump(seq_accuracy_matrix, file)

  with open(f"{my_drive_path}/validation_data-seq.pkl", "wb") as file:
      pickle.dump(all_seq_validation_data, file)

  with open(f"{my_drive_path}/train_data-seq.pkl", "wb") as file:
      pickle.dump(all_seq_train_data, file)