# Image Classification via WideResNet + SAM on FashionMNIST dataset
based on (Adaptive) SAM Optimizer https://github.com/davda54/sam repo.

## Initialization
To ensure results are reproducible

In [3]:
import torch
import random

def init_random(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

init_random(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device} device")

# hyper parameters
hparams = {}

using cuda:0 device


## Dataset with augmentation and preprocessing
For augmentation I use a small jitter in position via RandomCrop; random flip around vertical axis and random erasing / cutout.
I did experiments with AugMix and AutoAugment learnt on CIFAR10, but results were not better.

In [4]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader


class FashionMNIST:
    def __init__(self, batch_size, threads):
        self.image_size = (28, 28)

        # gather statistics
        def get_statistics(train_set):
            data = torch.cat([d[0] for d in DataLoader(train_set)])
            return data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3])

        mean, std = get_statistics(
            datasets.FashionMNIST(root='./fashionmnist', train=True, download=True, transform=transforms.ToTensor()))

        train_transform = transforms.Compose([
            transforms.RandomCrop(size=self.image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        self.infer_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        self.train_set = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
        self.test_set = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)

        self.train = DataLoader(self.train_set, batch_size=batch_size, shuffle=True, num_workers=threads)
        self.test = DataLoader(self.test_set, batch_size=batch_size, shuffle=False, num_workers=threads)

        self.classes = self.train_set.classes


dataset = FashionMNIST(batch_size=128, threads=2)

## Define Wide ResNet model
I use WRN-16-8 architecture with dropouts. Experiment with WRN-28-10 did not show an improvement.

In [5]:
from collections import OrderedDict
import torch
import torch.nn as nn

hparams |= {
        "model/dropout": 0.3,
        "model/depth": 16,
        "model/width_factor": 8
    }

class BasicUnit(nn.Module):
    def __init__(self, channels: int, dropout: float):
        super(BasicUnit, self).__init__()
        self.block = nn.Sequential(OrderedDict([
            ("0_normalization", nn.BatchNorm2d(channels)),
            ("1_activation", nn.ReLU()),
            ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
            ("3_normalization", nn.BatchNorm2d(channels)),
            ("4_activation", nn.ReLU()),
            ("5_dropout", nn.Dropout(dropout)),
            ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
        ]))

    def forward(self, x):
        return x + self.block(x)


class DownsampleUnit(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float):
        super(DownsampleUnit, self).__init__()
        self.norm_act = nn.Sequential(OrderedDict([
            ("0_normalization", nn.BatchNorm2d(in_channels)),
            ("1_activation", nn.ReLU()),
        ]))
        self.block = nn.Sequential(OrderedDict([
            ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)),
            ("1_normalization", nn.BatchNorm2d(out_channels)),
            ("2_activation", nn.ReLU()),
            ("3_dropout", nn.Dropout(dropout)),
            ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)),
        ]))
        self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False)

    def forward(self, x):
        x = self.norm_act(x)
        return self.block(x) + self.downsample(x)


class Block(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float):
        super(Block, self).__init__()
        self.block = nn.Sequential(
            DownsampleUnit(in_channels, out_channels, stride, dropout),
            *(BasicUnit(out_channels, dropout) for _ in range(depth))
        )

    def forward(self, x):
        return self.block(x)


class WideResNet(nn.Module):
    def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, labels: int, input_size: int):
        super(WideResNet, self).__init__()

        self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor, 7]
        self.block_depth = (depth - 4) // (3 * 2)

        self.f = nn.Sequential(OrderedDict([
            # 128 x 1 x 28 x 28
            ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)),
            # 128 x 16 x 28 x 28
            ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)),
            # 128 x 128 x 28 x 28
            ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)),
            # 128 x 256 x 14 x 14
            ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)),
            # 128 x 128 x 7 x 7
            ("4_normalization", nn.BatchNorm2d(self.filters[3])),
            # 128 x 128 x 7 x 7
            ("5_activation", nn.ReLU()),
            # 128 x 128 x 7 x 7
            ("6_pooling", nn.AvgPool2d(kernel_size=input_size // 4)),
            # 128 x 128 x 1 x 1
            ("7_flattening", nn.Flatten()),
            # 128 x 128
            ("8_classification", nn.Linear(in_features=self.filters[3], out_features=labels)),
            # 128 x 10
        ]))

        self._initialize()

    def _initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x):
        return self.f(x)



