In [1]:
google_colab = False

# Google Colab用事前準備

In [2]:
# install
try:
    import binarybrain as bb
except:
    if google_colab and False:
        !pip install pybind11
        %cd /content
        !nvcc -V
        !sudo rm -fr BinaryBrain
        !rm -fr BinaryBrain
        !git clone --recursive -b ver3_develop https://github.com/ryuz/BinaryBrain.git
        %cd /content/BinaryBrain/python
        !sh copy_src.sh
        !python3 setup.py build
        !python3 setup.py develop

        import binarybrain as bb

In [3]:
# mount
if google_colab:
  from google.colab import drive 
  drive.mount('/content/drive')
  %cd /content/drive/My Drive/git-work/BinaryBrain_ver3_develop/tests/python

In [4]:
device = 0

import binarybrain as bb

import importlib
importlib.reload(bb)

bb.set_device(device)
prop = bb.get_device_properties(device)
print(prop)


name                     : GeForce GTX 1660 SUPER
totalGlobalMem           : 6442450944
sharedMemPerBlock        : 49152
regsPerBlock             : 65536
warpSize                 : 32
memPitch                 : 2147483647
maxThreadsPerBlock       : 1024
maxThreadsDim[0]         : 1024
maxThreadsDim[1]         : 1024
maxThreadsDim[2]         : 64
maxGridSize[0]           : 2147483647
maxGridSize[1]           : 65535
maxGridSize[2]           : 65535
clockRate                : 1800000
totalConstMem            : 65536
major                    : 7
minor                    : 5
textureAlignment         : 512
deviceOverlap            : 1
multiProcessorCount      : 22
kernelExecTimeoutEnabled : 1
integrated               : 0
canMapHostMemory         : 1
computeMode              : 0




# メインコード

In [5]:
import binarybrain as bb
import numpy as np
import matplotlib.pyplot as plt
import random
import cv2
import os
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import OrderedDict

In [6]:
def make_test_image(src_x, src_t):
    dst_x = np.ndarray((len(src_x),  1, 28, 28), dtype=np.float32)
    dst_t = np.ndarray((len(src_x), 11, 28, 28), dtype=np.float32)
    for i in range(len(src_x)):
        x_img = src_x[i].copy()
        t_img = np.zeros((11, 28*28), dtype=np.float32)
        mask0 = x_img <= 0.5
        mask1 = x_img > 0.5
        for j in range(10):
          t_img[j,mask1] = src_t[i][j]
        t_img[10,mask0] = 0.15
        dst_x[i] = x_img.reshape(1, 28, 28)
        dst_t[i] = t_img.reshape(11, 28, 28)
    return dst_x, dst_t

