In [1]:
#!/usr/bin/env python
# coding: utf-8

import os

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, FashionMNIST

from argparse import ArgumentParser

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from collections import OrderedDict
from torch import optim
import numpy as np
from dataclasses import dataclass
from torchvision.models.resnet import conv1x1, conv3x3


def reduce_mean_dicts(list_dicts, key):
    if not list_dicts:
        return 0.

    s = sum([d[key] for d in list_dicts], 0.)
    return s / len(list_dicts)


@dataclass
class Params:
    batch_size_train: int
    batch_size_val: int
    path_data: str
    learning_rate: float

In [2]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNetPart(nn.Module):

    def __init__(self, zero_init_residual=False):
        super(ResNetPart, self).__init__()
        
        self._norm_layer = nn.BatchNorm2d

        self.inplanes = 64
        self.dilation = 1
        self.groups = 1

        self.layer1 = self._make_layer(BasicBlock, 32, 1)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, 1, 
                            64, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.layer1(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)

In [3]:
ResNetPart()(torch.randn(1, 64, 10, 10)).size()

torch.Size([1, 32, 10, 10])

In [4]:
class MiniVGGDeeper2WoDropout(pl.LightningModule):

    def __init__(self, hparams: Params):
        super().__init__()

        self.hparams = hparams
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            
            nn.MaxPool2d(2, ),
#             nn.Dropout(0.25),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            
            ResNetPart(),
            
