# **What is Subclassing in Tensorflow?**

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = layers.Dense(128, activation= 'relu')
    self.dense2 = layers.Dense(64, activation = 'relu')
    self.output_layer = layers.Dense(10, activation='softmax')

  def call(self, inputs):
    x = self.dense1(inputs)
    x = self.dense2(x)
    return self.output_layer(x)

In [None]:
# Instantiating using the model
model = MyModel()

# Example Input
input_data = tf.random.normal([32,784]) # Batch of 32 samples, each of 784 features

# forward pass
output = model(input_data)

# Check output shape
print(output.shape)

(32, 10)


In [None]:
# Defining the custom training loop.

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

# Custom training step

@tf.function
def train_step(model,images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images, training = True)
    loss = loss_fn(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return loss

In [None]:
# Training the model

# load MNIST dataset

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
x_train, x_test = x_train/255.0, x_test/255.0

# create Tensorflow datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_dataset = train_dataset.from_tensor_slices((x_test, y_test)).batch(32)

# Training loop
epochs = 10
for epoch in range(epochs):
  print(f"Epoch {epoch+1}/{epochs}")
  for images, labels in train_dataset:
    images = tf.reshape(images, [-1, 784]) # Flatten the images
    loss = train_step(model, images, labels)
  print(f"Loss: {loss.numpy()}")


Epoch 1/10


  output, from_logits = _get_logits(


Loss: 0.02893536537885666
Epoch 2/10
Loss: 0.06303180754184723
Epoch 3/10
Loss: 0.03357594460248947
Epoch 4/10
Loss: 0.006820914335548878
Epoch 5/10
Loss: 0.007841427810490131
Epoch 6/10
Loss: 0.021507922559976578
Epoch 7/10
Loss: 0.0003730443713720888
Epoch 8/10
Loss: 0.00013930546992924064
Epoch 9/10
Loss: 0.0009413263760507107
Epoch 10/10
Loss: 0.0001044063683366403
