# ViT

## 0. Paper

### Info
* Title: An Image is 16x16 Words: Transformers for Image Recognition at Scale
* Author: Alexey Dosovitskiy
* Task: Image Classification
* Link: https://arxiv.org/abs/2010.11929


### Features
* Dataset: CIFAR-10

### Reference
* https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
* https://github.com/huggingface/transformers/blob/95037a169f5b0327288a91183dce61a9e09f2e71/src/transformers/models/vit/modeling_vit.py#L309


## 1. Setting

In [1]:
# Libraries
import os
from glob import glob
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchsummary import torchsummary

In [2]:
class CONFIG:
    num_block = 8
    num_head = 4
    num_class = 10
    dim_size = 128
    image_size = 32
    patch_size = 2
    drop_rate = 0.0
    batch_size = 256
    epoch_size = 100
    base_dir = '/content/drive/Shared drives/Yoon/Project/Doing/Deep Learning Paper Implementation'

## 2. Data

In [3]:
def create_dataloader(batch_size):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2439, 0.2616)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2439, 0.2616)),
    ])


    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [45000, 5000])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

In [4]:
train_loader, val_loader, test_loader = create_dataloader(CONFIG.batch_size)
inputs, targets = next(iter(train_loader))
inputs.size(), targets.size()

Files already downloaded and verified
Files already downloaded and verified


(torch.Size([256, 3, 32, 32]), torch.Size([256]))

## 3. Model

In [5]:
class Embedding(nn.Module):
    def __init__(self, dim_size, image_size, patch_size, drop_rate):
        super(Embedding, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(3, dim_size, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim_size))
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches+1, dim_size))
        self.dropout = nn.Dropout(drop_rate)

    def forward(self, x):
        x = self.conv(x) # (batch_size, dim_size, img_size/patch_size, img_size/patch_size)
        x = x.flatten(2) # (batch_size, dim_size, num_patches)
        x = x.transpose(1, 2) # (batch_size, num_patches, dim_size)

        cls_token = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x += self.pos_embedding
        x = self.dropout(x)
        return x


