<a href="https://colab.research.google.com/github/prateekgulati/-Breast-Cancer-ML/blob/master/Notebook_Prateek/RG2_Notebook7_DavidNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Summary


This notebook has experiments to find good augmentation strategy to incorporate within the network.  
Augmentation-channel level: 
```
       One strategy randomly selected
          0: random_pad_crop_image(channel,padding=padding)
          1: flip_left_right(channel)
          2: cutout_channel(channel,size=cutSize)
          3: flip_left_right(random_pad_crop_image(channel,padding=padding))
          4: cutout_channel(random_pad_crop_image(channel,padding=padding),size=cutSize)
          5: cutout_channel(flip_left_right(channel),size=cutSize)
          6: cutout_channel(flip_left_right(random_pad_crop_image(channel,padding=padding)),size=cutSize)
```

|    Parameters   |     Applied at     | Train  | Test (epoch) |
|:---------------:|:------------------:|:------:|:------------:|
| pad=1,cutSize=4 | Initial Conv block |  99.73 | **90.95 (22nd)** |
| pad=1,cutSize=8 | Initial Conv block |  96.78 | **90.91 (24th)** |
| pad=2,cutSize=4 | Initial Conv block |  97.39 | 85.16 (23rd) |
| pad=2,cutSize=8 | Initial Conv block |  94.95 | 84.48 (24th) |
| pad=4,cutSize=4 | Initial Conv block |  99.20 | **91.49 (23rd)** |
| pad=4,cutSize=8 | Initial Conv block |  97.90 | **90.96 (24th)** |
| pad=1,cutSize=4 |   ResNet Block 1   |  97.44 | 84.21 (20th) |
| pad=1,cutSize=8 |   ResNet Block 1   |  93.17 | 83.63 (21st) |
| pad=2,cutSize=4 |   ResNet Block 1   |  94.19 | 83.68 (22nd) |
| pad=2,cutSize=8 |   ResNet Block 1   |  98.09 | **91.94 (24th)**    |
| pad=2,cutSize=2 |   ResNet Block 1   |  96.86 | 84.01 (18th) |
| pad=1,cutSize=2 |   ResNet Block 2   |  98.07 | 80.64 (16th) |
| pad=1,cutSize=4 |   ResNet Block 2   |  99.18 | **91.02 (24th)** |
| pad=2,cutSize=1 |   ResNet Block 2   |  97.71 |   78.54 (10th) |
| pad=2,cutSize=2 |   ResNet Block 2   |  99.74 | **90.46 (22nd)** |
| pad=2,cutSize=4 |   ResNet Block 2   |  93.46 | 81.31 (14th) |
| pad=1,cutSize=1 |   ResNet Block 3   |  100.0 | 82.16 (23nd) |
| pad=1,cutSize=2 |   ResNet Block 3   |  99.98 | 88.55 (24th) |
| pad=2,cutSize=1 |   ResNet Block 3   |  99.94 | 81.05 (24th) |
| pad=2,cutSize=2 |   ResNet Block 3   |  99.65 | 80.89 (24th) |
| pad=2,cutSize=4 |   ResNet Block 3   |  58.59 | 78.26 (24th) |

### Code



In [0]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import random

In [0]:
tf.enable_eager_execution()


In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
class Conv(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)    

  def call(self, inputs):
    return tf.nn.relu(self.conv(inputs))

In [0]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [0]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [0]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

test_mean = np.mean(x_train, axis=(0,1,2))
test_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
normalize_test = lambda x: ((x - test_mean) / test_std).astype('float32') # todo: check here
# pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

# x_train = normalize(pad4(x_train))
x_train = normalize(x_train)
# x_test = normalize(x_test)
x_test = normalize_test(x_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.8*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.1*LEARNING_RATE, 0.005])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
def cutout_channel(img, prob=100, size=8, min_size=5, use_fixed_size=True):
  return tf.cond(tf.random.uniform([], 0, 100) > prob, lambda: img , lambda: get_cutout_channel(img,size,min_size,use_fixed_size))


def get_cutout_channel(img, size=8,min_size=2,use_fixed_size=True):
    height = tf.shape(img)[0]
    width = tf.shape(img)[1]
    channel = tf.shape(img)[2]
    area = tf.cast(width*height, tf.float32)
    if (use_fixed_size==True):
      s=size
    else:  
      s=tf.random.uniform([], min_size, size, tf.int32)
    x1 = tf.random.uniform([], 0, height+1-s , tf.int32) # get the x offset from top left
    y1 = tf.random.uniform([], 0, width+1-s , tf.int32)
    img1 = tf.ones_like(img)  
    #print(tf.shape(img1))
    cut_slice = tf.slice(
    img1,
    [x1, y1, 0],
    [s, s, channel])
    #create mask similar in shape to input image with cutout area having ones and rest of the area padded with zeros 
    mask = tf.image.pad_to_bounding_box(
      cut_slice,
      x1,
      y1,
      height,
      width
    )
    mask = tf.ones_like(mask) - mask
    tmp_img = tf.multiply(img,mask)
    cut_img =tmp_img
    return cut_img

In [0]:
def random_pad_crop_image(image,padding=4):
  # global ctr
  # ctr=ctr+1
  shape=tf.shape(image)  
  image=tf.pad(image,[(padding, padding), (padding, padding), (0, 0)])
  image=tf.image.random_crop(image,size=shape)
  return image

def random_pad_crop_batch(batch,padding=4):
  # global ctr
  # ctr=ctr+1
  shape=tf.shape(batch)  
  batch=tf.pad(batch,[(0, 0), (padding, padding), (padding, padding), (0, 0)])
  batch=tf.image.random_crop(batch,size=shape)
  return batch

def flip_left_right(image):
#   global ctr
#   ctr=ctr+1
  return tf.image.random_flip_left_right(image)  

def no_augmentation(batch):
  # global ctr
  # ctr=ctr+1
  return batch

def augmentDictBatch(batch,padding):
  return tf.switch_case(
        tf.random_uniform([],0,6, dtype=tf.dtypes.int32), 
        branch_fns={
            0: lambda:random_pad_crop_batch(batch,padding=padding[0]),
            1: lambda:flip_left_right(batch),
            2: lambda:flip_left_right(random_pad_crop_batch(batch,padding=padding[0])),
            3: lambda:random_pad_crop_batch(batch,padding=padding[1]),
            4: lambda:flip_left_right(random_pad_crop_batch(batch,padding=padding[1])),
        }, 
        default= lambda: batch # 5
    )  

def augmentDictChannel(batch,padding,cutSize):
  return tf.switch_case(
        tf.random_uniform([],0,9, dtype=tf.dtypes.int32), 
        branch_fns={
            0: lambda:tf.map_fn(lambda channel: random_pad_crop_image(channel,padding=padding), batch),
            1: lambda:tf.map_fn(lambda channel: flip_left_right(channel), batch),
            2: lambda:tf.map_fn(lambda channel: cutout_channel(channel,size=cutSize), batch),
            3: lambda:tf.map_fn(lambda channel: flip_left_right(random_pad_crop_image(channel,padding=padding)), batch),            
            4: lambda:tf.map_fn(lambda channel: cutout_channel(random_pad_crop_image(channel,padding=padding),size=cutSize), batch),
            5: lambda:tf.map_fn(lambda channel: cutout_channel(flip_left_right(channel),size=cutSize), batch),
            6: lambda:tf.map_fn(lambda channel: cutout_channel(flip_left_right(random_pad_crop_image(channel,padding=padding)),size=cutSize), batch),
        }, 
        default= lambda: batch # 7,8
    )

### Without Augmentation

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  training=True
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.6341621618652344 train acc: 0.42006 val loss: 1.2786644775390625 val acc: 0.5336 time: 84.12477540969849


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.0550838256835937 train acc: 0.62662 val loss: 1.0723008911132812 val acc: 0.6166 time: 168.25910019874573


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.7666204714965821 train acc: 0.73398 val loss: 0.8635984420776367 val acc: 0.6925 time: 252.41594672203064


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5686230340576172 train acc: 0.80822 val loss: 0.794242236328125 val acc: 0.7234 time: 336.4668564796448


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.40912643966674805 train acc: 0.86958 val loss: 0.8083009780883789 val acc: 0.7189 time: 420.6586241722107


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.25697577102661134 train acc: 0.92704 val loss: 0.7783981506347656 val acc: 0.733 time: 504.72028970718384


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.13866929374694825 train acc: 0.9678 val loss: 0.7423687622070313 val acc: 0.7564 time: 588.8281450271606


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.05586202363967895 train acc: 0.99316 val loss: 0.6656653503417969 val acc: 0.7868 time: 672.8910217285156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.01981062524795532 train acc: 0.99924 val loss: 0.6305414154052734 val acc: 0.8026 time: 756.95898604393


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.00887911123752594 train acc: 0.99992 val loss: 0.6138105651855469 val acc: 0.8087 time: 841.0362412929535


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.005965333847999573 train acc: 0.99998 val loss: 0.6266000091552735 val acc: 0.8079 time: 925.2222571372986


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.004728774175643921 train acc: 1.0 val loss: 0.6255576232910156 val acc: 0.8111 time: 1009.0526669025421


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.0039314536690711974 train acc: 1.0 val loss: 0.6261321350097656 val acc: 0.8111 time: 1093.2347838878632


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.0034008057999610902 train acc: 1.0 val loss: 0.6329055755615235 val acc: 0.8096 time: 1177.3435504436493


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.0029965528178215028 train acc: 1.0 val loss: 0.63671826171875 val acc: 0.8103 time: 1261.516449213028


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.0026799788820743562 train acc: 1.0 val loss: 0.6441303314208985 val acc: 0.8091 time: 1345.632423400879


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.0024527893006801604 train acc: 1.0 val loss: 0.6435116027832031 val acc: 0.8109 time: 1429.7537217140198


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.002233844630718231 train acc: 1.0 val loss: 0.6464360290527343 val acc: 0.8113 time: 1513.8203377723694


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.002057218931913376 train acc: 1.0 val loss: 0.6510072509765625 val acc: 0.8106 time: 1597.9055850505829


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.0018887103962898254 train acc: 1.0 val loss: 0.6530312286376954 val acc: 0.8096 time: 1681.9618873596191


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.001757127342224121 train acc: 1.0 val loss: 0.6537623336791992 val acc: 0.8095 time: 1766.0287747383118


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.0016434865772724152 train acc: 1.0 val loss: 0.6562681594848633 val acc: 0.8101 time: 1850.128155708313


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.0015578745555877686 train acc: 1.0 val loss: 0.6610282562255859 val acc: 0.8099 time: 1934.2064535617828


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.001461819063425064 train acc: 1.0 val loss: 0.6627364303588867 val acc: 0.8099 time: 2018.2653663158417