model = WideResNet(depth=hparams["model/depth"], width_factor=hparams["model/width_factor"], dropout=hparams["model/dropout"],
                   in_channels=1, labels=10, input_size=dataset.image_size[0]).to(device)
print(model)
print(hparams)

WideResNet(
  (f): Sequential(
    (0_convolution): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1_block): Block(
      (block): Sequential(
        (0): DownsampleUnit(
          (norm_act): Sequential(
            (0_normalization): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1_activation): ReLU()
          )
          (block): Sequential(
            (0_convolution): Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1_normalization): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2_activation): ReLU()
            (3_dropout): Dropout(p=0.3, inplace=False)
            (4_convolution): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
          (downsample): Conv2d(16, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BasicUnit(
          (block):

## Optimizer: SAM + SGD with nesterov momentum
SAM is Sharpness-Aware Minimization that
seeks parameters that lie in neighborhoods having uniformly low loss. It improves model generalization and provides robustness to label noise.
See [SAM](https://arxiv.org/abs/2010.01412) and [Adaptive SAM](https://arxiv.org/abs/2102.11600) papers.
I tested different values of rho for 30 epochs with no data augmentation and rho=1 seems to be the best.

In [6]:
hparams |= {"SAM/rho": 1,
            "SAM/adaptive": True}
hparams |= {"SGD/lr": 0.1,
            "SGD/momentum": 0.9,
            "SGD/weight_decay":0.0005,
            "SGD/nesterov": True
           }

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):

        # put everything on the same device, in case of model parallelism
        shared_device = self.param_groups[0]["params"][0].device

        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

optimizer = SAM(model.parameters(), torch.optim.SGD,
                rho=hparams["SAM/rho"], adaptive=hparams["SAM/adaptive"],
                lr=hparams["SGD/lr"], momentum=hparams["SGD/momentum"], weight_decay=hparams["SGD/weight_decay"], nesterov=hparams["SGD/nesterov"])
print(optimizer)

SAM (
Parameter Group 0
    adaptive: True
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.1
    maximize: False
    momentum: 0.9
    nesterov: True
    rho: 1
    weight_decay: 0.0005
)


### Functions to set BN momentum to zero to bypass the running statistics during the second pass backpropagation. (Two passes are due to SAM)

In [7]:
from torch.nn.modules.batchnorm import _BatchNorm

class disable_running_stats:
    def __init__(self, model):
        self.model = model

    def __enter__(self):
        def _disable(module):
            if isinstance(module, _BatchNorm):
                module.backup_momentum = module.momentum
                module.momentum = 0

        self.model.apply(_disable)

    def __exit__(self, exc_type, exc_val, exc_tb):
        def _enable(module):
            if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
                module.momentum = module.backup_momentum

        self.model.apply(_enable)


## Train duration
Original paper used 200 epochs

In [51]:
hparams["total_epochs"] = 10

## Scheduler
Decrease learning rate 5 times at 0.3, 0.6 and 0.8 relative epochs

In [52]:
class StepLR:
    def __init__(self, optimizer, learning_rate: float, total_epochs: int):
        self.optimizer = optimizer
        self.total_epochs = total_epochs
        self.base = learning_rate

    def __call__(self, epoch):
        if epoch < self.total_epochs * 3 / 10:
            lr = self.base
        elif epoch < self.total_epochs * 6 / 10:
            lr = self.base * 0.2
        elif epoch < self.total_epochs * 8 / 10:
            lr = self.base * 0.2 ** 2
        else:
            lr = self.base * 0.2 ** 3

        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def lr(self) -> float:
        return self.optimizer.param_groups[0]["lr"]

scheduler = StepLR(optimizer, learning_rate=hparams["SGD/lr"], total_epochs=hparams["total_epochs"])

## Loss
Smooth labels and compute probabilities difference

In [53]:
import torch.nn.functional as F

hparams |= {"loss/label_smoothing": 0.1}

def smooth_crossentropy(pred, gold, smoothing=hparams["loss/label_smoothing"]):
    n_class = pred.size(1)

    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)

    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)

## Convenient Logging class

In [54]:
import time


class LoadingBar:
    def __init__(self, length: int = 40):
        self.length = length
        self.symbols = ['┈', '░', '▒', '▓']

    def __call__(self, progress: float) -> str:
        p = int(progress * self.length * 4 + 0.5)
        d, r = p // 4, p % 4
        return '┠┈' + d * '█' + (
            (self.symbols[r]) + max(0, self.length - 1 - d) * '┈' if p < self.length * 4 else '') + "┈┨"


class Log:
    def __init__(self, log_each: int, initial_epoch=-1, writer=None):
        self.loading_bar = LoadingBar(length=27)
        self.best_accuracy = 0.0
        self.log_each = log_each
        self.epoch = initial_epoch
        self.is_train = False
        self.is_eval = False
        self.initial_epoch = initial_epoch
        self.writer = writer
        self.epoch_state = {}

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.flush()
        self._print_footer()

    def train(self, len_dataset: int) -> None:
        self.epoch += 1
        if self.epoch == self.initial_epoch + 1:
            self._print_header()
        else:
            self.flush()

        self.is_train = True
        self.is_eval = False
        self.last_steps_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0}
        self._reset(len_dataset)

    def eval(self, len_dataset: int) -> None:
        self.flush()
        self.is_train = False
        self.is_eval = True
        self._reset(len_dataset)

    def __call__(self, model, loss, accuracy, learning_rate: float = None) -> None:
        if self.is_train:
            self._train_step(model, loss, accuracy, learning_rate)
        else:
            self._eval_step(loss, accuracy)

    def flush(self) -> None:
        if self.is_train:
            self.train_loss = self.epoch_state["loss"] / self.epoch_state["steps"]
            self.train_accuracy = self.epoch_state["accuracy"] / self.epoch_state["steps"]
            if self.writer:
                self.writer.add_scalar("Loss/train", self.train_loss, self.epoch)
                self.writer.add_scalar("Accuracy/train", self.train_accuracy, self.epoch)
                self.writer.flush()

            print(
                f"\r┃{self.epoch:12d}  ┃{self.train_loss:12.4f}  │{100 * self.train_accuracy:10.2f} %  ┃{self.learning_rate:12.3e}  │{self._time():>12}  ┃",
                end="",
                flush=True,
            )

        elif self.is_eval:
            loss = self.epoch_state["loss"] / self.epoch_state["steps"]
            accuracy = self.epoch_state["accuracy"] / self.epoch_state["steps"]
            if self.writer:
                self.writer.add_scalar("Loss/eval", loss, self.epoch)
                self.writer.add_scalar("Accuracy/eval", accuracy, self.epoch)
                self.writer.flush()

            print(f"\r┃{self.epoch:12d}  ┃{self.train_loss:12.4f}  │{100 * self.train_accuracy:10.2f} %  ┃{self.learning_rate:12.3e}  │{self._time():>12}  ┃{loss:12.4f}  │{100 * accuracy:10.2f} %  ┃", flush=True)

            if accuracy > self.best_accuracy:
                self.best_accuracy = accuracy

    def loss(self):
        return self.epoch_state["loss"] / self.epoch_state["steps"]

    def accuracy(self):
        return self.epoch_state["accuracy"] / self.epoch_state["steps"]

    def _train_step(self, model, loss, accuracy, learning_rate: float) -> None:
        self.learning_rate = learning_rate
        self.last_steps_state["loss"] += loss.sum().item()
        self.last_steps_state["accuracy"] += accuracy.sum().item()
        self.last_steps_state["steps"] += loss.size(0)
        self.epoch_state["loss"] += loss.sum().item()
        self.epoch_state["accuracy"] += accuracy.sum().item()
        self.epoch_state["steps"] += loss.size(0)
        self.step += 1

        if self.step % self.log_each == self.log_each - 1:
            loss = self.last_steps_state["loss"] / self.last_steps_state["steps"]
            accuracy = self.last_steps_state["accuracy"] / self.last_steps_state["steps"]

            self.last_steps_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0}
            progress = self.step / self.len_dataset

            print(
                f"\r┃{self.epoch:12d}  ┃{loss:12.4f}  │{100 * accuracy:10.2f} %  ┃{learning_rate:12.3e}  │{self._time():>12}  {self.loading_bar(progress)}",
                end="",
                flush=True,
            )

    def _eval_step(self, loss, accuracy) -> None:
        self.epoch_state["loss"] += loss.sum().item()
        self.epoch_state["accuracy"] += accuracy.sum().item()
        self.epoch_state["steps"] += loss.size(0)

    def _reset(self, len_dataset: int) -> None:
        self.start_time = time.time()
        self.step = 0
        self.len_dataset = len_dataset
        self.epoch_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0}

    def _time(self) -> str:
        time_seconds = int(time.time() - self.start_time)
        return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min"

    def _print_header(self) -> None:
        print(
            f"┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓")
        print(
            f"┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃")
        print(
            f"┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃")
        print(
            f"┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨")

    def _print_footer(self) -> None:
        print(
            f"┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛")

