# Train the model

In [1]:
import os
import datetime
import tensorflow as tf
import tensorflow_datasets as tfds

## Prepare training data

In [2]:
tfds.disable_progress_bar()

### Load images from the "cats vs dogs" dataset

In [3]:
(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

### Fit images to network inputs

In [4]:
IMG_SIZE = 160 # All images will be resized to 160x160
def format_example(image, label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

### Shuffle training data

In [5]:
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

## Load the untrained model

In [6]:
model = tf.keras.models.load_model('saved_model/mobilenetv2-untrained')






In [7]:
is_trainable = False
for layer in model.layers:
    if layer.name == 'top_start':
        is_trainable = True
    layer.trainable = is_trainable

## Compile the model

In [8]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

## Train the top of the model
Progress can be visualized using the command:
```
tensorboard --logdir logs/scalars
```

In [9]:
initial_epochs = 10
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
history = model.fit(train_batches,
                    epochs=initial_epochs,
                    validation_data=validation_batches,
                    callbacks=[tensorboard_callback])

Epoch 1/10
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Instructions for updating:
use `tf.profiler.experimental.stop` instead.


## Fine tune the last layers of the network

In [10]:
n_fine_tune_layers = 30
for layer in model.layers[:-n_fine_tune_layers]:
    layer.trainable =  False
for layer in model.layers[-n_fine_tune_layers:]:
    layer.trainable = True

model.compile(optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs
history_fine = model.fit(train_batches,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_batches,
                         callbacks=[tensorboard_callback])

Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20




## Save the model

In [11]:
os.makedirs('saved_model', exist_ok=True)
model.save('saved_model/mobilenetv2')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/mobilenetv2/assets


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/mobilenetv2/assets
