<a href="https://colab.research.google.com/github/ozgekokyay/tensorflow-deep-learning/blob/main/TFFlowers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow_datasets as tfds
import numpy as np
import tensorflow as tf

tf.__version__

In [None]:
splits = ['train[:90%]',
          'train[90%:95%]',
          'train[95%:]'
          ]

In [None]:
(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers', 
                                                  split=splits, data_dir='./flowers', 
                                                  as_supervised=True, with_info=True)

In [None]:
metadata

In [None]:
split_weights = (90, 5, 5)
num_train, num_val, num_test = (int(metadata.splits['train'].num_examples * weight / 100) for weight in split_weights)
num_train, num_val, num_test

In [None]:
def resize_and_normalize(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, (128, 128))
  image = image / 255.0
  return image, label
  

In [None]:
def augment(image, label): 
  seed = (1,2)
  image = tf.image.stateless_random_flip_left_right(image, seed=seed)
  image = tf.image.stateless_random_brightness(image, 0.2, seed=seed)
  image = tf.image.stateless_random_contrast(image, 0.8, 1.0, seed=seed)
  return image, label


In [None]:
train_ds = train_ds.map(resize_and_normalize)
val_ds = val_ds.map(resize_and_normalize)
test_ds = test_ds.map(resize_and_normalize)


In [None]:
train_ds =train_ds.map(augment)


In [None]:
train_ds = train_ds.shuffle(1024).batch(32)     
val_ds = val_ds.batch(32)
test_ds = test_ds.batch(32)
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:

def create_model():
    img_inputs = tf.keras.Input(shape=(128, 128, 3))
    conv_1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu')(img_inputs)
    maxpool_1 = tf.keras.layers.MaxPooling2D((2, 2))(conv_1)
    conv_2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(maxpool_1)
    maxpool_2 = tf.keras.layers.MaxPooling2D((2, 2))(conv_2)
    conv_3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')(maxpool_2)
    maxpool_3 = tf.keras.layers.MaxPool2D((2, 2))(conv_3)
    flatten = tf.keras.layers.Flatten()(conv_3)
    drop_1 = tf.keras.layers.Dropout(0.2)(flatten)
    dense_1 = tf.keras.layers.Dense(64, activation='relu')(drop_1)
    drop_2 = tf.keras.layers.Dropout(0.5)(dense_1)
    output = tf.keras.layers.Dense(metadata.features['label'].num_classes)(drop_2)

    model = tf.keras.Model(inputs=img_inputs, outputs=output)
    
    return model


In [None]:
model = create_model()
model.summary()

In [None]:
tf.keras.utils.plot_model(model, 'flower_model_with_shape_info.png', show_shapes=True)

In [None]:
import datetime, os

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(log_dir)

In [None]:
steps_per_epoch = int(num_train) // 32
validation_steps = int(num_val) // 32

In [None]:
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        'training_checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5', period=5)
os.makedirs('training_checkpoints/', exist_ok=True)
early_stopping_checkpoint = tf.keras.callbacks.EarlyStopping(patience=20)

    

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs


In [None]:
history = model.fit(train_ds.repeat(),
              epochs=20, 
              steps_per_epoch=steps_per_epoch,
              validation_data=val_ds.repeat(),
              validation_steps=validation_steps,
              callbacks=[tensorboard_callback,
                         model_checkpoint_callback,
                         early_stopping_checkpoint])
    
    

In [None]:
preds = model.predict(test_ds.take(1))

In [None]:
preds.shape, preds

In [None]:
preds_indices = np.argmax(preds, axis=-1)
preds_indices.shape, preds_indices