<a href="https://colab.research.google.com/github/zeynep68/CIFAR100/blob/main/wide_resnet_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import numpy as np
import tensorflow as tf

from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

from wide_resnet import create_wide_resnet

In [8]:
def load_data(num_classes=10):
  (x_train, y_train), (x_test, y_test) = cifar10.load_data()

  x_train = tf.cast(x_train, dtype=tf.float32)
  x_test = tf.cast(x_test, dtype=tf.float32)

  y_train = to_categorical(y_train, num_classes)
  y_test = to_categorical(y_test, num_classes)

  x_train /= 255.
  x_test /= 255.

  mean = np.mean(x_train, axis=(0,1,2))
  std = np.std(x_train, axis=(0,1,2))

  x_train = (x_train - mean) / std
  x_test = (x_test - mean) / std

  return (x_train, y_train), (x_test, y_test)

In [9]:
def lr_schedule(epoch, lr):
  if epoch > 55:
      lr *= 1e-2
  elif epoch >= 40:
      lr *= 1e-1

  return lr

In [10]:
""" Block types used in the wide resnet paper. """
BLOCK_TYPE = [[3,3], [3,1,3], [1,3,1], [1,3], [3,1], [3,1,1]

In [12]:
if __name__ == "__main__":
  (x_train, y_train), (x_test, y_test) = load_data()

  lr_scheduler = LearningRateScheduler(lr_schedule)

  model = create_wide_resnet(x_train, avg_pool=8, k=4, n=[2,2,2], kernels=BLOCK_TYPE[6], learning_rate=1e-3)

  EPOCHS = 70
  BATCH_SIZE = 128

  training = model.fit(x_train,
                      y_train,
                      batch_size=BATCH_SIZE,
                      epochs=EPOCHS,
                      validation_data=(x_test, y_test),
                      callbacks=[lr_scheduler])

Epoch 1/120
Epoch 2/120
Epoch 3/120
Epoch 4/120
Epoch 5/120
Epoch 6/120
Epoch 7/120
Epoch 8/120
Epoch 9/120
Epoch 10/120
Epoch 11/120
Epoch 12/120
Epoch 13/120
Epoch 14/120
Epoch 15/120
Epoch 16/120
Epoch 17/120
Epoch 18/120
Epoch 19/120
Epoch 20/120
Epoch 21/120
Epoch 22/120
Epoch 23/120
Epoch 24/120
Epoch 25/120
Epoch 26/120
Epoch 27/120
Epoch 28/120
Epoch 29/120
Epoch 30/120
Epoch 31/120
Epoch 32/120
Epoch 33/120
Epoch 34/120
Epoch 35/120
Epoch 36/120
Epoch 37/120
Epoch 38/120
Epoch 39/120
Epoch 40/120
Epoch 41/120
Epoch 42/120
Epoch 43/120
Epoch 44/120
Epoch 45/120
Epoch 46/120
Epoch 47/120
Epoch 48/120
Epoch 49/120
Epoch 50/120
Epoch 51/120
Epoch 52/120
Epoch 53/120
Epoch 54/120
Epoch 55/120
Epoch 56/120
Epoch 57/120
Epoch 58/120

KeyboardInterrupt: ignored