# Custom Training in Keras & TF 2.X
> use fit even if you have a custom training loop (cifar10 example)
- toc: true
- branch: master
- badgets: true
- comments: true
- image: images/train_loop.png
- author: Sajjad Ayoubi
- categories: [tips]

- with this new syntax from Keras, you can write a complex training loop using model subclassing


## Example 
- Cifar10 with Resnetish Model

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

In [2]:
# download cifar10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


- create new layer in tf

In [3]:
# new layer
class ResBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(ResBlock, self).__init__()
        self.c1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')
        self.c2 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')

    # forward step
    def call(self, inputs):
        x1 = self.c1(inputs)
        x2 = self.c2(x1+inputs)
        return x2

- create new model in tf
- you need to override the `train_step` function

In [4]:
# new model model
class Resnet18(tf.keras.Model):
    def __init__(self, n_class=10):
        super(Resnet18, self).__init__()
        self.first_conv = layers.Conv2D(32, (3, 3), activation='relu', padding='same')
        self.blocks = tf.keras.Sequential([ResBlock(), ResBlock()])
        self.faltten = layers.Flatten()
        self.fc = layers.Dense(n_class, activation='softmax')
    
    # forward step
    def call(self, x):
        x = self.first_conv(x)
        x = self.blocks(x)
        x = self.fc(self.faltten(x))
        return x
    
    # one batch train
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    # one batch test
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    # other usful functions
    def compile(self): pass
    def metrics(self): pass

- training

In [5]:
model = Resnet18()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
H = model.fit(x_train, y_train, epochs=5, batch_size=64)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Template
- you don't have to write your training loop from scrach

In [None]:
import tenserflow as tf

class Learner(tf.keras.Model):
    def train_step(self, data):
        # you have anything that passed to tf.Data
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # often you don't change the rest
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def predict_step(self, x):  pass


model = # builded model(Sequential, Functional, Application, Model-Subclassing)
learner = Learner(model.inputs, model.outputs)
learner.compile(optimizer=, loss=, metrics=) # compile model
learner.fit() # enjoy the abilitis of keras.fit :)