## Training loop
Use tensor board to record progress and history

In [55]:
from torch.utils.tensorboard import SummaryWriter


def training(train_dl, test_dl, model, optimizer, scheduler, total_epochs, device, start_epoch=0, writer=None):
    with Log(log_each=1, initial_epoch=start_epoch - 1, writer=writer) as log:
        for epoch in range(start_epoch, total_epochs):
            model.train()
            log.train(len_dataset=len(train_dl))

            for nb, batch in enumerate(train_dl):
                inputs, targets = (b.to(device) for b in batch)

                # first forward-backward step
                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
                loss.mean().backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward step
                with disable_running_stats(model):
                    smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
                    optimizer.second_step(zero_grad=True)

                # gather stats and update learning rate
                with torch.no_grad():
                    correct = torch.argmax(predictions.data, 1) == targets
                    log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                    scheduler(epoch)

            # update stat on test set
            model.eval()
            log.eval(len_dataset=len(test_dl))

            with torch.no_grad():
                for batch in test_dl:
                    inputs, targets = (b.to(device) for b in batch)
                    predictions = model(inputs)
                    loss = smooth_crossentropy(predictions, targets)
                    correct = torch.argmax(predictions, 1) == targets
                    log(model, loss.cpu(), correct.cpu())
        final_loss, final_accuracy = log.loss(), log.accuracy()
    return final_loss, final_accuracy


