<a href="https://colab.research.google.com/github/yeungjosh/resnet-perturbations-riselab/blob/master/pytorch_resnet_mnist_frank_wolfe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ResNet for MNIST in PyTorch

In [1]:
!pip install --upgrade --force-reinstall --quiet git+https://github.com/ZIB-IOL/StochasticFrankWolfe.git@arXiv-2010.07243v2
!pip install --quiet barbar

  Building wheel for frankwolfe-IOL (setup.py) ... [?25l[?25hdone


In [2]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time

import torch

from torch import nn, optim
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from copy import copy, deepcopy
import numpy as np

import frankwolfe.pytorch as fw


In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [4]:
def save_last_model(input_model):
  model_save_name = 'resnet_frank_wolfe_last.pkl'
  path = F"/content/gdrive/My Drive/{model_save_name}" 
  torch.save(input_model, path)

In [5]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [15]:

def getData(name='cifar10', train_bs=128, test_bs=1000):    
    if name == 'svhn':
        train_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='extra', download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='test', download=True,transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=test_bs, shuffle=False)
    if name == 'mnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'emnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=True, download=True, split='balanced',
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=False, split='balanced', transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'cifar10':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'cifar100':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR100(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'tinyimagenet':      
        normalize = transforms.Normalize(mean=[0.44785526394844055, 0.41693055629730225, 0.36942949891090393],
                                     std=[0.2928885519504547, 0.28230994939804077, 0.2889912724494934])
        train_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/train',
        transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, pin_memory=False)
        test_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/val',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_bs, shuffle=False)
    return train_loader, test_loader

In [16]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [17]:
class RetractionLR(torch.optim.lr_scheduler._LRScheduler):
    """
    Retracts the learning rate as follows. Two running averages are kept, one of length n_close, one of n_far. Adjust
    the learning_rate depending on the relation of far_average and close_average. Decrease by 1-retraction_factor.
    Increase by 1/(1 - retraction_factor*growth_factor)
    """
    def __init__(self, optimizer, retraction_factor=0.3, n_close=5, n_far=10, lowerBound=1e-5, upperBound=1, growth_factor=0.2, last_epoch=-1):
        self.retraction_factor = retraction_factor
        self.n_close = n_close
        self.n_far = n_far
        self.lowerBound = lowerBound
        self.upperBound = upperBound
        self.growth_factor = growth_factor

        assert (0 <= self.retraction_factor < 1), "Retraction factor must be in [0, 1[."
        assert (0 <= self.lowerBound < self.upperBound <= 1), "Bounds must be in [0, 1]"
        assert (0 < self.growth_factor <= 1), "Growth factor must be in ]0, 1]"

        self.closeAverage = RunningAverage(self.n_close)
        self.farAverage = RunningAverage(self.n_far)

        super(RetractionLR, self).__init__(optimizer, last_epoch)

    def update_averages(self, loss):
        self.closeAverage(loss)
        self.farAverage(loss)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        factor = 1
        if self.farAverage.is_complete() and self.closeAverage.is_complete():
            if self.closeAverage.result() > self.farAverage.result():
                # Decrease the learning rate
                factor = 1 - self.retraction_factor
            elif self.farAverage.result() > self.closeAverage.result():
                # Increase the learning rate
                factor = 1./(1 - self.retraction_factor*self.growth_factor)

        return [max(self.lowerBound, min(factor * group['lr'], self.upperBound)) for group in self.optimizer.param_groups]

class RunningAverage(object):
    """Tracks the running average of n numbers"""
    def __init__(self, n):
        self.n = n
        self.reset()

    def reset(self):
        self.sum = 0
        self.avg = 0
        self.entries = []

    def result(self):
        return self.avg

    def get_count(self):
        return len(self.entries)

    def is_complete(self):
        return len(self.entries) == self.n

    def __call__(self, val):
        if len(self.entries) == self.n:
            l = self.entries.pop(0)
            self.sum -= l
        self.entries.append(val)
        self.sum += val
        self.avg = self.sum / len(self.entries)

    def __str__(self):
        return str(self.avg)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def result(self):
        return self.avg

    def __call__(self, val, n=1):
        """val is an average over n samples. To compute the overall average, add val*n to sum and increase count by n"""
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        return str(self.avg)
        

In [18]:
# Init model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MnistResNet().to(device)

In [19]:
#@title Choosing Lp-Norm constraints
#@markdown The following cell allows you to set Lp-norm constraints for the chosen network. For exact parameters both for the constraints and the optimizer see the last cell of this notebook.
ord =  "2" #@param [1, 2, 5, 'inf']
ord = float(ord)
value = 10 #@param {type:"number"}
mode = 'initialization' #@param ['initialization', 'radius', 'diameter']

assert value > 0

# Select constraints
constraints = fw.constraints.create_lp_constraints(model, ord=ord, value=value, mode=mode)

In [20]:
#@title Configuring the Frank-Wolfe Algorithm
#@markdown Choose momentum and learning rate rescaling, see Section 3.1 of [arXiv:2010.07243](https://arxiv.org/pdf/2010.07243.pdf).
momentum = 0.9 #@param {type:"number"}
rescale = 'gradient' #@param ['gradient', 'diameter', 'None']
rescale = None if rescale == 'None' else rescale

#@markdown Choose a learning rate for SFW. You can activate the learning rate scheduler which automatically multiplies the current learning rate by `lr_decrease_factor` every `lr_step_size epochs`
learning_rate = 0.1 #@param {type:"number"}
lr_scheduler_active = True #@param {type:"boolean"}
lr_decrease_factor = 0.1 #@param {type:"number"}
lr_step_size = 30 #@param {type:"integer"}

#@markdown You can also enable retraction of the learning rate, i.e., if enabled the learning rate is increased and decreased automatically depending on the two moving averages of different length of the train loss over the epochs.
retraction = True #@param {type:"boolean"}

assert learning_rate > 0
assert 0 <= momentum <= 1
assert lr_decrease_factor > 0
assert lr_step_size > 0


# Select optimizer
optimizer = fw.optimizers.SFW(params=model.parameters(), learning_rate=learning_rate, momentum=momentum, rescale=rescale)

In [23]:
from six.moves import urllib    
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

In [24]:
start_ts = time.time()


epochs = 70

# train_loader, val_loader = get_data_loaders(256, 256)
train_loader, val_loader = getData(name='mnist', train_bs=128, test_bs=1000)


losses = []
loss_function = nn.CrossEntropyLoss()

# f_w
# initialize some necessary metrics objects
train_loss, train_accuracy = AverageMeter(), AverageMeter()
test_loss, test_accuracy = AverageMeter(), AverageMeter()

if lr_scheduler_active:
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=lr_step_size, gamma=lr_decrease_factor)