### Augmentation inside network
Augmentation Appled at: After 1st convolution  
Augmentation Strategy: Random Pad Crop (pad=1,cutSize=4)

In [0]:
pad=1
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch   

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.478090366821289 train acc: 0.4648 val loss: 1.1969029052734375 val acc: 0.5931 time: 336.0568616390228


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.7908851968383789 train acc: 0.71798 val loss: 0.8535251312255859 val acc: 0.7124 time: 660.4499454498291


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.5993705606079102 train acc: 0.79074 val loss: 1.1467464385986328 val acc: 0.6698 time: 994.4949586391449


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5002247698974609 train acc: 0.82538 val loss: 1.4085215454101563 val acc: 0.6339 time: 1329.4532070159912


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.42898507537841796 train acc: 0.85166 val loss: 1.0110733062744142 val acc: 0.6936 time: 1692.034119606018


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.3490048262023926 train acc: 0.87874 val loss: 0.6422595733642578 val acc: 0.8036 time: 2027.0390186309814


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.25633068420410154 train acc: 0.91272 val loss: 0.4870686050415039 val acc: 0.8461 time: 2351.6388626098633


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.201245678024292 train acc: 0.93064 val loss: 0.4863878540039063 val acc: 0.8482 time: 2699.0924503803253


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.15110100090026857 train acc: 0.94884 val loss: 0.43356957702636717 val acc: 0.8681 time: 3019.7849378585815


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.12441808052062989 train acc: 0.9579 val loss: 0.5314265670776367 val acc: 0.8543 time: 3331.34588599205


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.09737268314361572 train acc: 0.96708 val loss: 0.4919749496459961 val acc: 0.8626 time: 3640.6757690906525


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.074343876247406 train acc: 0.97552 val loss: 0.44954756011962893 val acc: 0.8754 time: 3972.5427720546722


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.0677045074748993 train acc: 0.978 val loss: 0.43599405059814456 val acc: 0.885 time: 4320.658070325851


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.045422395453453064 train acc: 0.98594 val loss: 0.42032209243774415 val acc: 0.8899 time: 4610.22252368927


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.03442320176124573 train acc: 0.98928 val loss: 0.49803761596679685 val acc: 0.8781 time: 4913.162062883377


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.02796058429479599 train acc: 0.99124 val loss: 0.38136397399902344 val acc: 0.9012 time: 5222.022808790207


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.023434138736724855 train acc: 0.99334 val loss: 0.39557887268066405 val acc: 0.901 time: 5546.266808509827


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.0179736554646492 train acc: 0.99502 val loss: 0.38416852416992187 val acc: 0.9029 time: 5841.579709529877


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.01327041626393795 train acc: 0.9964 val loss: 0.37956133956909177 val acc: 0.9035 time: 6145.184897899628


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.014433642601966858 train acc: 0.99588 val loss: 0.3712515808105469 val acc: 0.9079 time: 6473.598547458649


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.012112502950429916 train acc: 0.99662 val loss: 0.3742081428527832 val acc: 0.9074 time: 6831.260121583939


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.009388836225569248 train acc: 0.99772 val loss: 0.36701606674194337 val acc: 0.9095 time: 7160.159880161285


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.007516123212873936 train acc: 0.99828 val loss: 0.3687893196105957 val acc: 0.9091 time: 7483.458249330521


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.01032744450300932 train acc: 0.99736 val loss: 0.37061681213378905 val acc: 0.9087 time: 7810.0663385391235


### Augmentation inside network
Augmentation Appled at: After 1st convolution  
Augmentation Strategy: Random Pad Crop (pad=1,cutSize=8)

In [0]:
pad=1
cutSize=8
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch   

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.458164061279297 train acc: 0.46978 val loss: 1.0126893096923828 val acc: 0.6366 time: 312.94167137145996


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.7973662045288086 train acc: 0.71806 val loss: 0.9206479278564453 val acc: 0.6877 time: 619.7380905151367


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6153512570190429 train acc: 0.78678 val loss: 0.7740326385498046 val acc: 0.7448 time: 932.5455813407898


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5288537127685546 train acc: 0.81496 val loss: 0.9886033142089844 val acc: 0.7004 time: 1281.8481171131134


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4507023260498047 train acc: 0.84454 val loss: 0.6566224060058594 val acc: 0.7779 time: 1621.7052855491638


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.38679775421142576 train acc: 0.86824 val loss: 0.5948684997558594 val acc: 0.8151 time: 1974.3353941440582


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.29687256912231447 train acc: 0.89698 val loss: 0.4699892852783203 val acc: 0.8461 time: 2312.106716632843


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.22726799507141113 train acc: 0.92134 val loss: 0.5540105331420898 val acc: 0.8279 time: 2639.175468683243


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.18844825637817383 train acc: 0.9351 val loss: 0.4553110336303711 val acc: 0.8591 time: 2965.077484369278


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.15442725616455077 train acc: 0.94782 val loss: 0.5010489044189453 val acc: 0.8556 time: 3284.877879858017


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.11953215757369995 train acc: 0.95938 val loss: 0.46530227966308596 val acc: 0.8647 time: 3591.630217075348


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.09659168193817139 train acc: 0.96826 val loss: 0.5509934921264649 val acc: 0.861 time: 3887.999388694763


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.08268924755096435 train acc: 0.9732 val loss: 0.4546795555114746 val acc: 0.8789 time: 4203.695697546005


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.0693309010219574 train acc: 0.9779 val loss: 0.39172375564575196 val acc: 0.8955 time: 4519.3546822071075


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.053219323744773867 train acc: 0.9831 val loss: 0.4325897796630859 val acc: 0.8864 time: 4838.512014389038


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.04007314752817154 train acc: 0.98686 val loss: 0.365929776763916 val acc: 0.9007 time: 5159.555776119232


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.030581861295700074 train acc: 0.99068 val loss: 0.3871241943359375 val acc: 0.903 time: 5477.784622907639


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.02946550108909607 train acc: 0.99064 val loss: 0.37179121322631836 val acc: 0.9056 time: 5822.70593214035


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.02568089189887047 train acc: 0.99228 val loss: 0.3770712127685547 val acc: 0.9031 time: 6171.384992599487


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.018307467818260193 train acc: 0.99426 val loss: 0.3601479965209961 val acc: 0.9075 time: 6513.843440055847


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.016564615815877915 train acc: 0.99492 val loss: 0.3574476318359375 val acc: 0.9065 time: 6856.54891037941


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.012975655930042266 train acc: 0.9962 val loss: 0.35752385940551756 val acc: 0.9087 time: 7145.695568084717


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.015529093179702759 train acc: 0.99518 val loss: 0.35403167877197267 val acc: 0.9089 time: 7472.154883861542


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.013565014469623566 train acc: 0.99614 val loss: 0.35403314895629884 val acc: 0.9091 time: 7793.047049045563


### Augmentation inside network
Augmentation Appled at: After 1st convolution  
Augmentation Strategy: Random Pad Crop (pad=2,cutsize=4)

In [0]:
pad=2
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.6665190509033203 train acc: 0.41174 val loss: 1.3188296173095704 val acc: 0.5366 time: 306.766872882843


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1134955206298829 train acc: 0.60544 val loss: 0.9823518981933593 val acc: 0.645 time: 602.5171127319336


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.8627255871582031 train acc: 0.6958 val loss: 0.9149790283203125 val acc: 0.6778 time: 934.1883838176727


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.717933583984375 train acc: 0.74958 val loss: 0.7978549179077148 val acc: 0.7245 time: 1227.2634735107422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.6073503503417969 train acc: 0.79008 val loss: 0.7257555465698242 val acc: 0.7465 time: 1558.592827796936


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5191341183471679 train acc: 0.82364 val loss: 0.6633151504516601 val acc: 0.7703 time: 1864.507393360138


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.45633617370605467 train acc: 0.8459 val loss: 0.6351809524536133 val acc: 0.7776 time: 2199.958199262619


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.4000434927368164 train acc: 0.86462 val loss: 0.6287017684936523 val acc: 0.782 time: 2528.0474622249603


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.34671671249389646 train acc: 0.88266 val loss: 0.5542364120483398 val acc: 0.8065 time: 2820.5770049095154


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.2964210111999512 train acc: 0.90156 val loss: 0.6739828521728516 val acc: 0.7809 time: 3091.4576795101166


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.27508279800415036 train acc: 0.91046 val loss: 0.559666325378418 val acc: 0.8105 time: 3409.1214253902435


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.2506812582397461 train acc: 0.9192 val loss: 0.5346502288818359 val acc: 0.8236 time: 3767.356528520584


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.2377200437927246 train acc: 0.92262 val loss: 0.5410789749145508 val acc: 0.8191 time: 4127.662163734436


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.20429827323913574 train acc: 0.93452 val loss: 0.5461317840576172 val acc: 0.82 time: 4475.734708786011


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.18187912620544433 train acc: 0.9418 val loss: 0.6914864288330078 val acc: 0.7904 time: 4769.743026018143


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.15070935550689699 train acc: 0.95426 val loss: 0.6147449508666992 val acc: 0.8097 time: 5062.876434087753


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.1327720813369751 train acc: 0.96058 val loss: 0.5734961334228516 val acc: 0.8198 time: 5344.670038700104


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.12253231967926026 train acc: 0.96226 val loss: 0.5650581146240234 val acc: 0.8234 time: 5650.835715532303


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.12683837894439698 train acc: 0.96002 val loss: 0.5248507827758789 val acc: 0.8358 time: 5973.284979104996


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.12287731004714966 train acc: 0.96172 val loss: 0.59318916015625 val acc: 0.8221 time: 6295.713894605637


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.11210089860916138 train acc: 0.96538 val loss: 0.6026990142822266 val acc: 0.8194 time: 6623.849063634872


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.1270035678100586 train acc: 0.95946 val loss: 0.5297336853027343 val acc: 0.8346 time: 6953.815659046173


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.09680546756744385 train acc: 0.96982 val loss: 0.5035704879760742 val acc: 0.8516 time: 7242.827726840973


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.08664260103225709 train acc: 0.9739 val loss: 0.502859732055664 val acc: 0.8454 time: 7568.04896402359


