In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from unoai.imports import *
from unoai.data.datasets import *
from unoai.train import *

In [3]:
import shutil
import gc
import contextlib
import timeit

In [4]:
gfile = tf.io.gfile


In [5]:
DATASET_LOC = "datasets/cifar10"
MODEL_LOC = "models/cifar10"
BATCH_SIZE = 512
EPOCHS = 5
RESOLUTION = (32, 32)
NUM_CHANNELS = 3
NUM_TRAIN = 50000

In [6]:
MOMENTUM = 0.9
WEIGHT_DECAY = 0.000125
LEARNING_RATE = 0.9
EPOCHS = 15
WARMUP = 5

In [7]:
if os.path.isdir(MODEL_LOC):
    shutil.rmtree(MODEL_LOC)
gfile.makedirs(MODEL_LOC)

In [8]:
train_data, test_data = get_cifar10(ds_dir=DATASET_LOC, batch_size=BATCH_SIZE, normalize=True)

100%|██████████| 50000/50000 [00:32<00:00, 1537.94it/s]
100%|██████████| 10000/10000 [00:06<00:00, 1535.27it/s]


In [9]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

class ConvBN(tf.keras.Model):
  def __init__(self, c_out, virtual_batch_size=None):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, virtual_batch_size=None)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

class ConvPoolBNAct(tf.keras.Model):
  def __init__(self, c_out, virtual_batch_size=None):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, virtual_batch_size=None)
    self.pool = tf.keras.layers.MaxPool2D()

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.pool(self.conv(inputs))))

class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out, 8)
    self.conv_pool_bn_act = ConvPoolBNAct(c_out, 8)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    #h = self.pool(self.conv_bn(inputs))
    h = self.conv_pool_bn_act(inputs)
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

class DavidNet(tf.keras.Model):
  def __init__(self, c=32, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*3, pool, res = True)
    self.blk2 = ResBlk(c*6, pool)
    self.blk3 = ResBlk(c*9, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    return h

In [16]:
# LR Schedule + SGD code taken from Fenwicks
# TODO: Can be removed since it's mostly the same code as old one

def warmup_lr_sched(step: tf.Tensor, warmup_steps: int, init_lr: float, lr) -> tf.Tensor:
    step = tf.cast(step, tf.float32)
    warmup_steps = tf.constant(warmup_steps, dtype=tf.float32)
    warmup_lr = init_lr * step / warmup_steps
    is_warmup = tf.cast(step < warmup_steps, tf.float32)
    return (1.0 - is_warmup) * lr + is_warmup * warmup_lr
  
def linear_decay() -> Callable:
    return functools.partial(tf.compat.v1.train.polynomial_decay, end_learning_rate=0.0, power=1.0, cycle=False)


def one_cycle_lr(init_lr: float, total_steps: int, warmup_steps: int, decay_sched: Callable) -> Callable:
    def lr_func(step: tf.Tensor = None):
        if step is None:
            step = tf.compat.v1.train.get_or_create_global_step()

        lr = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
        lr = decay_sched(lr, step - warmup_steps, total_steps - warmup_steps)
        return lr if warmup_steps == 0 else warmup_lr_sched(step, warmup_steps, init_lr, lr)

    return lr_func
  

class SGD(tf.compat.v1.train.MomentumOptimizer):
    def __init__(self, lr: tf.Tensor, mom: float, wd: float):
        super().__init__(lr, momentum=mom, use_nesterov=True)
        self.wd = wd

    def compute_gradients(self, loss: tf.Tensor, var_list: List[tf.Tensor] = None, **kwargs) -> List[
        Tuple[tf.Tensor, tf.Tensor]]:
        grads_and_vars = super().compute_gradients(loss, var_list=var_list)

        l = len(grads_and_vars)
        for i in range(l):
            g, v = grads_and_vars[i]
            g += v * self.wd
            grads_and_vars[i] = (g, v)

        return grads_and_vars
      
def sgd_optimizer(lr_func: Callable, mom: float = 0.9, wd: float = 0.0) -> Callable:
    def opt_func():
        lr = lr_func()
        return SGD(lr, mom=mom, wd=wd)

    return opt_func

In [17]:
steps_per_epoch = NUM_TRAIN // BATCH_SIZE
total_steps = steps_per_epoch * EPOCHS
warmup_steps = steps_per_epoch * WARMUP

In [18]:
lr_decay = linear_decay()
lr_func = one_cycle_lr(LEARNING_RATE/BATCH_SIZE, total_steps, warmup_steps, lr_decay)

In [21]:
# opt_func = sgd_optimizer(lr_func, mom=MOMENTUM, wd=WEIGHT_DECAY*BATCH_SIZE)
opt_func = tf.keras.optimizers.Adam

In [22]:
model = train_model_custom(train_ds=train_data, test_ds=test_data, epochs=2, model_fn=DavidNet, opt_fn=opt_func, loss_fn=tf.compat.v1.losses.sparse_softmax_cross_entropy)

  0%|          | 0/2 [00:00<?, ?it/s]

KeyboardInterrupt: 