if retraction:
    retractionScheduler = RetractionLR(optimizer=optimizer)

# function to reset metrics
def reset_metrics():
    train_loss.reset()
    train_accuracy.reset()

    test_loss.reset()
    test_accuracy.reset()

batches = len(train_loader)
val_batches = len(val_loader)
# keep best model
accuracies=[]
best_accuracy = 0
best_model = deepcopy(model)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    # # lr decay
    # optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=0.1, decayEpoch=[15])

    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step(constraints=constraints)
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))

        # train_loss(loss.item(), len(y))
        # train_accuracy(Utilities.categorical_accuracy(y_true=y_target, output=output), len(y_target))

    if lr_scheduler_active:
        scheduler.step()
    # if retraction:
    #     # Learning rate retraction
    #     retractionScheduler.update_averages(train_loss.result())
    #     retractionScheduler.step()    
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    
    current_model_accuracy = sum(accuracy)/val_batches
    accuracies.append(current_model_accuracy)
    if current_model_accuracy > best_accuracy:
        best_model = deepcopy(model)
        best_accuracy=current_model_accuracy
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
    print('current_model_accuracy: ',current_model_accuracy)
    print('best_accuracy: ',best_accuracy)

save_last_model(model)
model_save_name = 'resnet_mnist_frank_wolfe_L2.pkl'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save(best_model, path)