class Attention(nn.Module):
    def __init__(self, num_head, dim_size, drop_rate):
        super(Attention, self).__init__()
        self.num_head = num_head
        self.dim_size = dim_size
        self.temperature = dim_size ** 0.5

        self.qkv = nn.Linear(dim_size, 3*dim_size, bias=False)
        self.proj = nn.Linear(dim_size, dim_size)
        self.drop = nn.Dropout(drop_rate)

    def forward(self, x, mask=None):
        B, N, C = x.size()
        qkv = self.qkv(x).reshape(B, N, 3, self.num_head, self.dim_size // self.num_head).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, num_head, seq_len, dim_size//num_head)
        attn = torch.matmul(q, k.transpose(3, 2)) / self.temperature # (batch_size, num_head, seq_len, seq_len)
        attn = F.softmax(attn, dim=-1)
        attn = self.drop(attn)

        x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.drop(x)
        return x


class MLP(nn.Module):
    def __init__(self, in_C, C, out_C):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_C, C)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(C, out_C)

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

    
class Block(nn.Module):
    def __init__(self, num_head, dim_size, drop_rate):
        super(Block, self).__init__()
        self.norm1 = nn.LayerNorm(dim_size)
        self.attn = Attention(num_head, dim_size, drop_rate)
        self.norm2 = nn.LayerNorm(dim_size)
        self.mlp = MLP(dim_size, dim_size//2, dim_size)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, num_block, num_head, num_class, dim_size, image_size, patch_size, drop_rate):
        super(VisionTransformer, self).__init__()
        self.embedding = Embedding(dim_size, image_size, patch_size, drop_rate)
        self.blocks = nn.Sequential(*[Block(num_head, dim_size, drop_rate) for _ in range(num_block)])
        self.classifier = nn.Linear(dim_size, num_class)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.blocks(x)
        x = self.classifier(x[:,0])
        return x

## 4. Experiment

In [7]:
class AverageMeter(object):
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = f'{self.name:10s} {self.avg:.3f}'
        return fmtstr


class ProgressMeter(object):
    def __init__(self, meters):
        self.meters = [AverageMeter(m) for m in meters]
    
    def reset(self):
        for m in self.meters:
            m.reset()
    
    def update(self, values, n=1):
        for m, v in zip(self.meters, values):
            m.update(v, n)
            self.__setattr__(m.name, m.avg)

    def log(self):
        msg = [str(meter) for meter in self.meters]
        msg = ' | '.join(msg)
        return msg


def accuracy(logits, targets):
    _, preds = logits.max(1)
    acc = (preds == targets).float().mean()
    return acc

In [8]:
class Trainer(object):
    def __init__(self, model, criterion, optimizer, scheduler, device):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.best_epoch, self.best_score = 0, 0
        

    def train(self, train_loader, epoch):
        progress = ProgressMeter(["train_loss", "train_acc"])
        self.model.train()

        pbar = tqdm(train_loader)
        pbar.set_description(f'TRAIN {epoch:03d}')
        for idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss = loss.item()
            acc = accuracy(outputs, targets).item()
            progress.update([loss, acc], n=inputs.size(0))
            pbar.set_postfix(log=progress.log())

        self.scheduler.step()

    
    def validate(self, valid_loader, epoch):
        progress = ProgressMeter(["valid_loss", "valid_acc"])
        self.model.eval()

        pbar = tqdm(valid_loader)
        pbar.set_description(f'VALID {epoch:03d}')
        with torch.no_grad():
            for idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss = loss.item()
                acc = accuracy(outputs, targets).item()
                progress.update([loss, acc], n=inputs.size(0))
                pbar.set_postfix(log=progress.log())

            if progress.valid_acc > self.best_score:
                self.best_epoch = epoch
                self.best_score = progress.valid_acc
                ckpt = {
                    'best_epoch': self.best_epoch,
                    'best_score': self.best_score,
                    'model_state_dict': self.model.state_dict()
                }
                torch.save(ckpt, 'ckpt.pt')

    
    def test(self, test_loader):
        progress = ProgressMeter(["test_loss", "test_acc"])
        ckpt = torch.load('ckpt.pt')
        self.model.load_state_dict(ckpt['model_state_dict'])
        self.model.eval()

        pbar = tqdm(test_loader)
        pbar.set_description(f'TEST')
        with torch.no_grad():
            for idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss = loss.item()
                acc = accuracy(outputs, targets).item()
                progress.update([loss, acc], n=inputs.size(0))
                pbar.set_postfix(log=progress.log())

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionTransformer(CONFIG.num_block, CONFIG.num_head, CONFIG.num_class, CONFIG.dim_size, CONFIG.image_size, CONFIG.patch_size, CONFIG.drop_rate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, CONFIG.epoch_size)

torchsummary.summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 16, 16]           1,664
           Dropout-2             [-1, 257, 128]               0
         Embedding-3             [-1, 257, 128]               0
         LayerNorm-4             [-1, 257, 128]             256
            Linear-5             [-1, 257, 384]          49,152
           Dropout-6          [-1, 4, 257, 257]               0
            Linear-7             [-1, 257, 128]          16,512
           Dropout-8             [-1, 257, 128]               0
         Attention-9             [-1, 257, 128]               0
        LayerNorm-10             [-1, 257, 128]             256
           Linear-11              [-1, 257, 64]           8,256
             GELU-12              [-1, 257, 64]               0
           Linear-13             [-1, 257, 128]           8,320
              MLP-14             [-1, 2

In [10]:
trainer = Trainer(model, criterion, optimizer, scheduler, device)

In [11]:
for ep in range(CONFIG.epoch_size):
    print('=' * 80)
    trainer.train(train_loader, ep)
    trainer.validate(val_loader, ep)



HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




In [12]:
trainer.test(test_loader)

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


