In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32')/255.
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32')/255.

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
model = keras.Sequential(
    [
     layers.Input(shape=(28,28,1)),
     layers.Conv2D(64, 3, padding='same'),
     layers.ReLU(),
     layers.Conv2D(128, 3, padding='same'),
     layers.ReLU(),
     layers.Flatten(),
     layers.Dense(10),
    ],

    name = 'basic_model'

)

# Custom fit

In [5]:
class CustomFit(keras.Model):
  def __init__(self, model):
    super(CustomFit, self).__init__()
    self.model = model
  
  def train_step(self, data):
    x, y = data

    # Forward Propagation
    with tf.GradientTape() as tape: # records operations for gradients
      y_pred = self.model(x, training=True)
      loss   = self.compiled_loss(y, y_pred) # loss in compile function

    # BackPropagation    
    training_vars = self.trainable_variables
    gradients = tape.gradient(loss, training_vars)
    self.optimizer.apply_gradients(zip(gradients, training_vars))

    # Updating metrics (mentioned in compile function)
    self.compiled_metrics.update_state(y, y_pred)

    return {m.name : m.result() for m in self.metrics}




training = CustomFit(model)
training.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
training.fit(x_train, y_train, batch_size=32, epochs=2)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7fcc049ad7b8>

## Custom compile_fit_evaluate <br>
<pre>  => compile (need to use own accuracy, loss and not compiled ones)
  => fit (train_step)
  => evaluate (test_step)
</pre>

In [13]:
class CustomTrainer(keras.Model):

  def __init__(self, model):
    super(CustomTrainer, self).__init__()
    self.model = model
  
  def compile(self, optimizer, loss):
    super(CustomTrainer, self).compile()
    self.optimizer = optimizer
    self.loss = loss
    self.acc_metric = keras.metrics.SparseCategoricalAccuracy(name='accuracy')

  def train_step(self, data):
    x, y = data

    # Forward Propagation
    with tf.GradientTape() as tape: # record operations for gradients
      y_pred = self.model(x, training=True)
      loss   = self.loss(y, y_pred) # here not compiled since custom compilation
    
    # Back Propagation
    training_vars = self.trainable_variables
    gradients = tape.gradient(loss, training_vars)
    self.optimizer.apply_gradients(zip(gradients, training_vars))

    # Metrics update(custom acc_metric and not compiled_metric)
    self.acc_metric.update_state(y, y_pred)

    return {'loss': loss, 'accuracy': self.acc_metric.result()}

  def test_step(self, data):
    x,y = data

    # Training is false, because BatchNormalization, Dropout etc
    # have different behaviour during train and test
    y_pred = self.model(x, training=False)
    loss   = self.loss(y, y_pred)
    self.acc_metric.update_state(y, y_pred)

    return {'loss':loss, 'accuracy':self.acc_metric.result()}




In [15]:
model = keras.Sequential(
    [
     layers.Input(shape=(28,28,1)),
     layers.Conv2D(64, 3, padding='same'),
     layers.ReLU(),
     layers.Conv2D(128, 3, padding='same'),
     layers.ReLU(),
     layers.Flatten(),
     layers.Dense(10),
    ],

    name = 'basic_model'

)

training = CustomTrainer(model)
training.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
training.fit(x_train, y_train, batch_size=32, epochs=2)
training.evaluate(x_test, y_test, batch_size=32)

Epoch 1/2
Epoch 2/2


0.9872000217437744