### Augmentation inside network
Augmentation Appled at: After 1st convolution  
Augmentation Strategy: Random Pad Crop (pad=2,cutsize=8)

In [0]:
pad=2
cutSize=8
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.6671758679199218 train acc: 0.40916 val loss: 1.3326741790771484 val acc: 0.5145 time: 308.8178584575653


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1341986395263672 train acc: 0.5953 val loss: 1.0001943145751953 val acc: 0.6484 time: 626.3598575592041


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.8928193743896484 train acc: 0.68358 val loss: 0.8606158950805664 val acc: 0.6973 time: 954.7900516986847


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.7312901742553711 train acc: 0.74674 val loss: 0.7395172241210938 val acc: 0.7407 time: 1237.4006350040436


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.6241932290649415 train acc: 0.78412 val loss: 0.6859697784423828 val acc: 0.7595 time: 1517.785236120224


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5498588790893555 train acc: 0.81028 val loss: 0.6255640380859375 val acc: 0.7814 time: 1835.9508001804352


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.48285829620361326 train acc: 0.83556 val loss: 0.6406650604248046 val acc: 0.7775 time: 2146.171797513962


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.429098742980957 train acc: 0.85352 val loss: 0.6842124984741211 val acc: 0.7651 time: 2460.7081587314606


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.3815388682556152 train acc: 0.8696 val loss: 0.6517674682617187 val acc: 0.7778 time: 2786.1070635318756


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.33750039932250975 train acc: 0.8858 val loss: 0.5708413604736328 val acc: 0.8025 time: 3110.338561296463


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.2988060968017578 train acc: 0.90034 val loss: 0.5695644561767578 val acc: 0.8119 time: 3412.5537457466125


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.26661588653564455 train acc: 0.91236 val loss: 0.6020491500854492 val acc: 0.8059 time: 3714.085151195526


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.25469062644958496 train acc: 0.9162 val loss: 0.5649087478637695 val acc: 0.8128 time: 4030.9537296295166


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.21969031616210938 train acc: 0.92902 val loss: 0.5769761779785156 val acc: 0.8088 time: 4333.34406542778


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.23029907409667968 train acc: 0.92384 val loss: 0.6579989471435547 val acc: 0.7926 time: 4692.258258342743


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.18311012104034424 train acc: 0.94166 val loss: 0.5380835067749024 val acc: 0.8267 time: 4997.085289955139


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.17409007511138916 train acc: 0.94448 val loss: 0.5639341476440429 val acc: 0.8241 time: 5338.322377920151


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.15983867683410644 train acc: 0.94952 val loss: 0.48949864807128907 val acc: 0.8448 time: 5628.0934336185455


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.13471918029785157 train acc: 0.95802 val loss: 0.5444285430908203 val acc: 0.8271 time: 5914.419582366943


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.14139594734191893 train acc: 0.95444 val loss: 0.597479817199707 val acc: 0.82 time: 6238.96818113327


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.13843443103790284 train acc: 0.95552 val loss: 0.5362965423583984 val acc: 0.8405 time: 6547.7421452999115


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.11942970338821411 train acc: 0.96188 val loss: 0.5366804824829101 val acc: 0.8348 time: 6854.317105054855


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.11858439964294433 train acc: 0.9622 val loss: 0.5613835708618164 val acc: 0.8286 time: 7183.854322910309


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.11282662817001343 train acc: 0.96382 val loss: 0.5903560653686524 val acc: 0.8241 time: 7499.541564464569


### Augmentation inside network
Augmentation Appled at: After 1st convolution  
Augmentation Strategy: Random Pad Crop (pad=4,cutsize=4)

In [0]:
pad=4
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.5187113116455078 train acc: 0.44504 val loss: 1.4419800170898438 val acc: 0.5321 time: 335.9128267765045


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8664142315673828 train acc: 0.69084 val loss: 0.914839013671875 val acc: 0.6949 time: 654.8884637355804


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6443114953613281 train acc: 0.7744 val loss: 0.8579335800170899 val acc: 0.7074 time: 948.1267635822296


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5481588903808594 train acc: 0.81108 val loss: 0.8897675415039062 val acc: 0.7215 time: 1278.2319159507751


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.47748458526611326 train acc: 0.8356 val loss: 0.8229106842041015 val acc: 0.7517 time: 1564.429339647293


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.38484749755859377 train acc: 0.86844 val loss: 0.6322223678588867 val acc: 0.7981 time: 1866.876712322235


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.32171222396850585 train acc: 0.8909 val loss: 0.6327680541992188 val acc: 0.8022 time: 2203.7777841091156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.2630330041503906 train acc: 0.90966 val loss: 0.52029755859375 val acc: 0.8363 time: 2497.946801185608


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.2052902182006836 train acc: 0.92952 val loss: 0.4570541442871094 val acc: 0.8578 time: 2801.801237344742


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.18521284950256348 train acc: 0.93628 val loss: 0.47762760467529297 val acc: 0.8576 time: 3095.008405447006


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.14593953174591065 train acc: 0.95064 val loss: 0.44182633361816404 val acc: 0.8718 time: 3451.02352142334


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.12394369434356689 train acc: 0.95758 val loss: 0.39411910400390626 val acc: 0.8825 time: 3773.891178369522


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.11345506689071655 train acc: 0.96302 val loss: 0.4172412498474121 val acc: 0.8773 time: 4128.20783162117


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.0973874268913269 train acc: 0.96752 val loss: 0.35621521072387696 val acc: 0.8945 time: 4449.583037614822


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.07721847107887268 train acc: 0.974 val loss: 0.37397850494384766 val acc: 0.8953 time: 4768.577209949493


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.06672696594238281 train acc: 0.97792 val loss: 0.3506197868347168 val acc: 0.9009 time: 5056.237311840057


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.043648429579734804 train acc: 0.9862 val loss: 0.34258751373291013 val acc: 0.9076 time: 5339.173730611801


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.04712669388771057 train acc: 0.9851 val loss: 0.3353434844970703 val acc: 0.9075 time: 5682.603320598602


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.035997266855239865 train acc: 0.98868 val loss: 0.32623980865478514 val acc: 0.9105 time: 5995.714622497559


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.0294507517015934 train acc: 0.99112 val loss: 0.3231544403076172 val acc: 0.9121 time: 6321.427857875824


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.02961469107091427 train acc: 0.99052 val loss: 0.3180964477539063 val acc: 0.9128 time: 6636.2335550785065


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.02613851903319359 train acc: 0.9917 val loss: 0.3181461250305176 val acc: 0.913 time: 6962.006323099136


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.026843778510093688 train acc: 0.99166 val loss: 0.31983998565673827 val acc: 0.9149 time: 7290.126756906509


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.026478399572372437 train acc: 0.99202 val loss: 0.3156652931213379 val acc: 0.9145 time: 7613.910309314728


### Augmentation inside network
Augmentation Appled at: After 1st convolution  
Augmentation Strategy: Random Pad Crop (pad=4,cutsize=8)