print(losses)
print(f"Training time: {time.time()-start_ts}s")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 1/70, training loss: 1.510880936691756, validation loss: 1.483991026878357
	     precision: 0.9791
	        recall: 0.9788
	            F1: 0.9787
	      accuracy: 0.9789
current_model_accuracy:  0.9789
best_accuracy:  0.9789


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 2/70, training loss: 1.4806050934008699, validation loss: 1.478159785270691
	     precision: 0.9837
	        recall: 0.9832
	            F1: 0.9833
	      accuracy: 0.9834
current_model_accuracy:  0.9833999999999999
best_accuracy:  0.9833999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 3/70, training loss: 1.477046595707631, validation loss: 1.4797176122665405
	     precision: 0.9827
	        recall: 0.9819
	            F1: 0.9821
	      accuracy: 0.9819
current_model_accuracy:  0.9819000000000001
best_accuracy:  0.9833999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 4/70, training loss: 1.4736327855571756, validation loss: 1.4761543273925781
	     precision: 0.9854
	        recall: 0.9851
	            F1: 0.9850
	      accuracy: 0.9851
current_model_accuracy:  0.9850999999999999
best_accuracy:  0.9850999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 5/70, training loss: 1.4720042898202501, validation loss: 1.474468469619751
	     precision: 0.9874
	        recall: 0.9873
	            F1: 0.9872
	      accuracy: 0.9873
current_model_accuracy:  0.9873000000000001
best_accuracy:  0.9873000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 6/70, training loss: 1.4715293981373183, validation loss: 1.4720357656478882
	     precision: 0.9895
	        recall: 0.9891
	            F1: 0.9892
	      accuracy: 0.9892
current_model_accuracy:  0.9892000000000001
best_accuracy:  0.9892000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 7/70, training loss: 1.4713689790351558, validation loss: 1.4728339910507202
	     precision: 0.9885
	        recall: 0.9881
	            F1: 0.9881
	      accuracy: 0.9883
current_model_accuracy:  0.9883
best_accuracy:  0.9892000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 8/70, training loss: 1.4698607265822161, validation loss: 1.4707252979278564
	     precision: 0.9910
	        recall: 0.9910
	            F1: 0.9909
	      accuracy: 0.9910
current_model_accuracy:  0.991
best_accuracy:  0.991


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 9/70, training loss: 1.468794769569755, validation loss: 1.4706085920333862
	     precision: 0.9908
	        recall: 0.9908
	            F1: 0.9907
	      accuracy: 0.9907
current_model_accuracy:  0.9907
best_accuracy:  0.991


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 10/70, training loss: 1.4681079758764075, validation loss: 1.4811720848083496
	     precision: 0.9810
	        recall: 0.9804
	            F1: 0.9800
	      accuracy: 0.9804
current_model_accuracy:  0.9804
best_accuracy:  0.991


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 11/70, training loss: 1.4685766666428621, validation loss: 1.4705262184143066
	     precision: 0.9907
	        recall: 0.9906
	            F1: 0.9906
	      accuracy: 0.9906
current_model_accuracy:  0.9906
best_accuracy:  0.991


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 12/70, training loss: 1.4673032420022147, validation loss: 1.4712214469909668
	     precision: 0.9900
	        recall: 0.9900
	            F1: 0.9899
	      accuracy: 0.9898
current_model_accuracy:  0.9898000000000001
best_accuracy:  0.991


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 13/70, training loss: 1.4670141259236122, validation loss: 1.4682978391647339
	     precision: 0.9933
	        recall: 0.9933
	            F1: 0.9932
	      accuracy: 0.9933
current_model_accuracy:  0.9933
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 14/70, training loss: 1.4671001828301435, validation loss: 1.475683569908142
	     precision: 0.9862
	        recall: 0.9858
	            F1: 0.9858
	      accuracy: 0.9859
current_model_accuracy:  0.9859
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 15/70, training loss: 1.4678839760294347, validation loss: 1.4697831869125366
	     precision: 0.9920
	        recall: 0.9918
	            F1: 0.9919
	      accuracy: 0.9918
current_model_accuracy:  0.9917999999999999
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 16/70, training loss: 1.467451391443769, validation loss: 1.4707502126693726
	     precision: 0.9904
	        recall: 0.9903
	            F1: 0.9903
	      accuracy: 0.9903
