In [17]:
try:
  %tensorflow 2.x
except Exception:
  pass

In [18]:
import tensorflow as tf
import tensorflow_datasets as tdfs

In [19]:
class Block(tf.keras.Model):
    def __init__(self, filters, kernel_size, repetitions, pool_size=2, strides=2):
        super(Block, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.repetitions = repetitions
        for i in range(self.repetitions):
          vars(self)[f'conv2D_{i}'] = tf.keras.layers.Conv2D(self.filters, self.kernel_size,activation= 'relu', padding= 'same')
        self.max_pool = tf.keras.layers.MaxPool2D(pool_size, strides=strides)
  
    def call(self, inputs):
        conv2D_0 = vars(self)['conv2D_0']
        x = conv2D_0(inputs)
        for i in range(1,self.repetitions):
            conv2D_i = vars(self)[f'conv2D_{i}']
            x = conv2D_i(x)
        max_pool = self.max_pool(x)
        return max_pool

In [20]:
class VGG(tf.keras.Model):
  def __init__(self,classes):
    super(VGG,self).__init__()
    self.block_a=Block(filters=64,repetitions=2,kernel_size=3)
    self.block_b=Block(filters=128,repetitions=2,kernel_size=3)
    self.block_c=Block(filters=256,repetitions=3,kernel_size=3)
    self.block_d=Block(filters=512,repetitions=3,kernel_size=3)
    self.block_e=Block(filters=512,repetitions=3,kernel_size=3)
    self.flatten=tf.keras.layers.Flatten()
    self.fc=tf.keras.layers.Dense(256,activation='relu')
    self.classifier=tf.keras.layers.Dense(classes,activation='softmax')
  
  def call(self,inputs):
    x=self.block_a(inputs)
    x=self.block_b(x)
    x=self.block_c(x)
    x=self.block_d(x)
    x=self.block_e(x)
    x=self.flatten(x)
    x=self.fc(x)
    x=self.classifier(x)
    return x


In [21]:
import tensorflow_datasets as tfds
dataset=tfds.load('cats_vs_dogs', split=tfds.Split.TRAIN, data_dir='data/')
model=VGG(classes=2)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
def preprocess(features):
    image = tf.image.resize(features['image'], (224, 224))
    return tf.cast(image, tf.float32) / 255., features['label']

dataset = dataset.map(preprocess).batch(32)
model.fit(dataset, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7ffa9ab7ac90>