In [7]:
def make_test_data(src_x, src_t):
    x, t = make_test_image(src_x, src_t)

    dst_x = np.ndarray((len(src_x) // 4,  1*56*56), dtype=np.float32)
    dst_t = np.ndarray((len(src_x) // 4, 11*56*56), dtype=np.float32)
    for i in range(len(x) // 4):
        x_img = np.hstack((np.dstack((x[i*4+0], x[i*4+1])), np.dstack((x[i*4+2], x[i*4+3]))))
        t_img = np.hstack((np.dstack((t[i*4+0], t[i*4+1])), np.dstack((t[i*4+2], t[i*4+3]))))
        dst_x[i] = x_img.reshape(-1)
        dst_t[i] = t_img.reshape(-1)
    return dst_x, dst_t

In [8]:
def image_transform(img):
    angle = random.uniform(-10.0, +10.0)
    scale = random.uniform(0.8, 1.1)
    trans = cv2.getRotationMatrix2D((14, 14), angle , scale)
    trans[0][2] += random.uniform(-2.0, 2.0)
    trans[1][2] += random.uniform(-2.0, 2.0)
    return cv2.warpAffine(img.reshape(28, 28), trans, (28, 28))

def make_td_image(src_x, src_t, w=28, h=28, depth=27):
    l = len(src_x)
    mx = (28 + depth + w + depth + 27) // 28
    my = (28 + depth + h + depth + 27) // 28
    dst_x = np.zeros(( 1, 28*my, 28*mx), dtype=np.float32)
    dst_t = np.zeros((11, 28*my, 28*mx), dtype=np.float32)
    for y in range(my):
        for x in range(mx):
            i = random.randrange(0, l)
            dst_x[0,y*28:y*28+28, x*28:x*28+28] = image_transform(src_x[i])
            dst_t[0:10,y*28:y*28+28, x*28:x*28+28] = np.tile(src_t[i], (28,28)).reshape(28, 28, 10).transpose(2, 0, 1)
    x_flag0 = dst_x[0] <= 0.5
    x_flag1 = dst_x[0] >  0.5
    for i in range(10):
        dst_t[i][x_flag0] = 0
    dst_t[10][x_flag0] = 0.15
    
    xs = random.randrange(0, 28)
    ys = random.randrange(0, 28)
    dst_x  = dst_x[:,ys:ys+2*depth+h,xs:xs+2*depth+w]
    x_flag = x_flag1[ys:ys+2*depth+h,xs:xs+2*depth+w]
    
    dst_t  = dst_t[:,ys+depth:ys+depth+h,xs+depth:xs+depth+w]
    t_flag = x_flag1[ys+depth:ys+depth+h,xs+depth:xs+depth+w]
    
    x_mask = np.zeros(x_flag.shape, dtype=np.float32)
    x_mask[x_flag] = 1.0
    t_mask = np.zeros(t_flag.shape, dtype=np.float32)
    t_mask[t_flag] = 1.0
    
    if random.randint(0, 1) > 0:
        dst_x = 1.0 - dst_x
    
    return dst_x.flatten(), dst_t.flatten(), x_mask.flatten(), t_mask.flatten()

def generate_td(src_td, train_size=60000//4, test_size=10000//4, w=28, h=28, depth=27):
    dst_td = {}
    
    dst_td['x_train']      = []
    dst_td['t_train']      = []
    dst_td['x_mask_train'] = []
    dst_td['t_mask_train'] = []
    for _ in tqdm(range(train_size)):
        x, t, xm, tm = make_td_image(src_td['x_train'], src_td['t_train'], w=w, h=h, depth=depth)
        dst_td['x_train'].append(x)
        dst_td['t_train'].append(t)
        dst_td['x_mask_train'].append(xm)
        dst_td['t_mask_train'].append(tm)
    dst_td['x_train']      = np.array(dst_td['x_train'])
    dst_td['t_train']      = np.array(dst_td['t_train'])
    dst_td['x_mask_train'] = np.array(dst_td['x_mask_train'])
    dst_td['t_mask_train'] = np.array(dst_td['t_mask_train'])
    
    dst_td['x_test']      = []
    dst_td['t_test']      = []
    dst_td['x_mask_test'] = []
    dst_td['t_mask_test'] = []
    for _ in tqdm(range(test_size)):
        x, t, xm, tm = make_td_image(src_td['x_test'], src_td['t_test'], w=w, h=h, depth=depth)
        dst_td['x_test'].append(x)
        dst_td['t_test'].append(t)
        dst_td['x_mask_test'].append(xm)
        dst_td['t_mask_test'].append(tm)
    dst_td['x_test']      = np.array(dst_td['x_test'])
    dst_td['t_test']      = np.array(dst_td['t_test'])
    dst_td['x_mask_test'] = np.array(dst_td['x_mask_test'])
    dst_td['t_mask_test'] = np.array(dst_td['t_mask_test'])
    
    dst_td['x_shape'] = [depth*2+w, depth*2+h, 1]
    dst_td['t_shape'] = [w, h, 11]
    
    return dst_td

In [9]:
# load MNIST data
td = bb.load_mnist()

# ネットワーク構築

In [10]:
data_path = 'MnistMobileNetDistillation'
os.makedirs(data_path, exist_ok=True)
network_name = 'mnist-mobilenet-distillation-reverse'

ref_affine_list = []
ref_norm_list   = []
target_lut_list = []

def clear_list():
    ref_affine_list.clear()
    ref_norm_list.clear()
    target_lut_list.clear()

def save_model_list(model_list, name, path='.'):
    os.makedirs(os.path.join(data_path, path), exist_ok=True)
    for i, model in enumerate(model_list):
        model.save_json(os.path.join(data_path, path, '%s_%d.json' % (name, i)))

def load_model_list(model_list, name, path='.'):
    for i, model in enumerate(model_list):
        filename = os.path.join(data_path, path, '%s_%d.json' % (name, i))
        if not os.path.exists(filename):
            return False
        model.load_json(filename)
    return True

def save_all_model(path='.'):
    save_model_list(ref_affine_list, 'ref_affine',    path)
    save_model_list(ref_norm_list,   'ref_norm_list', path)
    save_model_list(target_lut_list, 'target_lut',    path)
    
def load_all_model(path='.'):
    if not load_model_list(ref_affine_list, 'ref_affine',    path): return False
    if not load_model_list(ref_norm_list,   'ref_norm_list', path): return False
    if not load_model_list(target_lut_list, 'target_lut',    path): return False
    return True

def make_common_layer(model):
    layer = {}
    layer['type']   = 'common'
    layer['select'] = 'common'
    layer['common'] = model
    return layer    

def make_cnv_layer(ch_size, w=3, h=3, lut_size=2, bn=True, connection='random', padding='valid'):
    # setup infomation
    layer = {}
    layer['type']            = 'convolution'
    layer['select']          = 'ref'
    layer['connection']      = connection
    layer['target_lut_size'] = lut_size
    
    if connection=='deptwise':
        layer['ref_affine'] = bb.DepthwiseDenseAffine.create([ch_size])
    else:
        layer['ref_affine'] = bb.DenseAffine.create([ch_size])        
    layer['ref_norm']   = bb.BatchNormalization.create() # momentum=0.1)
    layer['ref_act']    = bb.BinarizeBit.create()
    
    layer['lut_size'] = lut_size
    for i in range(lut_size):
        layer['target_lut%d' % i] = bb.SparseLut6Bit.create([ch_size*(6**i)], bn, connection)
    
    # save
    ref_affine_list.append(layer['ref_affine'])
    ref_norm_list.append(layer['ref_norm'])
    for i in range(lut_size-1, -1, -1):
        target_lut_list.append(layer['target_lut%d' % i])
    
    # make network
    ref_subnet =  bb.Sequential.create()
    ref_subnet.add(layer['ref_affine'])
    ref_subnet.add(layer['ref_norm'])
    ref_subnet.add(layer['ref_act'])
    layer['ref'] = bb.LoweringConvolutionBit.create(ref_subnet, w, h, 1, 1, padding=padding)
    
    target_subnet = bb.Sequential.create()
    for i in range(lut_size-1, -1, -1):
        target_subnet.add(layer['target_lut%d' % i])
    layer['target'] = bb.LoweringConvolutionBit.create(target_subnet, w, h, 1, 1, padding=padding)
    
    return layer

def build_net(layer_list):
    net = bb.Sequential.create()
    for layer in layer_list:
        net.add(layer[layer['select']])
    return net

In [11]:
# build network
modulation_size = 8
layer_rel2bin = bb.RealToBinaryBit.create(modulation_size, framewise=True)
layer_bin2rel = bb.BinaryToRealBit.create(modulation_size)

clear_list()
layer_list = []
layer_list.append(make_common_layer(layer_rel2bin))
layer_list.append(make_cnv_layer(36, 3, 3))                           # 26x26

layer_list.append(make_cnv_layer(36, 1, 1))
layer_list.append(make_cnv_layer(36, 3, 3,  connection='depthwise'))  # 24x24
layer_list.append(make_cnv_layer(36, 1, 1))

layer_list.append(make_common_layer(bb.MaxPoolingBit.create(2, 2)))   # 12x12

layer_list.append(make_cnv_layer(36*2, 1, 1))
layer_list.append(make_cnv_layer(36*2, 3, 3,  connection='depthwise'))  # 10x10
layer_list.append(make_cnv_layer(36, 1, 1))

layer_list.append(make_cnv_layer(36*2, 1, 1))
layer_list.append(make_cnv_layer(36*2, 3, 3,  connection='depthwise'))  # 8x8
layer_list.append(make_cnv_layer(36, 1, 1))

layer_list.append(make_common_layer(bb.MaxPoolingBit.create(2, 2)))      # 4x4

layer_list.append(make_cnv_layer(36*3, 1, 1))
layer_list.append(make_cnv_layer(36*3, 2, 2,  connection='depthwise'))  # 3x3
layer_list.append(make_cnv_layer(36, 1, 1))

layer_list.append(make_cnv_layer(36*3, 1, 1))
layer_list.append(make_cnv_layer(36*3, 3, 3,  connection='depthwise'))  # 1x1
layer_list.append(make_cnv_layer(10, 1, 1))

layer_list.append(make_common_layer(layer_bin2rel))

In [12]:
main_net = build_net(layer_list)
main_net.set_input_shape(td['x_shape'])
main_net.send_command('binary true')

print(main_net.get_info())

----------------------------------------------------------------------
[Sequential] 
  --------------------------------------------------------------------
  [RealToBinary] 
   input  shape : {28, 28, 1}   output shape : {28, 28, 1}
  --------------------------------------------------------------------
  [LoweringConvolution] 
   filter size : (3, 3)
   input  shape : {28, 28, 1}   output shape : {26, 26, 36}
    ------------------------------------------------------------------
    [ConvolutionIm2Col] 
     input  shape : {28, 28, 1}     output shape : {3, 3, 1}
    ------------------------------------------------------------------
    [Sequential] 
      ----------------------------------------------------------------
      [DenseAffine] 
       input  shape : {3, 3, 1}       output shape : {36}
      ----------------------------------------------------------------
      [BatchNormalization] 
       input  shape : {36}       output shape : {36}
      ---------------------------------

In [13]:
# train
loss      = bb.LossSoftmaxCrossEntropy.create()
metrics   = bb.MetricsCategoricalAccuracy.create()
optimizer = bb.OptimizerAdam.create()
optimizer.set_variables(main_net.get_parameters(), main_net.get_gradients())

runner = bb.Runner(main_net, network_name, loss, metrics, optimizer)
runner.fitting(td, epoch_size=64, mini_batch_size=32, file_write=False, file_read=False)

  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=1]

epoch=1 test_accuracy=0.933800 test_loss=1.533198 train_accuracy=0.927983 train_loss=1.539858


  0%|                                                              | 0/1875 [00:00<?, ?it/s, loss=1.53, accuracy=0.906]

epoch=2 test_accuracy=0.981500 test_loss=1.480170 train_accuracy=0.980100 train_loss=1.482367


  0%|                                                              | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=0.969]

epoch=3 test_accuracy=0.984500 test_loss=1.477892 train_accuracy=0.984550 train_loss=1.478139


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=4 test_accuracy=0.987100 test_loss=1.475835 train_accuracy=0.986000 train_loss=1.476892


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=5 test_accuracy=0.593700 test_loss=1.835865 train_accuracy=0.586317 train_loss=1.841793


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=6 test_accuracy=0.988900 test_loss=1.472953 train_accuracy=0.990250 train_loss=1.472403


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=7 test_accuracy=0.991100 test_loss=1.471702 train_accuracy=0.990750 train_loss=1.471846


  0%|                                                              | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=0.969]

epoch=8 test_accuracy=0.990900 test_loss=1.471995 train_accuracy=0.991017 train_loss=1.471568


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=9 test_accuracy=0.687400 test_loss=1.757574 train_accuracy=0.683133 train_loss=1.763926


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=10 test_accuracy=0.991600 test_loss=1.472492 train_accuracy=0.991333 train_loss=1.472339


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.46, accuracy=1]

epoch=11 test_accuracy=0.991000 test_loss=1.471598 train_accuracy=0.992183 train_loss=1.470754


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=12 test_accuracy=0.992400 test_loss=1.470524 train_accuracy=0.992783 train_loss=1.469881


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=13 test_accuracy=0.981500 test_loss=1.490945 train_accuracy=0.980033 train_loss=1.492651


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.46, accuracy=1]

epoch=14 test_accuracy=0.992100 test_loss=1.470178 train_accuracy=0.993217 train_loss=1.469602


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=15 test_accuracy=0.992100 test_loss=1.469796 train_accuracy=0.993317 train_loss=1.469150


  0%|                                                               | 0/1875 [00:00<?, ?it/s, loss=1.5, accuracy=0.969]

epoch=16 test_accuracy=0.992300 test_loss=1.470587 train_accuracy=0.992883 train_loss=1.470245


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=17 test_accuracy=0.992600 test_loss=1.470425 train_accuracy=0.994200 train_loss=1.469153


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=18 test_accuracy=0.992300 test_loss=1.470556 train_accuracy=0.992517 train_loss=1.469899


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=19 test_accuracy=0.919400 test_loss=1.573935 train_accuracy=0.908850 train_loss=1.579997


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=1]