current_model_accuracy:  0.9903000000000002
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 17/70, training loss: 1.4675382914573654, validation loss: 1.4694408178329468
	     precision: 0.9923
	        recall: 0.9921
	            F1: 0.9921
	      accuracy: 0.9922
current_model_accuracy:  0.9921999999999999
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 18/70, training loss: 1.4675519659575114, validation loss: 1.4703984260559082
	     precision: 0.9912
	        recall: 0.9909
	            F1: 0.9910
	      accuracy: 0.9910
current_model_accuracy:  0.991
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 19/70, training loss: 1.4666428037289618, validation loss: 1.4691704511642456
	     precision: 0.9925
	        recall: 0.9926
	            F1: 0.9925
	      accuracy: 0.9925
current_model_accuracy:  0.9924999999999999
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 20/70, training loss: 1.4661514380339111, validation loss: 1.4729341268539429
	     precision: 0.9893
	        recall: 0.9889
	            F1: 0.9889
	      accuracy: 0.9889
current_model_accuracy:  0.9888999999999999
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 21/70, training loss: 1.4676608985929347, validation loss: 1.4710664749145508
	     precision: 0.9905
	        recall: 0.9903
	            F1: 0.9904
	      accuracy: 0.9903
current_model_accuracy:  0.9902999999999998
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 22/70, training loss: 1.466084392594376, validation loss: 1.470682144165039
	     precision: 0.9911
	        recall: 0.9908
	            F1: 0.9909
	      accuracy: 0.9909
current_model_accuracy:  0.9909000000000001
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 23/70, training loss: 1.4661129481756865, validation loss: 1.4701919555664062
	     precision: 0.9914
	        recall: 0.9911
	            F1: 0.9912
	      accuracy: 0.9912
current_model_accuracy:  0.9912000000000001
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 24/70, training loss: 1.466041695842865, validation loss: 1.4696604013442993
	     precision: 0.9916
	        recall: 0.9916
	            F1: 0.9915
	      accuracy: 0.9916
current_model_accuracy:  0.9916
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 25/70, training loss: 1.4667344825354212, validation loss: 1.4692342281341553
	     precision: 0.9921
	        recall: 0.9920
	            F1: 0.9920
	      accuracy: 0.9921
current_model_accuracy:  0.9921
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 26/70, training loss: 1.46620781335241, validation loss: 1.4698046445846558
	     precision: 0.9914
	        recall: 0.9913
	            F1: 0.9913
	      accuracy: 0.9913
current_model_accuracy:  0.9913000000000001
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 27/70, training loss: 1.4662178531130239, validation loss: 1.470506191253662
	     precision: 0.9911
	        recall: 0.9910
	            F1: 0.9910
	      accuracy: 0.9911
current_model_accuracy:  0.9911000000000001
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 28/70, training loss: 1.4654675292562065, validation loss: 1.4696149826049805
	     precision: 0.9920
	        recall: 0.9919
	            F1: 0.9919
	      accuracy: 0.9919
current_model_accuracy:  0.9919
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 29/70, training loss: 1.467111259889501, validation loss: 1.4706517457962036
	     precision: 0.9911
	        recall: 0.9911
	            F1: 0.9910
	      accuracy: 0.9910
current_model_accuracy:  0.991
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 30/70, training loss: 1.466518433363453, validation loss: 1.4755173921585083
	     precision: 0.9860
	        recall: 0.9860
	            F1: 0.9858
	      accuracy: 0.9860
current_model_accuracy:  0.986
best_accuracy:  0.9933


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 31/70, training loss: 1.4643576841618715, validation loss: 1.4670482873916626
	     precision: 0.9944
	        recall: 0.9943
	            F1: 0.9943
	      accuracy: 0.9943
current_model_accuracy:  0.9943000000000002
best_accuracy:  0.9943000000000002


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 32/70, training loss: 1.4630393399866914, validation loss: 1.4667377471923828
	     precision: 0.9945
	        recall: 0.9942
	            F1: 0.9943
	      accuracy: 0.9943