#             nn.Dropout(0.25),

            nn.Flatten(1),
            nn.Linear(14 * 14 * 32, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512, ),
            nn.Dropout(0.5),

            nn.Linear(512, 10),
        )
        self.loss = F.cross_entropy
        self.train_dataset = None
        self.val_dataset = None
        
        self.tfms_common = [
            transforms.ToTensor(),
            transforms.Normalize((0.285,), (.3523*2,))
        ]
        self.tfms_train = [
            transforms.RandomRotation(7., fill=(0,)),
            transforms.RandomCrop(28, padding=3),
        ]
        

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
         
        loss_val = self.loss(y_hat, y, )

        
        tqdm_dict = {'train_loss': loss_val,
                    }
        
        if batch_idx % 100 == 0:
            tqdm_dict['lr'] = np.array(self.trainer.lr_schedulers[0]['scheduler'].get_lr())
            
        return OrderedDict({
            'loss': loss_val,
            'log': tqdm_dict
        })

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)

        loss_val = self.loss(y_hat, y, )

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        output = OrderedDict({
            'val_loss': loss_val,
            'val_acc': val_acc,
        })

        return output

    def validation_epoch_end(self, outputs):
        loss_mean = reduce_mean_dicts(outputs, 'val_loss')
        acc_mean = reduce_mean_dicts(outputs, 'val_acc')
        tqdm_dict = {
            'val_loss': loss_mean,
            'val_acc': acc_mean
        }
        result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': loss_mean}
        return result

    def prepare_data(self):

        mnist_train = FashionMNIST(self.hparams.path_data, train=True, download=True,
                                   transform=transforms.Compose(self.tfms_train + self.tfms_common)
                                  )
        mnist_test = FashionMNIST(self.hparams.path_data, train=False, download=True,
                                  transform=transforms.Compose(self.tfms_common)
                                  )

        #         mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        self.train_dataset = mnist_train
        self.val_dataset = mnist_test

    def train_dataloader(self):
        if self.train_dataset is None:
            self.prepare_data()
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size_train, num_workers=0)

    def val_dataloader(self):
        if self.val_dataset is None:
            self.prepare_data()
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size_val, num_workers=0)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size_val, num_workers=0)

    def configure_optimizers(self):
        """
        Return whatever optimizers and learning rate schedulers you want here.
        At least one optimizer is required.
        """
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.learning_rate, 
                                                  steps_per_epoch=1, 
                                                  epochs=self.trainer.max_epochs
                                                  )
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        parser = ArgumentParser(parents=[parent_parser])

        parser.add_argument('--learning_rate', default=0.01, type=float)
        parser.add_argument('--batch_size_train', default=64, type=int)
        parser.add_argument('--batch_size_val', default=32, type=int)
        parser.add_argument('--path_data', default="./data", type=str)

        # training params (opt)
        parser.add_argument('--epochs', default=20, type=int)
        return parser

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this during testing, similar to `validation_step`,
        with the data from the test dataloader passed in as `batch`.
        """
        output = self.validation_step(batch, batch_idx)
        # Rename output keys
        output['test_loss'] = output.pop('val_loss')
        output['test_acc'] = output.pop('val_acc')
        return output

    def test_epoch_end(self, outputs):
        loss_mean = reduce_mean_dicts(outputs, 'test_loss')
        acc_mean = reduce_mean_dicts(outputs, 'test_acc')
        tqdm_dict = {
            'test_loss': loss_mean,
            'test_acc': acc_mean
        }
        result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'test_loss': loss_mean}
        return result

In [5]:
def main(hparams):
    model = MiniVGGDeeper2WoDropout(hparams)

    name = type(model).__name__
    logger = TensorBoardLogger("lightning_logs", name=name)

    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join('models', name, name+'_{epoch:02d}-{val_acc:.4f}-{val_loss:.4f}'),
        save_top_k=1,
        verbose=True,
        monitor='val_acc',
        mode='max',
        prefix=''
    )

    trainer = pl.Trainer(
        max_epochs=hparams.epochs,
        gpus=hparams.gpus,
        logger=logger,
        checkpoint_callback=checkpoint_callback,
        # precision=16 if hparams.use_16bit else 32,
        progress_bar_refresh_rate=50,
    )

    trainer.fit(model)

    trainer.test(model)

In [6]:
if __name__ == '__main__':
    parent_parser = ArgumentParser(add_help=False)

    # gpu args
    parent_parser.add_argument(
        '--gpus',
        type=int,
        default=1,
        help='how many gpus'
    )

    parser = MiniVGGDeeper2WoDropout.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args([
        "--learning_rate", "0.001",
        "--epochs", "40"])
    main(hyperparams)

INFO:lightning:GPU available: True, used: True
INFO:lightning:VISIBLE GPUS: 0
INFO:lightning:
   | Name                           | Type        | Params
-----------------------------------------------------------
0  | model                          | Sequential  | 3 M   
1  | model.0                        | Conv2d      | 320   
2  | model.1                        | ReLU        | 0     
3  | model.2                        | BatchNorm2d | 64    
4  | model.3                        | Conv2d      | 9 K   
5  | model.4                        | ReLU        | 0     
6  | model.5                        | BatchNorm2d | 64    
7  | model.6                        | MaxPool2d   | 0     
8  | model.7                        | Conv2d      | 18 K  
9  | model.8                        | ReLU        | 0     
10 | model.9                        | BatchNorm2d | 128   
11 | model.10                       | ResNetPart  | 29 K  
12 | model.10.layer1                | Sequential  | 29 K  
13 | model.10.layer1

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00000: val_acc reached 0.83387 (best 0.83387), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=00-val_acc=0.8339-val_loss=0.4587.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00001: val_acc reached 0.85992 (best 0.85992), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=01-val_acc=0.8599-val_loss=0.3836.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00002: val_acc reached 0.87720 (best 0.87720), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=02-val_acc=0.8772-val_loss=0.3469.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00003: val_acc reached 0.88069 (best 0.88069), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=03-val_acc=0.8807-val_loss=0.3324.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00004: val_acc reached 0.89547 (best 0.89547), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=04-val_acc=0.8955-val_loss=0.2969.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00005: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00006: val_acc reached 0.90276 (best 0.90276), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=06-val_acc=0.9028-val_loss=0.2803.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00007: val_acc reached 0.90375 (best 0.90375), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=07-val_acc=0.9038-val_loss=0.2728.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00008: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00009: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00010: val_acc reached 0.91344 (best 0.91344), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=10-val_acc=0.9134-val_loss=0.2461.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00011: val_acc reached 0.91444 (best 0.91444), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=11-val_acc=0.9144-val_loss=0.2380.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00012: val_acc reached 0.91613 (best 0.91613), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=12-val_acc=0.9161-val_loss=0.2303.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00013: val_acc reached 0.91793 (best 0.91793), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=13-val_acc=0.9179-val_loss=0.2235.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00014: val_acc reached 0.91883 (best 0.91883), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=14-val_acc=0.9188-val_loss=0.2241.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00015: val_acc reached 0.91963 (best 0.91963), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=15-val_acc=0.9196-val_loss=0.2514.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00016: val_acc reached 0.92572 (best 0.92572), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=16-val_acc=0.9257-val_loss=0.2071.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00017: val_acc reached 0.92742 (best 0.92742), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=17-val_acc=0.9274-val_loss=0.2361.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00018: val_acc reached 0.93111 (best 0.93111), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=18-val_acc=0.9311-val_loss=0.2014.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00019: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00020: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00021: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00022: val_acc reached 0.93361 (best 0.93361), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=22-val_acc=0.9336-val_loss=0.1936.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00023: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00024: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00025: val_acc reached 0.93450 (best 0.93450), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=25-val_acc=0.9345-val_loss=0.1927.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00026: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00027: val_acc reached 0.93700 (best 0.93700), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=27-val_acc=0.9370-val_loss=0.1829.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00028: val_acc reached 0.93820 (best 0.93820), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=28-val_acc=0.9382-val_loss=0.1777.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00029: val_acc reached 0.93840 (best 0.93840), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=29-val_acc=0.9384-val_loss=0.1769.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00030: val_acc reached 0.93860 (best 0.93860), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=30-val_acc=0.9386-val_loss=0.1763.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00031: val_acc reached 0.93940 (best 0.93940), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=31-val_acc=0.9394-val_loss=0.1741.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00032: val_acc reached 0.93990 (best 0.93990), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=32-val_acc=0.9399-val_loss=0.1736.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00033: val_acc reached 0.94079 (best 0.94079), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=33-val_acc=0.9408-val_loss=0.1717.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00034: val_acc reached 0.94099 (best 0.94099), saving model to models/MiniVGGDeeper2WoDropout/MiniVGGDeeper2WoDropout_epoch=34-val_acc=0.9410-val_loss=0.1714.ckpt as top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00035: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00036: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00037: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00038: val_acc  was not in top 1


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

INFO:lightning:
Epoch 00039: val_acc  was not in top 1







HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=313.0, style=Progr…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': 0.940994381904602, 'test_loss': 0.17086780071258545}
--------------------------------------------------------------------------------