epoch=20 test_accuracy=0.992700 test_loss=1.470332 train_accuracy=0.994050 train_loss=1.469178


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=21 test_accuracy=0.915500 test_loss=1.562860 train_accuracy=0.911867 train_loss=1.567182


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.46, accuracy=1]

epoch=22 test_accuracy=0.986300 test_loss=1.477005 train_accuracy=0.987400 train_loss=1.477479


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=23 test_accuracy=0.993000 test_loss=1.470549 train_accuracy=0.993367 train_loss=1.469969


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=24 test_accuracy=0.984100 test_loss=1.483549 train_accuracy=0.982233 train_loss=1.485368


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=25 test_accuracy=0.993600 test_loss=1.469621 train_accuracy=0.994400 train_loss=1.468531


  0%|                                                          | 1/1875 [00:00<04:08,  7.55it/s, loss=1.48, accuracy=1]

epoch=26 test_accuracy=0.992800 test_loss=1.471700 train_accuracy=0.992833 train_loss=1.471375


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=27 test_accuracy=0.993100 test_loss=1.470926 train_accuracy=0.994217 train_loss=1.469481


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=28 test_accuracy=0.992800 test_loss=1.470241 train_accuracy=0.994233 train_loss=1.469087


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=29 test_accuracy=0.993000 test_loss=1.470905 train_accuracy=0.992683 train_loss=1.470691


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=30 test_accuracy=0.991800 test_loss=1.470922 train_accuracy=0.992533 train_loss=1.470640


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.46, accuracy=1]