current_model_accuracy:  0.9943
best_accuracy:  0.9943000000000002


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 33/70, training loss: 1.4626762358618697, validation loss: 1.4665720462799072
	     precision: 0.9948
	        recall: 0.9946
	            F1: 0.9947
	      accuracy: 0.9947
current_model_accuracy:  0.9947000000000001
best_accuracy:  0.9947000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 34/70, training loss: 1.462537113029057, validation loss: 1.4664216041564941
	     precision: 0.9948
	        recall: 0.9947
	            F1: 0.9947
	      accuracy: 0.9947
current_model_accuracy:  0.9947000000000001
best_accuracy:  0.9947000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 35/70, training loss: 1.462368478144664, validation loss: 1.4662379026412964
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9949999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 36/70, training loss: 1.4623686934330824, validation loss: 1.4662574529647827
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9949999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 37/70, training loss: 1.4622848692224986, validation loss: 1.4664751291275024
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9948
	      accuracy: 0.9948
current_model_accuracy:  0.9948
best_accuracy:  0.9949999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 38/70, training loss: 1.4622223623780046, validation loss: 1.466382622718811
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9948
	      accuracy: 0.9948
current_model_accuracy:  0.9948
best_accuracy:  0.9949999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 39/70, training loss: 1.462124056907605, validation loss: 1.4662939310073853
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9950000000000001
best_accuracy:  0.9950000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 40/70, training loss: 1.4621377837683347, validation loss: 1.46611487865448
	     precision: 0.9952
	        recall: 0.9951
	            F1: 0.9951
	      accuracy: 0.9951
current_model_accuracy:  0.9951000000000001
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 41/70, training loss: 1.4620820728700552, validation loss: 1.466281533241272
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 42/70, training loss: 1.4620740502627927, validation loss: 1.4662268161773682
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 43/70, training loss: 1.4619926198967483, validation loss: 1.466268539428711
	     precision: 0.9950
	        recall: 0.9949
	            F1: 0.9949
	      accuracy: 0.9949
current_model_accuracy:  0.9949
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 44/70, training loss: 1.461993266778714, validation loss: 1.466201663017273
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 45/70, training loss: 1.461983070444705, validation loss: 1.4661558866500854
	     precision: 0.9952
	        recall: 0.9951
	            F1: 0.9951
	      accuracy: 0.9951
current_model_accuracy:  0.9950999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 46/70, training loss: 1.4619309231162325, validation loss: 1.4661221504211426
	     precision: 0.9952
	        recall: 0.9951
	            F1: 0.9951
	      accuracy: 0.9951
current_model_accuracy:  0.9951000000000001
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 47/70, training loss: 1.4619331532704043, validation loss: 1.466295838356018
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9949
	      accuracy: 0.9948
current_model_accuracy:  0.9948
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 48/70, training loss: 1.4619059102621668, validation loss: 1.4661879539489746
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9950000000000001
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 49/70, training loss: 1.461892848838367, validation loss: 1.4662691354751587
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9950000000000001
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 50/70, training loss: 1.4618991120283538, validation loss: 1.4662799835205078
	     precision: 0.9950
	        recall: 0.9949
	            F1: 0.9949
	      accuracy: 0.9949
current_model_accuracy:  0.9948999999999998
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 51/70, training loss: 1.4619137473197887, validation loss: 1.4663432836532593
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9948
	      accuracy: 0.9948
current_model_accuracy:  0.9948
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 52/70, training loss: 1.4618321893565944, validation loss: 1.4662762880325317
	     precision: 0.9950
	        recall: 0.9949
	            F1: 0.9949
	      accuracy: 0.9949
current_model_accuracy:  0.9949
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 53/70, training loss: 1.4618565611747791, validation loss: 1.4663065671920776
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9948
	      accuracy: 0.9948
current_model_accuracy:  0.9947999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 54/70, training loss: 1.461811296721257, validation loss: 1.4662706851959229
	     precision: 0.9950
	        recall: 0.9949
	            F1: 0.9949
	      accuracy: 0.9949
current_model_accuracy:  0.9949
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 55/70, training loss: 1.4618314435995463, validation loss: 1.4662669897079468
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9948
	      accuracy: 0.9948