In [0]:
pad=4
cutSize=8
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.data_aug2(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.513200711669922 train acc: 0.44656 val loss: 1.286187078857422 val acc: 0.5724 time: 303.1924624443054


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8872401934814453 train acc: 0.68434 val loss: 1.351294024658203 val acc: 0.5962 time: 613.5565748214722


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6717058825683594 train acc: 0.76646 val loss: 0.9458463623046875 val acc: 0.7094 time: 920.2210185527802


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5733397882080078 train acc: 0.79958 val loss: 0.6886352966308594 val acc: 0.7766 time: 1239.274215221405


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4983666015625 train acc: 0.82746 val loss: 1.1944845336914063 val acc: 0.6603 time: 1538.3806099891663


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.41712127395629883 train acc: 0.8557 val loss: 0.6799238189697265 val acc: 0.7954 time: 1847.2626869678497


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.3430701301574707 train acc: 0.8822 val loss: 0.5068866516113282 val acc: 0.8352 time: 2161.5307517051697


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.2931094972229004 train acc: 0.89822 val loss: 0.47407769012451173 val acc: 0.8473 time: 2485.0943682193756


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.2442338941192627 train acc: 0.91626 val loss: 0.463582763671875 val acc: 0.8506 time: 2827.4560000896454


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.2015924895477295 train acc: 0.93122 val loss: 0.42662528991699217 val acc: 0.8654 time: 3120.9496307373047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.18272987606048585 train acc: 0.93772 val loss: 0.3974978462219238 val acc: 0.878 time: 3455.7736880779266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.14144361465454103 train acc: 0.95234 val loss: 0.3721118560791016 val acc: 0.8868 time: 3763.5707507133484


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.13909770568847657 train acc: 0.9531 val loss: 0.35671177368164064 val acc: 0.8933 time: 4046.098435640335


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.10609025215148926 train acc: 0.96536 val loss: 0.4155512710571289 val acc: 0.8819 time: 4337.141392230988


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.10838248022079468 train acc: 0.96278 val loss: 0.5047716201782226 val acc: 0.8681 time: 4666.657077789307


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.08513743560791015 train acc: 0.97146 val loss: 0.3959774444580078 val acc: 0.8889 time: 5003.785927534103


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.06459820172309876 train acc: 0.97904 val loss: 0.33313369827270506 val acc: 0.9069 time: 5299.007705688477


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.05546225331306458 train acc: 0.9823 val loss: 0.3331344970703125 val acc: 0.9085 time: 5618.981608867645


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.0447618389248848 train acc: 0.98542 val loss: 0.323142911529541 val acc: 0.9123 time: 5895.641102552414


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.039879504640102384 train acc: 0.98694 val loss: 0.322952473449707 val acc: 0.9126 time: 6213.5818157196045


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.04296023806333542 train acc: 0.9857 val loss: 0.32143545989990235 val acc: 0.9128 time: 6531.505244731903


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.04071170263051987 train acc: 0.98706 val loss: 0.3218938507080078 val acc: 0.9129 time: 6843.077108383179


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.04040976460814476 train acc: 0.98706 val loss: 0.31473096923828126 val acc: 0.9145 time: 7167.435603380203


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.0369829578435421 train acc: 0.98822 val loss: 0.3137911018371582 val acc: 0.9156 time: 7474.438971281052


### Augmentation inside network
Augmentation Appled at: After ResNet block 1  
Augmentation Strategy: Random Pad Crop (pad=1, cutSize=4)

In [0]:
pad=1
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.data_aug2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.6739215759277344 train acc: 0.40466 val loss: 1.286931475830078 val acc: 0.5357 time: 294.3380506038666


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1258365887451172 train acc: 0.59782 val loss: 1.0048131561279297 val acc: 0.6437 time: 628.3957533836365


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.8804691302490234 train acc: 0.68812 val loss: 0.8588029602050781 val acc: 0.701 time: 948.2764320373535


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.729170259399414 train acc: 0.74394 val loss: 0.7886862258911133 val acc: 0.7235 time: 1275.215219259262


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.610307648010254 train acc: 0.78756 val loss: 0.8139519317626953 val acc: 0.7198 time: 1582.7947027683258


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5314647354125976 train acc: 0.81624 val loss: 0.687806071472168 val acc: 0.7605 time: 1918.8916370868683


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.45858206634521487 train acc: 0.8419 val loss: 0.7263565002441407 val acc: 0.7517 time: 2212.3136405944824


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.4037501385498047 train acc: 0.8622 val loss: 0.6400754211425781 val acc: 0.7777 time: 2522.642644405365


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.3541432540893555 train acc: 0.87988 val loss: 0.5921943267822266 val acc: 0.8026 time: 2830.28989982605


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.3109328625488281 train acc: 0.8969 val loss: 0.6760240859985351 val acc: 0.7764 time: 3129.2538566589355


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.27130808547973634 train acc: 0.91014 val loss: 0.6022467254638671 val acc: 0.8054 time: 3435.739649295807


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.2503069636535645 train acc: 0.91704 val loss: 0.516441975402832 val acc: 0.8286 time: 3754.5171298980713


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.24248864929199218 train acc: 0.91924 val loss: 0.6711727676391601 val acc: 0.7936 time: 4112.373661279678


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.2134328729248047 train acc: 0.92902 val loss: 0.6352698471069336 val acc: 0.8002 time: 4449.2785551548


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.18809035888671874 train acc: 0.93798 val loss: 0.5561625518798828 val acc: 0.8182 time: 4785.040047168732


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.1510826558303833 train acc: 0.9526 val loss: 0.5825430084228516 val acc: 0.8145 time: 5078.757688045502


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.14049275344848633 train acc: 0.95544 val loss: 0.713926838684082 val acc: 0.7889 time: 5388.363402366638


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.14062550857543946 train acc: 0.95364 val loss: 0.5839700149536133 val acc: 0.8245 time: 5715.908813714981


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.11841514308929443 train acc: 0.96248 val loss: 0.5754383621215821 val acc: 0.8233 time: 6011.902843475342


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.1190552806854248 train acc: 0.96178 val loss: 0.533566079711914 val acc: 0.8421 time: 6342.409291028976


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.0958775252532959 train acc: 0.97042 val loss: 0.623066682434082 val acc: 0.8173 time: 6638.141429901123


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.10080661396026612 train acc: 0.96834 val loss: 0.6682453765869141 val acc: 0.8139 time: 6951.047869920731


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.08566309440612793 train acc: 0.97266 val loss: 0.5962306350708008 val acc: 0.827 time: 7251.440149784088


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.08281389841079712 train acc: 0.97444 val loss: 0.8060619613647461 val acc: 0.7816 time: 7564.801788806915


### Augmentation inside network
Augmentation Appled at: After ResNet block 1  
Augmentation Strategy: Random Pad Crop (pad=1, cutSize=8)

In [0]:
pad=1
cutSize=8
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.data_aug2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.7417633923339844 train acc: 0.37364 val loss: 1.4205133575439453 val acc: 0.4753 time: 362.96311473846436


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.221131226196289 train acc: 0.56084 val loss: 1.0864014678955078 val acc: 0.6079 time: 652.9500730037689


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.9776904449462891 train acc: 0.6505 val loss: 0.9139682312011719 val acc: 0.6768 time: 965.6120913028717


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.8347844329833984 train acc: 0.70426 val loss: 0.9207713226318359 val acc: 0.6778 time: 1286.5324108600616


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.7412512194824219 train acc: 0.73696 val loss: 0.8405061187744141 val acc: 0.7032 time: 1629.4906277656555


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.645372565612793 train acc: 0.77476 val loss: 0.7364095520019531 val acc: 0.7403 time: 1946.4454190731049


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.5896704220581055 train acc: 0.79412 val loss: 0.6915450866699219 val acc: 0.7614 time: 2270.974874973297


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.5276736862182617 train acc: 0.81724 val loss: 0.686688768005371 val acc: 0.7661 time: 2581.677731513977


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.4799621844482422 train acc: 0.83318 val loss: 0.6544844970703125 val acc: 0.7742 time: 2892.768614768982


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.4309447244262695 train acc: 0.85248 val loss: 0.6881894073486328 val acc: 0.7668 time: 3187.1751911640167


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.40132603240966797 train acc: 0.86282 val loss: 1.0576175018310547 val acc: 0.683 time: 3487.8769223690033


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.37930518814086917 train acc: 0.86986 val loss: 0.6659305053710938 val acc: 0.784 time: 3803.021507501602


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.3820316633605957 train acc: 0.869 val loss: 0.585911181640625 val acc: 0.8002 time: 4165.135120630264


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.3249649366760254 train acc: 0.88758 val loss: 0.7038083435058594 val acc: 0.7814 time: 4487.204869508743


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.3010501708984375 train acc: 0.89786 val loss: 0.585987158203125 val acc: 0.8049 time: 4792.301978349686


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.29365438484191897 train acc: 0.89934 val loss: 0.5719682083129883 val acc: 0.8133 time: 5122.308615446091


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.26596750465393065 train acc: 0.91056 val loss: 0.645276333618164 val acc: 0.7997 time: 5432.259700536728


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.25222910316467284 train acc: 0.91402 val loss: 0.5956802703857422 val acc: 0.8192 time: 5745.676739692688


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.24271513999938965 train acc: 0.91818 val loss: 0.6321960098266601 val acc: 0.8043 time: 6079.257763147354


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.22528325637817384 train acc: 0.9251 val loss: 0.6059334396362305 val acc: 0.808 time: 6393.437779188156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.19785403045654296 train acc: 0.9346 val loss: 0.5388290817260742 val acc: 0.8298 time: 6686.985191822052


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.18326292053222656 train acc: 0.94054 val loss: 0.5831904495239257 val acc: 0.8201 time: 6972.777536869049


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.20300742404937744 train acc: 0.93142 val loss: 0.5215156661987305 val acc: 0.8363 time: 7303.324038267136


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.2001133811569214 train acc: 0.93176 val loss: 0.6801018936157227 val acc: 0.8033 time: 7634.223736524582


### Augmentation inside network
Augmentation Appled at: After ResNet block 1  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=4


In [0]:
pad=2
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.data_aug2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.6936198791503907 train acc: 0.39504 val loss: 1.3851702270507813 val acc: 0.4816 time: 351.40392565727234


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1764398583984375 train acc: 0.58058 val loss: 1.020491323852539 val acc: 0.6382 time: 674.0564365386963


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.9189568292236329 train acc: 0.67274 val loss: 0.8705541320800781 val acc: 0.6879 time: 985.7614064216614


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.7653041217041016 train acc: 0.72922 val loss: 0.8306340744018554 val acc: 0.7026 time: 1301.734842300415


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.651315456237793 train acc: 0.77382 val loss: 0.7953774230957031 val acc: 0.7211 time: 1631.6571145057678


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5704860989379883 train acc: 0.8012 val loss: 0.6841114044189454 val acc: 0.7635 time: 1939.1657915115356


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.501130337524414 train acc: 0.8273 val loss: 0.5481068130493164 val acc: 0.8107 time: 2246.2735271453857


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.4561416677856445 train acc: 0.84438 val loss: 0.6130719650268555 val acc: 0.7923 time: 2554.7375254631042


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.3880428555297852 train acc: 0.86688 val loss: 0.700805224609375 val acc: 0.7673 time: 2877.760489463806


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.35717813873291016 train acc: 0.87828 val loss: 0.6130901840209961 val acc: 0.7977 time: 3196.616599559784


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.32746177642822266 train acc: 0.88878 val loss: 0.5873039611816406 val acc: 0.8066 time: 3525.9958806037903


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.29938445907592776 train acc: 0.89984 val loss: 0.7805122543334961 val acc: 0.7592 time: 3845.622848510742


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.2789775830078125 train acc: 0.9063 val loss: 0.8882104446411133 val acc: 0.7217 time: 4183.654020309448


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.2322754125213623 train acc: 0.9224 val loss: 0.7506556594848632 val acc: 0.7709 time: 4484.77699804306


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.22159534660339356 train acc: 0.92574 val loss: 0.5435021423339844 val acc: 0.8273 time: 4821.688076734543


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.18903969856262207 train acc: 0.93856 val loss: 0.5628855163574219 val acc: 0.8223 time: 5109.13251042366


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.18523602172851564 train acc: 0.93822 val loss: 0.5777410018920899 val acc: 0.8176 time: 5420.739627838135


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.17430288627624513 train acc: 0.94198 val loss: 0.8096920272827148 val acc: 0.7601 time: 5741.821891784668


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.16010369445800782 train acc: 0.94714 val loss: 0.8561150848388672 val acc: 0.7567 time: 6074.319520235062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.1381177857208252 train acc: 0.95532 val loss: 0.6125480285644531 val acc: 0.8168 time: 6382.351236343384


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.13153160533905028 train acc: 0.95698 val loss: 0.6833514389038086 val acc: 0.801 time: 6703.561769485474


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.12263097290039063 train acc: 0.9606 val loss: 0.5654070877075196 val acc: 0.8368 time: 7005.043070793152


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.12072544479370118 train acc: 0.96022 val loss: 0.5980527084350586 val acc: 0.8277 time: 7324.938195705414


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.11084938835144043 train acc: 0.96386 val loss: 0.6247633514404297 val acc: 0.8209 time: 7647.524335384369


### Augmentation inside network
Augmentation Appled at: After ResNet block 1  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=8


In [0]:
pad=2
cutSize=8
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.data_aug2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.56013619140625 train acc: 0.42906 val loss: 2.167406573486328 val acc: 0.4234 time: 344.6038703918457


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.965200517578125 train acc: 0.65212 val loss: 1.076610043334961 val acc: 0.6574 time: 690.1129419803619


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.7549077197265625 train acc: 0.73518 val loss: 1.7094183288574218 val acc: 0.5879 time: 989.68430352211


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.6578508767700195 train acc: 0.77156 val loss: 0.6544062484741211 val acc: 0.7791 time: 1289.2147526741028


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.5648462893676758 train acc: 0.80576 val loss: 0.640186262512207 val acc: 0.7943 time: 1588.2952919006348


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.482644225769043 train acc: 0.8317 val loss: 0.8021152587890625 val acc: 0.7664 time: 1881.745693206787


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.39812656677246094 train acc: 0.86452 val loss: 0.4745824554443359 val acc: 0.8491 time: 2179.9176921844482


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.35581936584472657 train acc: 0.8763 val loss: 0.4423580307006836 val acc: 0.8599 time: 2500.214649915695


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.3129488165283203 train acc: 0.89162 val loss: 0.42932523193359373 val acc: 0.8603 time: 2836.9301550388336


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.27010210990905764 train acc: 0.90718 val loss: 0.5416159118652344 val acc: 0.829 time: 3156.654623746872


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.24858723266601562 train acc: 0.91346 val loss: 0.4067100631713867 val acc: 0.8737 time: 3502.527079820633


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.1955993025970459 train acc: 0.9331 val loss: 0.4513424118041992 val acc: 0.8596 time: 3798.212708711624


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.2015037843322754 train acc: 0.92946 val loss: 0.42796974639892577 val acc: 0.8702 time: 4149.867446184158


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.16885636573791504 train acc: 0.94204 val loss: 0.34698438034057616 val acc: 0.8947 time: 4487.894048213959


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.1444437463951111 train acc: 0.94938 val loss: 0.4649825141906738 val acc: 0.8689 time: 4819.229304552078


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.12881217723846436 train acc: 0.956 val loss: 0.3160697731018066 val acc: 0.9056 time: 5167.450182199478


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.10430334272384643 train acc: 0.9642 val loss: 0.3421485733032227 val acc: 0.9024 time: 5495.758043766022


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.08075898592948913 train acc: 0.97198 val loss: 0.31704113159179687 val acc: 0.9106 time: 5792.678310871124


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.08794741238594055 train acc: 0.96948 val loss: 0.3086310401916504 val acc: 0.914 time: 6145.064413785934


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.07266371158599853 train acc: 0.97594 val loss: 0.307125821685791 val acc: 0.9126 time: 6481.260657310486


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.06529813711166382 train acc: 0.97796 val loss: 0.3039977523803711 val acc: 0.9174 time: 6814.184375524521


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.06450075716018677 train acc: 0.97856 val loss: 0.3092188621520996 val acc: 0.9154 time: 7158.5696449279785


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.05043654715776443 train acc: 0.98334 val loss: 0.2976011734008789 val acc: 0.9192 time: 7450.325302600861


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.05761500784635544 train acc: 0.98092 val loss: 0.30888068771362304 val acc: 0.9194 time: 7787.095217704773


### Augmentation inside network
Augmentation Appled at: After ResNet block 1  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=2


In [0]:
pad=2
cutSize=2
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch  

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.data_aug2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.681621817626953 train acc: 0.40154 val loss: 1.349463589477539 val acc: 0.5132 time: 278.8710696697235


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1535218408203125 train acc: 0.587 val loss: 1.058209457397461 val acc: 0.6206 time: 557.1495904922485


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.9144048077392578 train acc: 0.674 val loss: 0.8521651641845703 val acc: 0.6957 time: 880.1622431278229


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.7642893084716796 train acc: 0.73104 val loss: 0.8233490661621093 val acc: 0.7105 time: 1197.211720943451


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.6539117254638672 train acc: 0.77204 val loss: 0.766605207824707 val acc: 0.7237 time: 1485.4809234142303


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5584352188110352 train acc: 0.8045 val loss: 0.660071371459961 val acc: 0.7673 time: 1810.4562468528748


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.4880517501831055 train acc: 0.83108 val loss: 0.6562359558105468 val acc: 0.7763 time: 2125.8841054439545


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.4336975518798828 train acc: 0.85322 val loss: 0.5844624099731446 val acc: 0.7966 time: 2458.009047985077


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.37811716567993164 train acc: 0.87066 val loss: 0.7144609512329102 val acc: 0.7649 time: 2753.780818462372


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.335387043762207 train acc: 0.8874 val loss: 0.8494595504760742 val acc: 0.7305 time: 3065.982987165451


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.31733752166748047 train acc: 0.893 val loss: 0.5459375411987305 val acc: 0.8166 time: 3381.8521780967712


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.2821679844665527 train acc: 0.90602 val loss: 0.5681502746582031 val acc: 0.8112 time: 3714.83931183815


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.24049335052490234 train acc: 0.91962 val loss: 0.5573489822387695 val acc: 0.8174 time: 4043.3804252147675


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.22237898468017578 train acc: 0.92756 val loss: 0.667802896118164 val acc: 0.7868 time: 4343.04744553566


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.1888622364807129 train acc: 0.93752 val loss: 0.5214978775024414 val acc: 0.832 time: 4639.617775917053


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.1870682447052002 train acc: 0.93822 val loss: 0.6182326553344727 val acc: 0.8093 time: 4962.405160188675


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.1647936566543579 train acc: 0.9465 val loss: 0.8381474746704102 val acc: 0.7615 time: 5280.349648952484


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.14775977916717528 train acc: 0.95252 val loss: 0.5066740158081054 val acc: 0.8401 time: 5590.857746601105


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.14504272411346436 train acc: 0.9528 val loss: 0.6919011077880859 val acc: 0.7938 time: 5899.785016536713


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.13212342210769654 train acc: 0.95744 val loss: 0.6826962219238282 val acc: 0.8067 time: 6239.917277812958


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.11413112407684327 train acc: 0.96404 val loss: 0.6130329574584961 val acc: 0.821 time: 6525.923155546188


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.1229829419708252 train acc: 0.95982 val loss: 0.5516336563110351 val acc: 0.8337 time: 6824.373297452927


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.10450569353103638 train acc: 0.96646 val loss: 0.6602913742065429 val acc: 0.8145 time: 7145.423587560654


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.09760386112213135 train acc: 0.96862 val loss: 0.5850477844238281 val acc: 0.8396 time: 7460.729785442352


### Augmentation inside network
Augmentation Appled at: After 2nd Resnet Block  
Augmentation Strategy: Random Pad Crop (pad=1,cutSize=2)

In [0]:
pad=1
cutSize=2
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch    

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.data_aug2(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.7001617083740235 train acc: 0.3922 val loss: 1.339414193725586 val acc: 0.5143 time: 299.69847774505615


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1771773071289062 train acc: 0.57568 val loss: 1.040238494873047 val acc: 0.6246 time: 598.1110906600952


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.9128782604980469 train acc: 0.67452 val loss: 0.9061826934814453 val acc: 0.6762 time: 887.5305528640747


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.7596427526855469 train acc: 0.73126 val loss: 0.8658771270751953 val acc: 0.6931 time: 1187.039624452591


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.6420351281738281 train acc: 0.77646 val loss: 0.6830802795410156 val acc: 0.7627 time: 1508.1716122627258


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5449663381958008 train acc: 0.80988 val loss: 0.7115470825195312 val acc: 0.7529 time: 1796.8148293495178


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.47986305908203125 train acc: 0.83512 val loss: 1.0534038818359375 val acc: 0.6672 time: 2141.227771997452


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.41097577514648437 train acc: 0.86054 val loss: 0.9614496704101563 val acc: 0.6955 time: 2482.4825761318207


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.34957275405883786 train acc: 0.88178 val loss: 0.868968603515625 val acc: 0.7169 time: 2798.983127117157


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.2912677648925781 train acc: 0.90228 val loss: 0.8332010162353516 val acc: 0.7377 time: 3076.972693681717


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.25441468437194825 train acc: 0.91564 val loss: 0.9593683349609375 val acc: 0.7288 time: 3373.016696691513


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.2257516583251953 train acc: 0.9244 val loss: 1.0600367736816407 val acc: 0.7126 time: 3701.1024572849274


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.19037203239440917 train acc: 0.93668 val loss: 1.1685471954345703 val acc: 0.6979 time: 4015.9299478530884


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.15398020126342774 train acc: 0.94944 val loss: 0.6649199417114258 val acc: 0.7991 time: 4270.930648088455


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.15207443607330323 train acc: 0.95024 val loss: 1.2173089965820312 val acc: 0.6981 time: 4572.343306779861


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.11619427242279053 train acc: 0.96318 val loss: 0.7147664001464844 val acc: 0.8004 time: 4874.649482727051


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.10646742778778076 train acc: 0.9652 val loss: 1.32191201171875 val acc: 0.6962 time: 5179.399294614792


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.10058665325164795 train acc: 0.96774 val loss: 1.5574500366210937 val acc: 0.6804 time: 5483.3747754096985


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.08758197143554687 train acc: 0.97184 val loss: 1.0305211944580077 val acc: 0.7602 time: 5809.739863395691


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.09022929483413697 train acc: 0.97062 val loss: 0.8588045776367188 val acc: 0.7794 time: 6146.781136512756


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.07202292917251588 train acc: 0.97752 val loss: 0.9430787750244141 val acc: 0.7693 time: 6446.616165399551


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.0692700567150116 train acc: 0.97806 val loss: 1.1475429138183593 val acc: 0.729 time: 6764.447984933853


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.0592445977973938 train acc: 0.98094 val loss: 1.122024041748047 val acc: 0.7434 time: 7088.589679718018


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.061602747039794924 train acc: 0.98076 val loss: 1.58162099609375 val acc: 0.6831 time: 7419.569298744202


### Augmentation inside network
Augmentation Appled at: After 2nd Resnet Block  
Augmentation Strategy: Random Pad Crop (pad=1,cutSize=4)

In [0]:
pad=1
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch    

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.data_aug2(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.5853421716308593 train acc: 0.41802 val loss: 1.358955926513672 val acc: 0.5566 time: 386.5805432796478


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.9566800061035157 train acc: 0.65616 val loss: 1.0506786956787109 val acc: 0.6646 time: 690.7436211109161


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.7710095367431641 train acc: 0.72892 val loss: 0.8113151550292969 val acc: 0.7305 time: 1043.1320261955261


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.6450969400024414 train acc: 0.77518 val loss: 1.0972739318847655 val acc: 0.6964 time: 1384.8567266464233


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.5720924963378906 train acc: 0.80236 val loss: 0.6103920150756836 val acc: 0.7946 time: 1725.2505073547363


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.48200303344726564 train acc: 0.83314 val loss: 0.7544891830444336 val acc: 0.7502 time: 2057.6667096614838


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.40410492736816406 train acc: 0.85962 val loss: 0.6334926605224609 val acc: 0.7923 time: 2405.041672706604


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.3336856343078613 train acc: 0.88366 val loss: 0.5002815658569336 val acc: 0.842 time: 2738.419492483139


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.30376605361938475 train acc: 0.89452 val loss: 0.4765020004272461 val acc: 0.8546 time: 3102.5365273952484


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.24308241004943848 train acc: 0.91582 val loss: 0.5127298950195313 val acc: 0.8486 time: 3432.5988824367523


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.20115719596862794 train acc: 0.93046 val loss: 0.5192940032958985 val acc: 0.8488 time: 3735.101157426834


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.1723120157623291 train acc: 0.94084 val loss: 0.44833980255126954 val acc: 0.8685 time: 4044.547892332077


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.15373443157196046 train acc: 0.947 val loss: 0.44375172729492185 val acc: 0.8777 time: 4389.097256660461


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.1201962823677063 train acc: 0.95824 val loss: 0.5140770034790039 val acc: 0.8642 time: 4711.458118438721


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.10476486225128173 train acc: 0.9641 val loss: 0.5117241271972657 val acc: 0.8745 time: 5038.9163699150085


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.09139425546646118 train acc: 0.96906 val loss: 0.4586080047607422 val acc: 0.8813 time: 5388.031049728394


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.08559274609565735 train acc: 0.97126 val loss: 0.4208559600830078 val acc: 0.8962 time: 5746.805028676987


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.061603286681175234 train acc: 0.97952 val loss: 0.43602896575927735 val acc: 0.8954 time: 6082.71479845047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.05085516353368759 train acc: 0.98308 val loss: 0.39649828033447265 val acc: 0.9028 time: 6402.985944032669


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.04070400715827942 train acc: 0.98734 val loss: 0.38150966796875 val acc: 0.9051 time: 6732.285389661789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.032720000525712965 train acc: 0.98952 val loss: 0.39148032455444337 val acc: 0.9094 time: 7040.745693206787


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.033570353409051894 train acc: 0.9889 val loss: 0.3844665252685547 val acc: 0.9062 time: 7360.630423784256


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.02908505045413971 train acc: 0.99058 val loss: 0.37829129943847656 val acc: 0.9079 time: 7669.931704521179


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.02618369418203831 train acc: 0.99184 val loss: 0.369765153503418 val acc: 0.9102 time: 7981.225943803787


### Augmentation inside network
Augmentation Appled at: After 2nd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=1

In [0]:
pad=2
cutSize=1
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad, cutSize=cutSize)
    else:    
      return batch   

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.data_aug2(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.7440642529296875 train acc: 0.37016 val loss: 1.3235881774902343 val acc: 0.5149 time: 331.13750195503235


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.2021481072998046 train acc: 0.5652 val loss: 1.042958135986328 val acc: 0.6245 time: 631.006489276886


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.9656727734375 train acc: 0.6541 val loss: 0.9025577514648437 val acc: 0.6709 time: 983.6305432319641


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.8030006311035156 train acc: 0.71528 val loss: 1.0826359649658204 val acc: 0.6282 time: 1316.047456741333


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.680142325744629 train acc: 0.7622 val loss: 0.7625468200683594 val acc: 0.7322 time: 1612.1160216331482


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5801594180297851 train acc: 0.79648 val loss: 0.7150931732177734 val acc: 0.7587 time: 1907.2568881511688


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.5105638766479492 train acc: 0.82312 val loss: 1.032940771484375 val acc: 0.6619 time: 2230.7731804847717


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.4341296328735352 train acc: 0.85192 val loss: 0.7190421890258789 val acc: 0.7587 time: 2513.3250036239624


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.3720885061645508 train acc: 0.87278 val loss: 0.9093431945800782 val acc: 0.7219 time: 2813.369023323059


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.33543074661254885 train acc: 0.88546 val loss: 0.6635377777099609 val acc: 0.785 time: 3149.389950275421


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.2902013006591797 train acc: 0.90048 val loss: 0.811066470336914 val acc: 0.7547 time: 3460.8586044311523


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.24742906028747558 train acc: 0.91564 val loss: 0.9020707168579102 val acc: 0.7393 time: 3789.7239112854004


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.21496099304199218 train acc: 0.92814 val loss: 0.8926973327636719 val acc: 0.7466 time: 4089.1376535892487


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.2119521659851074 train acc: 0.92868 val loss: 0.9550775512695312 val acc: 0.7363 time: 4469.419666528702


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.16665176670074464 train acc: 0.94578 val loss: 1.2346364532470704 val acc: 0.682 time: 4763.278495311737


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.15976272579193115 train acc: 0.9458 val loss: 1.7006939392089844 val acc: 0.6375 time: 5087.874419927597


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.1316972043609619 train acc: 0.95756 val loss: 0.9314070831298829 val acc: 0.754 time: 5401.1941385269165


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.12765422370910645 train acc: 0.95818 val loss: 1.0035457336425782 val acc: 0.7599 time: 5736.084035158157


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.10181630262374879 train acc: 0.96812 val loss: 1.0830271667480469 val acc: 0.7416 time: 6056.566239356995


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.1048877684211731 train acc: 0.96536 val loss: 0.9587921401977539 val acc: 0.7572 time: 6370.264518022537


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.0878136919784546 train acc: 0.9715 val loss: 1.5601193420410155 val acc: 0.6761 time: 6696.578418970108


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.078529844789505 train acc: 0.97476 val loss: 1.9059832458496093 val acc: 0.6555 time: 7017.565333366394


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.07627339252471924 train acc: 0.9754 val loss: 1.166117562866211 val acc: 0.7372 time: 7312.741885900497


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.07019679019927978 train acc: 0.97716 val loss: 1.4624857940673828 val acc: 0.7026 time: 7637.79608130455


### Augmentation inside network
Augmentation Appled at: After 2nd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=2

In [0]:
pad=2
cutSize=2
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad, cutSize=cutSize)
    else:    
      return batch   

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.data_aug2(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.5699828369140625 train acc: 0.42164 val loss: 1.4111896392822265 val acc: 0.5401 time: 359.60451221466064


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.9010613348388672 train acc: 0.67792 val loss: 1.5593408813476564 val acc: 0.5732 time: 691.2611236572266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6858496139526368 train acc: 0.76004 val loss: 0.8301117980957031 val acc: 0.7208 time: 1017.5803172588348


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5677341485595703 train acc: 0.80348 val loss: 0.914662255859375 val acc: 0.7237 time: 1336.6949391365051


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.49014785888671875 train acc: 0.83036 val loss: 2.1888024047851564 val acc: 0.5737 time: 1678.2270486354828


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.4262572467041016 train acc: 0.8523 val loss: 0.7685572082519532 val acc: 0.7714 time: 2019.3684606552124


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.3232641970825195 train acc: 0.88886 val loss: 0.5338567459106446 val acc: 0.8323 time: 2340.7029106616974


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.24722392807006835 train acc: 0.91402 val loss: 0.628175277709961 val acc: 0.8079 time: 2630.2538521289825


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.2093665239715576 train acc: 0.92614 val loss: 0.6196473434448242 val acc: 0.8207 time: 2948.8222160339355


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.17179642822265626 train acc: 0.93994 val loss: 0.5093546035766602 val acc: 0.8576 time: 3261.8264541625977


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.1287149279022217 train acc: 0.9547 val loss: 0.7897833801269531 val acc: 0.8043 time: 3569.6386165618896


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.1050363921546936 train acc: 0.96366 val loss: 0.5890096710205078 val acc: 0.8456 time: 3871.7978398799896


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.08525451932907105 train acc: 0.97068 val loss: 0.6893838455200195 val acc: 0.8449 time: 4195.3444147109985


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.06799303292274475 train acc: 0.97674 val loss: 0.7026835052490235 val acc: 0.8493 time: 4515.09099316597


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.048103700447082516 train acc: 0.98348 val loss: 0.652638752746582 val acc: 0.8682 time: 4834.642048120499


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.041296620037555694 train acc: 0.98598 val loss: 0.6665777969360351 val acc: 0.8662 time: 5171.921990394592


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.03243982293844223 train acc: 0.98906 val loss: 0.49391004486083984 val acc: 0.8943 time: 5507.161457538605


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.022930190098285674 train acc: 0.99282 val loss: 0.4676080062866211 val acc: 0.9001 time: 5836.704733610153


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.019880100751519202 train acc: 0.9936 val loss: 0.45605255126953126 val acc: 0.8992 time: 6190.601444005966


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.012502557260990144 train acc: 0.99622 val loss: 0.4685919906616211 val acc: 0.9032 time: 6494.327630996704


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.011852895233929158 train acc: 0.99632 val loss: 0.47170534057617186 val acc: 0.9008 time: 6803.415456056595


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.010550170094370842 train acc: 0.99696 val loss: 0.45179154663085935 val acc: 0.9046 time: 7113.258825778961


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.011364735050201417 train acc: 0.99652 val loss: 0.4750121398925781 val acc: 0.9046 time: 7441.08469581604


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.008613952351510525 train acc: 0.99744 val loss: 0.46613668670654296 val acc: 0.9038 time: 7772.370632648468


### Augmentation inside network
Augmentation Appled at: After 2nd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=4

In [0]:
pad=2
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad, cutSize=cutSize)
    else:    
      return batch   

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.data_aug2(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.8136161712646484 train acc: 0.3405 val loss: 1.4393283813476563 val acc: 0.4635 time: 344.3578281402588


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.3118828790283203 train acc: 0.5239 val loss: 1.2169591522216796 val acc: 0.5473 time: 656.7080726623535


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 1.0834223120117188 train acc: 0.61122 val loss: 1.077284375 val acc: 0.6134 time: 988.170604467392


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.9143649081420898 train acc: 0.6733 val loss: 0.8646190643310547 val acc: 0.6961 time: 1294.3894147872925


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.793438695678711 train acc: 0.71972 val loss: 0.8408568817138672 val acc: 0.7037 time: 1615.2668969631195


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.7094413830566406 train acc: 0.7507 val loss: 0.9215629638671875 val acc: 0.6877 time: 1932.135535955429


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.6303590606689453 train acc: 0.77838 val loss: 1.0273768890380859 val acc: 0.6699 time: 2227.7413613796234


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.5914051119995117 train acc: 0.79196 val loss: 0.7815988006591796 val acc: 0.7388 time: 2564.1458559036255


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.5391319128417968 train acc: 0.8115 val loss: 0.8936631103515625 val acc: 0.7104 time: 2904.581349849701


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.4852535159301758 train acc: 0.83148 val loss: 0.6960880081176758 val acc: 0.7673 time: 3222.7769691944122


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.4607682925415039 train acc: 0.83834 val loss: 0.6751264221191406 val acc: 0.776 time: 3560.983355283737


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.41434790954589845 train acc: 0.85686 val loss: 0.7212955642700195 val acc: 0.7675 time: 3883.8344910144806


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.38055245559692386 train acc: 0.86852 val loss: 0.9754956848144531 val acc: 0.717 time: 4212.796271562576


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.36487687805175784 train acc: 0.87482 val loss: 0.577402848815918 val acc: 0.8131 time: 4558.939116001129


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.3111564887237549 train acc: 0.89332 val loss: 0.7388348709106445 val acc: 0.7752 time: 4877.466774225235


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.2867245548248291 train acc: 0.90196 val loss: 0.6503302291870117 val acc: 0.7979 time: 5182.444878816605


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.2817683576202393 train acc: 0.9027 val loss: 1.2320720153808593 val acc: 0.6937 time: 5507.987241029739


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.2592455077362061 train acc: 0.9111 val loss: 0.8462742141723633 val acc: 0.7635 time: 5829.974852800369


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.2488051342010498 train acc: 0.91406 val loss: 0.8903042175292969 val acc: 0.7585 time: 6162.555616378784


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.23353861396789552 train acc: 0.91892 val loss: 1.1219321685791015 val acc: 0.7376 time: 6499.231781721115


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.2400260569000244 train acc: 0.91724 val loss: 0.9329610504150391 val acc: 0.7595 time: 6864.624772071838


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.20351605892181396 train acc: 0.9298 val loss: 1.1155522521972656 val acc: 0.7364 time: 7197.193012475967


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.1783711735534668 train acc: 0.93966 val loss: 0.8108752380371094 val acc: 0.789 time: 7506.890388727188


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.1894309169769287 train acc: 0.93496 val loss: 1.2548774505615234 val acc: 0.7123 time: 7863.731426000595


### Augmentation inside network
Augmentation Appled at: After 3rd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=1,cutSize=1

In [0]:
pad=1
cutSize=1
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch 

  def call(self, x, y):
    h = self.pool(self.blk4(self.data_aug2(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.7425555780029296 train acc: 0.37724 val loss: 1.3972201293945312 val acc: 0.4872 time: 314.78137254714966


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1797134692382814 train acc: 0.58026 val loss: 1.0741511077880859 val acc: 0.6126 time: 650.1410677433014


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.9001570202636718 train acc: 0.6829 val loss: 0.9994837921142579 val acc: 0.6354 time: 991.2281901836395


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.7089501928710937 train acc: 0.754 val loss: 0.9240078216552734 val acc: 0.6751 time: 1310.907786846161


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.5627576852416992 train acc: 0.80846 val loss: 0.7926905624389649 val acc: 0.7284 time: 1615.0288844108582


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.43477358489990237 train acc: 0.85446 val loss: 0.9416111114501953 val acc: 0.6998 time: 1908.6389989852905


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.3228309481811523 train acc: 0.89374 val loss: 0.8915937088012695 val acc: 0.7268 time: 2246.170470237732


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.2216616067504883 train acc: 0.931 val loss: 0.9958732482910156 val acc: 0.7137 time: 2546.0595285892487


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.14375533042907715 train acc: 0.96048 val loss: 1.1398704895019531 val acc: 0.6999 time: 2850.8621819019318


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.08775323032379151 train acc: 0.97864 val loss: 0.7845830520629883 val acc: 0.7721 time: 3169.45161318779


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.0530165531539917 train acc: 0.98956 val loss: 1.2320170654296876 val acc: 0.7167 time: 3510.2970678806305


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.029713790168762206 train acc: 0.99586 val loss: 0.8258864929199219 val acc: 0.7768 time: 3842.1682057380676


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.016502563333511352 train acc: 0.9987 val loss: 0.8384071838378906 val acc: 0.7879 time: 4161.286209106445


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.008588755111694336 train acc: 0.99974 val loss: 0.7658482772827149 val acc: 0.8046 time: 4504.526997089386


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.004931837475299836 train acc: 0.99996 val loss: 0.7207237258911133 val acc: 0.8106 time: 4812.2693729400635


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.004067521086931229 train acc: 0.99996 val loss: 0.7267495498657227 val acc: 0.8101 time: 5120.516424894333


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.0037180251717567445 train acc: 1.0 val loss: 0.7065423370361328 val acc: 0.816 time: 5459.761605501175


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.002824152933359146 train acc: 1.0 val loss: 0.7084783432006836 val acc: 0.8139 time: 5776.5591769218445


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.0025700886881351473 train acc: 1.0 val loss: 0.6884792434692383 val acc: 0.8181 time: 6079.783547878265


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.002298788585662842 train acc: 1.0 val loss: 0.7047800231933594 val acc: 0.8201 time: 6414.562806606293


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.002134357281923294 train acc: 1.0 val loss: 0.7026110488891602 val acc: 0.819 time: 6742.912936925888


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.0018622981816530228 train acc: 1.0 val loss: 0.6964565673828125 val acc: 0.8187 time: 7049.991405010223


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.0017173825323581696 train acc: 1.0 val loss: 0.7065638732910157 val acc: 0.8216 time: 7357.684489250183


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.0016665874671936034 train acc: 1.0 val loss: 0.7127612518310547 val acc: 0.8188 time: 7666.398751974106


### Augmentation inside network
Augmentation Appled at: After 3rd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=1,cutSize=2

In [0]:
pad=1
cutSize=2
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch 

  def call(self, x, y):
    h = self.pool(self.blk4(self.data_aug2(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

epoch: 1 lr: 0.08 train loss: 1.5618844573974608 train acc: 0.42312 val loss: 1.4428502960205078 val acc: 0.5443 time: 327.4184856414795


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8504948211669922 train acc: 0.69358 val loss: 0.9592930053710937 val acc: 0.7016 time: 633.9718420505524


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6257404537963868 train acc: 0.78086 val loss: 1.0131699584960938 val acc: 0.7188 time: 948.2642438411713


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.49548582733154295 train acc: 0.8285 val loss: 1.1664123779296875 val acc: 0.6816 time: 1261.5304386615753


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4180797412109375 train acc: 0.85646 val loss: 0.7690294967651368 val acc: 0.7744 time: 1605.7766950130463


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.31621252975463865 train acc: 0.89278 val loss: 1.2904512512207031 val acc: 0.6873 time: 1918.848009109497


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.22369056747436522 train acc: 0.92344 val loss: 0.6963192001342774 val acc: 0.8123 time: 2279.6135189533234


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.1465373974609375 train acc: 0.94894 val loss: 0.7879324951171875 val acc: 0.8066 time: 2623.820252895355


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.1032205267715454 train acc: 0.96458 val loss: 0.8334302688598633 val acc: 0.8207 time: 2946.776596069336


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.06315186033248901 train acc: 0.97842 val loss: 1.0746515594482422 val acc: 0.806 time: 3261.589106798172


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.03612145703315735 train acc: 0.9876 val loss: 1.1621117004394532 val acc: 0.8216 time: 3559.005670070648


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.02364085692882538 train acc: 0.99234 val loss: 0.9324852935791016 val acc: 0.8381 time: 3882.827954053879


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.015428113152980804 train acc: 0.99478 val loss: 0.9626380264282226 val acc: 0.8359 time: 4186.7441511154175


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.007945958777666092 train acc: 0.99762 val loss: 0.6821072280883789 val acc: 0.8784 time: 4514.050767660141


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.003288201222717762 train acc: 0.9991 val loss: 0.7138492156982422 val acc: 0.8799 time: 4815.200405836105


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.001943566466718912 train acc: 0.99966 val loss: 0.6728310668945312 val acc: 0.8819 time: 5130.21745967865


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.0018578449414670467 train acc: 0.99966 val loss: 0.674441796875 val acc: 0.885 time: 5468.124856710434


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.0014543863439559937 train acc: 0.99974 val loss: 0.6971058319091797 val acc: 0.8843 time: 5778.490402698517


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.0008784255516529083 train acc: 0.99988 val loss: 0.6539590881347657 val acc: 0.8833 time: 6077.007429361343


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.0014309308955445885 train acc: 0.99964 val loss: 0.6868367446899414 val acc: 0.884 time: 6409.046169519424


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.0012175669299811125 train acc: 0.9997 val loss: 0.6830407089233398 val acc: 0.8851 time: 6705.023818254471


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.0008372579548507928 train acc: 0.99984 val loss: 0.6648986877441406 val acc: 0.8854 time: 7024.939955234528


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.000932557633407414 train acc: 0.9998 val loss: 0.6990473846435546 val acc: 0.8851 time: 7353.696537017822


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.0007508684581145644 train acc: 0.99988 val loss: 0.699462158203125 val acc: 0.8855 time: 7675.524048805237


### Augmentation inside network
Augmentation Appled at: After 3rd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=1

In [0]:
pad=2
cutSize=1
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch 

  def call(self, x, y):
    h = self.pool(self.blk4(self.data_aug2(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.8231514172363281 train acc: 0.33814 val loss: 1.442267401123047 val acc: 0.4613 time: 289.6646785736084


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.2966751123046876 train acc: 0.53362 val loss: 1.2198311065673828 val acc: 0.5684 time: 609.1216335296631


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 1.0168110852050782 train acc: 0.64206 val loss: 0.9677573272705078 val acc: 0.6539 time: 946.1997685432434


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.8102151486206055 train acc: 0.71464 val loss: 0.882210791015625 val acc: 0.6835 time: 1258.6346690654755


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.657751589050293 train acc: 0.77368 val loss: 0.7634411819458008 val acc: 0.731 time: 1550.4228093624115


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5410019778442383 train acc: 0.81554 val loss: 1.024207080078125 val acc: 0.6699 time: 1857.3733820915222


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.4390115887451172 train acc: 0.85386 val loss: 1.0161621490478516 val acc: 0.6824 time: 2183.4457437992096


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.3410350389099121 train acc: 0.88772 val loss: 0.8240613174438477 val acc: 0.7336 time: 2501.3228421211243


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.2462901609802246 train acc: 0.922 val loss: 1.2525902709960937 val acc: 0.668 time: 2820.3648734092712


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.18701446708679198 train acc: 0.94274 val loss: 1.2468136138916015 val acc: 0.6663 time: 3118.928178548813


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.133069365234375 train acc: 0.9614 val loss: 0.850958642578125 val acc: 0.7599 time: 3439.417964696884


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.08632393199920654 train acc: 0.97692 val loss: 0.8625077087402344 val acc: 0.763 time: 3730.286325454712


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.06199743335723877 train acc: 0.98482 val loss: 1.0957192443847656 val acc: 0.7381 time: 4056.5225880146027


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.046207111797332764 train acc: 0.98996 val loss: 1.1503233489990234 val acc: 0.7171 time: 4371.243454456329


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.031063398456573486 train acc: 0.99388 val loss: 0.9255937103271484 val acc: 0.7755 time: 4670.227818250656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.025014823894500732 train acc: 0.99566 val loss: 1.659578369140625 val acc: 0.6584 time: 4963.375111579895


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.02083930778503418 train acc: 0.99668 val loss: 1.131957177734375 val acc: 0.7385 time: 5284.282741069794


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.017835675187110902 train acc: 0.99724 val loss: 1.3921947937011718 val acc: 0.7214 time: 5622.742198467255


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.011290256700515746 train acc: 0.99858 val loss: 0.850003564453125 val acc: 0.7954 time: 5940.482671260834


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.009628053575754166 train acc: 0.99906 val loss: 0.8504127410888672 val acc: 0.7917 time: 6273.854773283005


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.0065399713027477265 train acc: 0.99952 val loss: 0.8535119537353516 val acc: 0.7986 time: 6571.15731549263


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.005081442323923111 train acc: 0.99962 val loss: 0.8374009582519532 val acc: 0.7963 time: 6858.866739749908


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.004174236983656883 train acc: 0.99996 val loss: 0.7926763214111329 val acc: 0.805 time: 7174.6339502334595


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.002823784787654877 train acc: 0.99994 val loss: 0.7744853057861328 val acc: 0.8105 time: 7478.011656522751


### Augmentation inside network
Augmentation Appled at: After 3rd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=2

In [0]:
pad=2
cutSize=2
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch 

  def call(self, x, y):
    h = self.pool(self.blk4(self.data_aug2(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.8160484887695312 train acc: 0.3489 val loss: 1.4179361694335937 val acc: 0.4607 time: 295.72327518463135


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.3301858636474608 train acc: 0.52034 val loss: 1.1543548431396484 val acc: 0.5706 time: 597.4277584552765


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 1.0542542346191406 train acc: 0.62694 val loss: 0.9825055938720703 val acc: 0.6487 time: 930.5684466362


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.8620135485839844 train acc: 0.69728 val loss: 0.9976129333496093 val acc: 0.6483 time: 1269.125559091568


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.7095697296142578 train acc: 0.75462 val loss: 0.8873738189697266 val acc: 0.6979 time: 1584.3024711608887


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.5873761788940429 train acc: 0.7997 val loss: 0.8121135925292968 val acc: 0.7318 time: 1916.1334824562073


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.4766591751098633 train acc: 0.83772 val loss: 0.787276628112793 val acc: 0.7443 time: 2248.8994665145874


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.3780060563659668 train acc: 0.87554 val loss: 0.8838701446533204 val acc: 0.7304 time: 2566.835261106491


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.3032449752807617 train acc: 0.9001 val loss: 0.9936376922607422 val acc: 0.722 time: 2872.620388507843


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.21862754013061522 train acc: 0.93012 val loss: 0.986840771484375 val acc: 0.7249 time: 3176.694655895233


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.1663838233947754 train acc: 0.94806 val loss: 1.5670072570800782 val acc: 0.6339 time: 3482.7506511211395


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.12688231060028077 train acc: 0.96276 val loss: 1.1245011535644531 val acc: 0.7242 time: 3813.8556747436523


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.08399612201690673 train acc: 0.97642 val loss: 0.960992367553711 val acc: 0.7518 time: 4102.440300703049


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.07620792812347413 train acc: 0.97964 val loss: 1.0601198669433594 val acc: 0.7501 time: 4426.462034463882


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.059406269769668577 train acc: 0.9846 val loss: 1.1409329162597657 val acc: 0.7449 time: 4743.403435468674


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.05704031600952148 train acc: 0.98468 val loss: 1.2960627532958984 val acc: 0.7176 time: 5081.524017572403


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.039068621249198916 train acc: 0.99028 val loss: 1.1554222900390625 val acc: 0.7601 time: 5373.199042797089


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.03160713667869568 train acc: 0.9926 val loss: 0.8171152923583984 val acc: 0.7948 time: 5693.762620449066


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.026868118770122528 train acc: 0.99368 val loss: 1.013483090209961 val acc: 0.7723 time: 6006.010609388351


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.01939874532222748 train acc: 0.99566 val loss: 0.9407024597167969 val acc: 0.7874 time: 6288.347442626953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.019704329578876495 train acc: 0.99552 val loss: 1.0317566680908203 val acc: 0.7664 time: 6582.805755376816


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.012853345379829406 train acc: 0.99702 val loss: 0.8391115295410156 val acc: 0.7989 time: 6875.430192708969


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.017622909216284753 train acc: 0.99514 val loss: 0.8445834136962891 val acc: 0.8063 time: 7203.672232866287


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.013363480752706528 train acc: 0.99656 val loss: 0.8369158813476563 val acc: 0.8089 time: 7539.292580127716


### Augmentation inside network
Augmentation Appled at: After 3rd Resnet Block  
Augmentation Strategy: Random Pad Crop pad=2,cutSize=4

In [0]:
pad=2
cutSize=4
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight
    self.data_aug2 = lambda x: (self.augment(x)) 
  
  def augment(self,batch):    
    if training:
      return augmentDictChannel(batch,padding=pad,cutSize=cutSize)
    else:    
      return batch 

  def call(self, x, y):
    h = self.pool(self.blk4(self.data_aug2(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

model = DavidNet()
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  training=True
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  training=False
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 2.101784052734375 train acc: 0.21598 val loss: 1.6178125579833984 val acc: 0.4219 time: 309.32568979263306


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.8587495947265624 train acc: 0.2993 val loss: 1.5619397552490235 val acc: 0.4682 time: 625.7331986427307


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 1.711361294555664 train acc: 0.35638 val loss: 1.7987959777832032 val acc: 0.5005 time: 930.9837007522583


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 1.5483372924804688 train acc: 0.41706 val loss: 1.039345263671875 val acc: 0.632 time: 1221.2258667945862


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 1.651786128540039 train acc: 0.37214 val loss: 1.0984959533691405 val acc: 0.6484 time: 1563.8383326530457


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 1.4641927365112304 train acc: 0.45082 val loss: 1.093222280883789 val acc: 0.6502 time: 1875.6064884662628


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 1.4893231008911132 train acc: 0.43522 val loss: 0.9154947967529297 val acc: 0.6952 time: 2202.2212829589844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 1.4962323779296876 train acc: 0.4319 val loss: 1.177291064453125 val acc: 0.6799 time: 2541.656369447708


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 1.3922363619995117 train acc: 0.47142 val loss: 0.9436907257080078 val acc: 0.7039 time: 2857.6204600334167


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 1.2618546868896485 train acc: 0.52358 val loss: 0.8283124053955078 val acc: 0.7299 time: 3164.876187324524


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 1.575549006652832 train acc: 0.39222 val loss: 0.8418470153808594 val acc: 0.7486 time: 3534.6037640571594


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 1.2040432522583009 train acc: 0.54366 val loss: 0.9051451904296876 val acc: 0.7287 time: 3825.062479496002


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 1.2590344192504883 train acc: 0.52042 val loss: 1.165890576171875 val acc: 0.6795 time: 4148.470786809921


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 1.3502400874328613 train acc: 0.48416 val loss: 0.7973780151367188 val acc: 0.7698 time: 4486.6832275390625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 1.369206212310791 train acc: 0.47552 val loss: 1.1922957244873047 val acc: 0.7017 time: 4819.912066459656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 1.2105461709594727 train acc: 0.53912 val loss: 1.0249920684814453 val acc: 0.7213 time: 5144.5895075798035


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 1.1223518669128418 train acc: 0.57232 val loss: 0.9113342926025391 val acc: 0.7391 time: 5459.744210958481


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 1.0472087786102295 train acc: 0.60098 val loss: 0.883955404663086 val acc: 0.747 time: 5757.145701646805


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 1.072152600479126 train acc: 0.58752 val loss: 1.209646356201172 val acc: 0.6942 time: 6054.5174350738525


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 1.0048241163635254 train acc: 0.61474 val loss: 1.073508563232422 val acc: 0.7185 time: 6350.1826910972595


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.9919131321716309 train acc: 0.61846 val loss: 0.9957437072753906 val acc: 0.7323 time: 6633.752651929855


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 1.0913423049545288 train acc: 0.57872 val loss: 1.6057800079345703 val acc: 0.6483 time: 6953.001893758774


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 1.2155848331069947 train acc: 0.52984 val loss: 3.0441604614257813 val acc: 0.58 time: 7273.564126491547


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 1.0722630225753784 train acc: 0.58592 val loss: 0.7727822174072265 val acc: 0.7826 time: 7573.882131814957