epoch=31 test_accuracy=0.992500 test_loss=1.470699 train_accuracy=0.994217 train_loss=1.469018


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=32 test_accuracy=0.992900 test_loss=1.470293 train_accuracy=0.995017 train_loss=1.469184


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=1]

epoch=33 test_accuracy=0.924100 test_loss=1.543026 train_accuracy=0.923133 train_loss=1.545758


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=34 test_accuracy=0.982400 test_loss=1.487940 train_accuracy=0.981300 train_loss=1.488694


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=35 test_accuracy=0.992600 test_loss=1.470100 train_accuracy=0.994433 train_loss=1.469092


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=36 test_accuracy=0.887700 test_loss=1.603756 train_accuracy=0.884333 train_loss=1.608033


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=37 test_accuracy=0.991500 test_loss=1.473066 train_accuracy=0.992200 train_loss=1.472449


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=1]

epoch=38 test_accuracy=0.981800 test_loss=1.486896 train_accuracy=0.980217 train_loss=1.488124


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=39 test_accuracy=0.963000 test_loss=1.515554 train_accuracy=0.958867 train_loss=1.518846


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=40 test_accuracy=0.993100 test_loss=1.469937 train_accuracy=0.994567 train_loss=1.469194


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=41 test_accuracy=0.993300 test_loss=1.470345 train_accuracy=0.994600 train_loss=1.468935


  0%|                                                              | 0/1875 [00:00<?, ?it/s, loss=1.49, accuracy=0.969]