current_model_accuracy:  0.9947999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 56/70, training loss: 1.4618610649474904, validation loss: 1.466225028038025
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 57/70, training loss: 1.461821335973516, validation loss: 1.4661425352096558
	     precision: 0.9952
	        recall: 0.9951
	            F1: 0.9951
	      accuracy: 0.9951
current_model_accuracy:  0.9950999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 58/70, training loss: 1.4617944099247329, validation loss: 1.466241717338562
	     precision: 0.9950
	        recall: 0.9949
	            F1: 0.9949
	      accuracy: 0.9949
current_model_accuracy:  0.9949
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 59/70, training loss: 1.461799789085063, validation loss: 1.4662257432937622
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 60/70, training loss: 1.461806454129819, validation loss: 1.466125726699829
	     precision: 0.9952
	        recall: 0.9951
	            F1: 0.9951
	      accuracy: 0.9951
current_model_accuracy:  0.9951000000000001
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 61/70, training loss: 1.4617931903806576, validation loss: 1.466063380241394
	     precision: 0.9950
	        recall: 0.9949
	            F1: 0.9949
	      accuracy: 0.9949
current_model_accuracy:  0.9949
best_accuracy:  0.9951000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 62/70, training loss: 1.4618220781720777, validation loss: 1.4661449193954468
	     precision: 0.9954
	        recall: 0.9952
	            F1: 0.9953
	      accuracy: 0.9952
current_model_accuracy:  0.9952
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 63/70, training loss: 1.461792488342155, validation loss: 1.4662443399429321
	     precision: 0.9952
	        recall: 0.9952
	            F1: 0.9952
	      accuracy: 0.9951
current_model_accuracy:  0.9951000000000001
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 64/70, training loss: 1.4617765295480105, validation loss: 1.4660589694976807
	     precision: 0.9949
	        recall: 0.9948
	            F1: 0.9948
	      accuracy: 0.9948
current_model_accuracy:  0.9948
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 65/70, training loss: 1.4617892440194007, validation loss: 1.4661086797714233
	     precision: 0.9953
	        recall: 0.9953
	            F1: 0.9953
	      accuracy: 0.9952
current_model_accuracy:  0.9952
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 66/70, training loss: 1.4617710001687252, validation loss: 1.4660080671310425
	     precision: 0.9952
	        recall: 0.9951
	            F1: 0.9951
	      accuracy: 0.9951
current_model_accuracy:  0.9951000000000001
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 67/70, training loss: 1.461779373795239, validation loss: 1.466129183769226
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9950000000000001
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 68/70, training loss: 1.4617833455742548, validation loss: 1.466232180595398
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9950
	      accuracy: 0.9950
current_model_accuracy:  0.9950000000000001
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 69/70, training loss: 1.4617799911926042, validation loss: 1.4661777019500732
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9951
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9952


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 70/70, training loss: 1.4618027377992804, validation loss: 1.4661661386489868
	     precision: 0.9951
	        recall: 0.9950
	            F1: 0.9951
	      accuracy: 0.9950
current_model_accuracy:  0.9949999999999999
best_accuracy:  0.9952
[1.510880936691756, 1.4806050934008699, 1.477046595707631, 1.4736327855571756, 1.4720042898202501, 1.4715293981373183, 1.4713689790351558, 1.4698607265822161, 1.468794769569755, 1.4681079758764075, 1.4685766666428621, 1.4673032420022147, 1.4670141259236122, 1.4671001828301435, 1.4678839760294347, 1.467451391443769, 1.4675382914573654, 1.4675519659575114, 1.4666428037289618, 1.4661514380339111, 1.4676608985929347, 1.466084392594376, 1.4661129481756865, 1.466041695842865, 1.4667344825354212, 1.46620781335241, 1.4662178531130239, 1.4654675292562065, 1.467111259889501, 1.466518433363453, 1.4643576841618715, 1.4630393399866914, 1.4626762358618697, 1.462537113029057, 1.462368478144664, 1.4623686934330824, 1.4622848692224986, 1.4622223623780046, 1.4

In [None]:
print('best_accuracy: ',best_accuracy)