with SummaryWriter() as writer:
    print(hparams)
    # dataset.test.next
    loss, accuracy = training(train_dl=dataset.train, test_dl=dataset.test, model=model, optimizer=optimizer,
                              scheduler=scheduler,
                              total_epochs=hparams["total_epochs"], device=device, writer=writer)
    print(F"loss: {loss}, accuracy: {accuracy}")

{'model/dropout': 0.3, 'model/depth': 16, 'model/width_factor': 8, 'SAM/rho': 1, 'SAM/adaptive': True, 'SGD/lr': 0.1, 'SGD/momentum': 0.9, 'SGD/weight_decay': 0.0005, 'SGD/nesterov': True, 'total_epochs': 10, 'loss/label_smoothing': 0.1}
┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨
┃           1  ┃      0.1953  │     91.05 %  ┃   1.000e-01  │   00:03 min  ┃      0.1799  │     91.72 %  ┃
┃           2  ┃      0.1932  │     91.14 %  ┃   1.000e-01  │   00:03 min  ┃      0.2795  │     88.80 %  ┃
┃           3  ┃      0.1935  │     91.25 %  ┃   1.000e-01  │   00:03 min  ┃      0.1727  │     92.13 %  ┃
┃           4

## Applying result to unrelated images and form resulting json

In [68]:
from PIL import Image
from pathlib import Path
import json

result = {"test_acc": accuracy, "top_predictions": {}}

model.eval()
for image_fname in ['images/img1.jpg', 'images/img2.jpg', 'images/img3.jpg']:
    img = dataset.infer_transform(Image.open(image_fname))[None, :]
    with torch.no_grad():
        infer = model(img.to(device))
    cl = int(infer.argmax(1)[0])
    cl_name = dataset.classes[cl]
    prob = F.softmax(infer, dim=1).cpu().numpy()[0, cl].astype(float)
    short_fname = str(Path(image_fname).stem)
    print(f"{short_fname}: {cl_name} with {100 * prob:.2f}% probability")
    result["top_predictions"][short_fname] = {cl_name : round(prob, 2)}
print (result)
with open('result.json', 'w') as f:
    json.dump(result, f, indent=2)

img1: Bag with 88.45% probability
img2: Bag with 88.27% probability
img3: Bag with 92.54% probability
{'test_acc': 0.9503, 'top_predictions': {'img1': {'Bag': 0.88}, 'img2': {'Bag': 0.88}, 'img3': {'Bag': 0.93}}}