epoch=42 test_accuracy=0.994200 test_loss=1.470857 train_accuracy=0.994733 train_loss=1.469613


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=43 test_accuracy=0.991500 test_loss=1.471682 train_accuracy=0.993317 train_loss=1.470717


  0%|                                                               | 0/1875 [00:00<?, ?it/s, loss=1.5, accuracy=0.969]

epoch=44 test_accuracy=0.989800 test_loss=1.476602 train_accuracy=0.992083 train_loss=1.475681


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=45 test_accuracy=0.990700 test_loss=1.477780 train_accuracy=0.992183 train_loss=1.476391


  0%|                                                              | 0/1875 [00:00<?, ?it/s, loss=1.48, accuracy=0.969]

epoch=46 test_accuracy=0.993600 test_loss=1.469713 train_accuracy=0.995233 train_loss=1.467974


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.46, accuracy=1]

epoch=47 test_accuracy=0.994000 test_loss=1.469672 train_accuracy=0.995700 train_loss=1.468328


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=48 test_accuracy=0.994300 test_loss=1.469372 train_accuracy=0.995600 train_loss=1.467994


  0%|                                                               | 0/1875 [00:00<?, ?it/s, loss=1.5, accuracy=0.969]

epoch=49 test_accuracy=0.994000 test_loss=1.469405 train_accuracy=0.995500 train_loss=1.468138


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=50 test_accuracy=0.993400 test_loss=1.469636 train_accuracy=0.994733 train_loss=1.468157


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=51 test_accuracy=0.970600 test_loss=1.504411 train_accuracy=0.972217 train_loss=1.506097


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=52 test_accuracy=0.889600 test_loss=1.615806 train_accuracy=0.879367 train_loss=1.625528


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=53 test_accuracy=0.992900 test_loss=1.470678 train_accuracy=0.994133 train_loss=1.469551


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=54 test_accuracy=0.992000 test_loss=1.470314 train_accuracy=0.995300 train_loss=1.468477


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=55 test_accuracy=0.994500 test_loss=1.469436 train_accuracy=0.995433 train_loss=1.468559


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=56 test_accuracy=0.993700 test_loss=1.470310 train_accuracy=0.995617 train_loss=1.468485


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=57 test_accuracy=0.991600 test_loss=1.474060 train_accuracy=0.992133 train_loss=1.473881


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=58 test_accuracy=0.992400 test_loss=1.470280 train_accuracy=0.994983 train_loss=1.468607


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=59 test_accuracy=0.992400 test_loss=1.471621 train_accuracy=0.994083 train_loss=1.469747


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=60 test_accuracy=0.992300 test_loss=1.470970 train_accuracy=0.995417 train_loss=1.468823


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=61 test_accuracy=0.992400 test_loss=1.471027 train_accuracy=0.994367 train_loss=1.469255


  0%|                                                                  | 0/1875 [00:00<?, ?it/s, loss=1.47, accuracy=1]

epoch=62 test_accuracy=0.994000 test_loss=1.469570 train_accuracy=0.995717 train_loss=1.468031


  0%|                                                                                         | 0/1875 [00:00<?, ?it/s]

epoch=63 test_accuracy=0.993600 test_loss=1.469591 train_accuracy=0.995300 train_loss=1.468460


                                                                                                                       

epoch=64 test_accuracy=0.992900 test_loss=1.471016 train_accuracy=0.994617 train_loss=1.469567
