In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def resize_and_rescale(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [150, 150])
  image = (image / 255.0)
  return image, label

In [4]:
def prepreprocess(image, label):
  resized_image = tf.image.resize(image, [150,150])
  resized_image /= 255.0
  return resized_image, label

In [5]:
def f(image, label):
  seed = rng.make_seeds(2)[0]
  image, label = augment((image, label), seed)
  return image, label
  

In [7]:
def augment(image_label, seed):
  image, label = image_label
  image, label = prepreprocess(image, label)
  image = tf.image.resize_with_crop_or_pad(image, 150 + 6, 150 + 6)
  # Make a new seed.
  new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
  # Random crop back to the original size.
  image = tf.image.stateless_random_crop(
      image, size=[150, 150, 3], seed=seed)
  
#   #Random contrast
#   image = tf.image.stateless_random_contrast(
#       image, lower=0.1, upper=0.9, seed=seed)
  
  #Random flip
  image = tf.image.stateless_random_flip_left_right(
    image, seed = seed
)

  # Random brightness.
  image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
  image = tf.clip_by_value(image, 0, 1)
  return image, label

In [8]:
def preprocess():
    # Loads and splits the data into training and validation splits using tfds.
    (ds_train, ds_validation, ds_test), ds_info = tfds.load('beans', split = ['train', 'validation', 'test'], as_supervised = True, with_info = True)

    ds_train = ds_train.map(f).cache().shuffle(
        ds_info.splits['train'].num_examples).batch(BATCH_SIZE).prefetch(
        tf.data.experimental.AUTOTUNE)

    ds_validation = ds_validation.map(prepreprocess).batch(BATCH_SIZE).cache().prefetch(tf.data.experimental.AUTOTUNE)

    ds_test = ds_test.map(prepreprocess).batch(BATCH_SIZE).cache().prefetch(tf.data.experimental.AUTOTUNE)

    return ds_train, ds_validation, ds_test


In [23]:
def create_model(train_set, validation_set, test_set):
  
  model =  tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation = 'relu', input_shape = [150,150,3]),
        tf.keras.layers.MaxPool2D(2),
        tf.keras.layers.Conv2D(64, 3, activation = 'relu'),
        tf.keras.layers.MaxPool2D(2),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Conv2D(128, 3, activation = 'relu'),
        tf.keras.layers.MaxPool2D(2),
        tf.keras.layers.Conv2D(256, 3, activation = 'relu'),
        tf.keras.layers.MaxPool2D(2),
        tf.keras.layers.Dropout(0.7),       

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128,activation = 'relu'),
        tf.keras.layers.Dropout(0.7),

        tf.keras.layers.Dense(3, activation = 'softmax')

    ])
  
  model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(), optimizer = tf.keras.optimizers.Adam(), metrics = ['accuracy'])

  # callback = tf.keras.callbacks.EarlyStopping(monitor = 'val_accuracy', patience = 4)
  # lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epochs: 1e-4 * 10 ** (epochs / 20))

  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)
  

  model.fit(train_set, epochs = 100, validation_data = validation_set, callbacks = [model_checkpoint_callback])

  # plt.semilogx(history.history["lr"], history.history["val_accuracy"])
  # plt.axis([1e-4, 10, 0, 10])
  # plt.ylim([0,1])

  return model, test_set

In [24]:
def beans_classifier():
  train_set, validation_set, test_set = preprocess()  
  return create_model(train_set, validation_set, test_set)

In [25]:
if __name__ == '__main__':
  rng = tf.random.Generator.from_seed(123, alg='philox')
  BATCH_SIZE = 64
  checkpoint_filepath = '/tmp/checkpoint'

  model, ds_test = beans_classifier()

  model.load_weights(checkpoint_filepath)
  loss, acc = model.evaluate(ds_test)

  print("Loss on the test set is :", loss)
  print("Accuracy on the test set is :", acc)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78