##### Copyright 2019 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
#@title MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

# Transfer learning and fine-tuning with TensorFlow Hub

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/images/transfer_learning"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning.ipynb?force_kitty_mode=1&force_corgi_mode=1"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/transfer_learning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/transfer_learning.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>
  <td>
    <a href="https://tfhub.dev/google/collections/image/1"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub models</a>
  </td>
</table>

This tutorial demonstrates how to use two transfer learning strategies—_feature extraction_ and _fine-tuning_—for applying pre-trained representations to a downstream task, such as image classification. You will classify images of cats and dogs by using pre-trained models from [TensorFlow Hub](https://tfhub.dev).

Transfer learning consists of taking features learned on one task, and leveraging them on a new, similar task. A pre-trained model is a saved network that was previously trained model on a typically large dataset for a particular task, such as image classification. You can use the pre-trained model on an as is basis or customize it to a given task.

Customizing a pre-trained model with transfer learning can be done in two ways:

1. _Feature extraction_: Use the representations learned by a previous pre-trained model to extract meaningful features from new samples. That is, you repurpose the feature maps previously learned from the dataset the pre-trained model trained on. During feature extraction, you train only the last (few) layer(s).

  For image classification, you add a new classifier, such as a new fully-connected layer (`tf.keras.layer.Dense`), on top of the base (pre-trained) model. You freeze the base model's layers, and then train only the final output layer. The base (pre-trained) model will already contain the features that can be useful for your particular task. (The final part of the original base model is specific to the original task, and subsequently specific to the set of classes on which the model was trained.)

1. _Fine-tuning_: Use the weights/parameters of the base pre-trained model to initialize the training on a new task, unfreeze a few of the top layers, and jointly train both the newly-added layer(s) (similar to feature extraction), and the unfrozen layers of the base model. This method allows to fine-tune the higher-order feature representations in the base model to make them more relevant for the specific task. Fine-tuning is also known to reduce training resources.

The notebook's workflow for classification with pre-trained models is as follows:

1. Load the data and build an input pipeline.
1. Load a pre-trained base model.
1. Apply a pre-trained model on sample data without any training (weight updates) to test the model.
1. Perform classification using transfer learning: 
  - Perform feature extraction by training only the classification layer.
  - Perform fine-tuning by updating the weights of parts of the pre-trained network.

## Setup

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tempfile
import datetime
import PIL.Image as Image
import tensorflow as tf
import tensorflow_hub as hub
%load_ext tensorboard

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "not available")

## Data preprocessing

### Download the dataset

Download and extract a ZIP file containing several thousand images of cats and dogs, then create a `tf.data.Dataset` for training and validation using the `tf.keras.utils.image_dataset_from_directory` utility. You can learn more in the [Load and preprocess images](https://www.tensorflow.org/tutorials/load_data/images) tutorial.

In [None]:
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (224, 224)

# Create a training set.
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)

In [None]:
# Create a validation set.
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)

Inspect the first nine images and labels from the training set:

In [None]:
class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

Create a test set using the validation set, since the original dataset doesn't contain one.

First you determine how many batches of data are available in the validation set using `tf.data.experimental.cardinality`, then move 20% of them to a test set:

In [None]:
val_batches = tf.data.experimental.cardinality(validation_dataset)

# Create a test set.
test_dataset = validation_dataset.take(val_batches // 5)

# Revise the existing validation set.
validation_dataset = validation_dataset.skip(val_batches // 5)

In [None]:
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))

### Rescale pixel values

The original pixel values in the dataset images are in the `[0, 255]` range. The pre-trained models from TensorFlow Hub expect float inputs in the `[0, 1]` range for image inputs. Therefore, you need to normalize the data by using the `tf.keras.layers.Rescaling` preprocessing layer.

Note: You could also include the `tf.keras.layers.Rescaling` layer inside the model. Refer to the [Working with preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) guide for a discussion of the tradeoffs.

In [None]:
normalization_layer = tf.keras.layers.Rescaling(1./255)

train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
validation_dataset = validation_dataset.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.

### Configure the dataset for performance with `tf.data`

Finish the input pipeline by using buffered prefetching (`Dataset.prefetch`) to load images from disk without having I/O become blocking. You can learn more about this and other methods in the [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance) guide.

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

In [None]:
for image_batch, labels_batch in train_dataset:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

### Data augmentation

When you don't have a large image dataset, it's a good practice to artificially introduce sample diversity by applying random, yet realistic, transformations to the training images, such as rotation and horizontal flipping. This helps expose the model to different aspects of the training data and reduce [overfitting](https://www.tensorflow.org/tutorials/keras/overfit_and_underfit). For more information, check out the [Data augmentation](https://www.tensorflow.org/tutorials/images/data_augmentation) tutorial.

In [None]:
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
])

Note: These layers are active only during training when you call Keras `Model.fit`. They are inactive when the model is used in inference mode in `Model.evaluate` or `Model.fit`.

Repeatedly apply these layers to the same image and view the results:

In [None]:
image_batch, label_batch = next(iter(train_dataset))
plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  augmented_image = data_augmentation(image_batch)
  plt.imshow(augmented_image[i])
  label = label_batch[i]
  plt.title(class_names[label])
  plt.axis("off")

## Classification with a pre-trained model

### Create the base model

Here, you will use a pre-trained <a href="https://arxiv.org/abs/1801.04381" class="external">MobileNet V2</a> benchmark classification model [from TensorFlow Hub](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2), and wrap it as a Keras layer with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer). Any <a href="https://tfhub.dev/s?q=tf2&module-type=image-classification/" class="external">compatible image classifier model</a> from TensorFlow Hub will work here, including the examples provided in the drop-down below. These can be used to easily perform transfer learning.

MobileNet V2 is pre-trained on the <a href="https://en.wikipedia.org/wiki/ImageNet" class="external">ImageNet</a> dataset, a large dataset consisting of 1.4 million images and 1,000 classes. This base of knowledge will help classify cats and dogs from your specific dataset.

In [None]:
mobilenet_v2 ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
inception_v3 = "https://tfhub.dev/google/imagenet/inception_v3/classification/5"

pretrained_model = mobilenet_v2 #@param ["mobilenet_v2", "inception_v3"] {type:"raw"}

In [None]:
# Create the base model from the pre-trained MobileNet V2 model.

IMAGE_SHAPE = (224, 224)

base_model = tf.keras.Sequential([
    hub.KerasLayer(pretrained_model, input_shape=IMAGE_SHAPE+(3,))
])

### Apply the pre-trained model on a test image

To test the model, download a single image:

In [None]:
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper

Add a batch dimension (with `np.newaxis`) and pass the image to the model:

In [None]:
grace_hopper = np.array(grace_hopper)/255.0
print(grace_hopper.shape)

result = base_model.predict(grace_hopper[np.newaxis, ...])
print(result.shape)

The result is a 1,001-element vector of logits, rating the probability of each class for the image.

The top class ID can be found with `tf.math.argmax`:

In [None]:
predicted_class = tf.math.argmax(result[0], axis=-1)
predicted_class

Take the `predicted_class` ID (such as `653`) and fetch the ImageNet dataset labels to decode the predictions:

In [None]:
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())

plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

### Apply the pre-trained model on the training set

You can also run the classifier on a batch of images from the training set. Run the classifier on an image batch:

In [None]:
result_batch = base_model.predict(train_dataset)

predicted_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
predicted_class_names

# Check how the predictions line up with the images.
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

## Feature extraction

You can also a _custom_ classifier using your own dataset that has classes that aren't included in the original ImageNet dataset (that the pre-trained model was trained on)

To do that, you can:

1. Select a pre-trained SavedModel from TensorFlow Hub; and
2. Retrain the top (last) layer to recognize the classes from your custom dataset.

### Pick a headless classifier model

A headless classifier model is a model without the top classification layer.

Select a <a href="https://arxiv.org/abs/1801.04381" class="external">MobileNetV2</a> pre-trained model <a href="https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4" class="external">from TensorFlow Hub</a>. Any <a href="https://tfhub.dev/s?module-type=image-feature-vector&q=tf2" class="external">compatible image feature vector model</a> from TensorFlow Hub will work here, including the examples from the drop-down menu.


In [None]:
mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"

feature_extractor_model = mobilenet_v2 #@param ["mobilenet_v2", "inception_v3"] {type:"raw"}

### Freeze the convolutional base

In this section, you will create the feature extractor by wrapping the pre-trained TensorFlow Hub model as a Keras layer with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer).

To freeze the convolutional base created from the previous step and use it as a feature extractor, add the `trainable=False` argument. By freezing the variables, the training only modifies the new classifier layer. Additionally, you add a classifier on top of it and train the top-level classifier. It is important to freeze the convolutional base before you compile and train the model. Freezing prevents the weights in a given layer from being updated during training. MobileNet V2 has many layers, so setting the entire model's `trainable` flag to `False` will freeze all of them.

In [None]:
new_base_model = hub.KerasLayer(
    feature_extractor_model,
    input_shape=(224, 224, 3),
    trainable=False)

The feature extractor returns a 1280-long vector for each image (the image batch size remains at 32 in this example):

In [None]:
feature_batch = new_base_model(image_batch)
print(feature_batch.shape)

### Important note about `BatchNormalization` layers

Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial. 

When you set `trainable = False` in `hub.KerasLayer`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics. 

When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing `training = False` when calling the base model. Otherwise, the updates applied to the non-trainable weights will destroy what the model has learned.

You can learn more in the Keras [Transfer learning and fine-tuning](https://www.tensorflow.org/guide/keras/transfer_learning) guide.

### Add a classification head

Apply a `tf.keras.layers.Dense` layer to convert these features into a single prediction per image. You don't need an activation function here because this prediction will be treated as a `logit`, or a raw prediction value. Positive numbers predict class 1, negative numbers predict class 0.

In [None]:
num_classes = len(class_names)

model = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
  new_base_model,
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(num_classes)
])

In [None]:
predictions = model(image_batch)
predictions.shape

### Compile the model

Use `Model.compile` to configure the training process and add a `tf.keras.callbacks.TensorBoard` callback to create and store logs. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True`, since the model provides a linear output.

In [None]:
BASE_LEARNING_RATE = 0.0001

model.compile(
  optimizer=tf.keras.optimizers.Adam(learning_rate=BASE_LEARNING_RATE),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1) # Enable histogram computation for every epoch.

In [None]:
model.summary()

While MobileNet's parameters are frozen, there are _trainable_ parameters in the `Dense` layer. These are divided between two `tf.Variable` objects—the weights and biases.

In [None]:
len(model.trainable_variables)

### Train the model

Now use the `Model.fit` method to train the model.

To keep this example short, you'll be training for just 10 epochs. To visualize the training progress in TensorBoard later, create and store logs an a [TensorBoard callback](https://www.tensorflow.org/tensorboard/get_started#using_tensorboard_with_keras_modelfit).

In [None]:
NUM_EPOCHS = 10

history = model.fit(train_dataset,
                    validation_data=validation_dataset,
                    epochs=NUM_EPOCHS,
                    callbacks=tensorboard_callback)

### Learning curves

Start the TensorBoard to view the learning curves of the training and validation accuracy/loss when using the MobileNetV2 base model as a fixed feature extractor:

In [None]:
#docs_infra: no_execute
%tensorboard --logdir logs/fit

<!-- <img class="tfo-display-only-on-site" src="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/images/tensorboard_transfer_learning_feature_extraction.png?raw=1"/> -->

Note: If you are wondering why the validation metrics may be better than the training metrics, the main factor could be because layers like `tf.keras.layers.BatchNormalization` and `tf.keras.layers.Dropout` affect accuracy during training. They are turned off when calculating the validation loss. To a lesser extent, it could also be because training metrics report the average for an epoch, while validation metrics are evaluated after the epoch, so validation metrics observe a model that has trained for a slightly longer period.

To save the trained model in a SavedModel format for deployment to TensorFlow Serving or TensorFlow Lite, you can use `tf.saved_model.save`:

In [None]:
tmpdir = tempfile.mkdtemp()
model_save_path = os.path.join(tmpdir, "temp_saved_model")
tf.saved_model.save(model, model_save_path)

To load the SavedModel back into Python, use `tf.saved_model.load` as follows:

In [None]:
loaded = tf.saved_model.load(model_save_path)
print(list(loaded.signatures.keys()))

[Learn more](https://www.tensorflow.org/guide/saved_model) about the SavedModel format.

## Fine-tuning

In the feature extraction experiment, you were only training a few layers on top of a MobileNetV2 base model. The weights of the pre-trained network were _not_ updated during training.

One way to increase performance even further is to train—or _fine-tune_—the weights of the top layers of the pre-trained model alongside the training of the classifier you added. The training process will force the weights to be tuned from generic feature maps to features associated specifically with the dataset.

Note: This should only be attempted after you have trained the top-level classifier with the pre-trained model set to non-trainable. If you add a randomly initialized classifier on top of a pre-trained model and attempt to train all layers jointly, the magnitude of the gradient updates will be too large (due to the random weights from the classifier) and your pre-trained model will forget what it has learned.

Also, you should try to fine-tune a small number of top layers rather than the whole MobileNet model. In most convolutional networks, the higher up a layer is, the more specialized it is. The first few layers learn very simple and generic features that generalize to almost all types of images. As you go higher up, the features are increasingly more specific to the dataset on which the model was trained. The goal of fine-tuning is to adapt these specialized features to work with the new dataset, rather than overwrite the generic learning.

### Unfreeze the top layer of the model


Previously, when creating `new_base_model` earlier with TensorFlow Hub's `hub.KerasLayer`, you set a boolean attribute called `trainable` to `False`. 

Now, you will _unfreeze_ the `KerasLayer` in the model, so that the model's parameters become trainable (`trainable=True`) and the state of that layer can now be updated during training.

In [None]:
# Unfreeze the model's `feature_extractor_layer`.
new_base_model.trainable = True

### Recompile the model

As you are training a much larger model and want to readapt the pre-trained weights, it is important to use a lower learning rate at this stage. Otherwise, your model could overfit very quickly.

In [None]:
BASE_LEARNING_RATE = 0.0001

model.compile(
  optimizer=tf.keras.optimizers.Adam(learning_rate=BASE_LEARNING_RATE/10),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

In [None]:
len(model.trainable_variables)

In [None]:
model.summary()

### Continue training the model

If you trained to convergence earlier, this step should improve your accuracy by a few percentage points:

In [None]:
FINE_TUNE_EPOCHS = 10
TOTAL_EPOCHS =  NUM_EPOCHS + FINE_TUNE_EPOCHS

history_fine = model.fit(train_dataset,
                         epochs=TOTAL_EPOCHS,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset,
                         callbacks=tensorboard_callback)

Review the learning curves of the training and validation accuracy/loss when fine-tuning the last few layers of the MobileNetV2 base model, and training the classifier on top of it.

- Notice the validation loss is much higher than the training loss, so you may get some overfitting.
- You may also get some overfitting as the new training set is relatively small and similar to the original MobileNetV2 datasets.


In [None]:
%tensorboard --logdir logs/fit

<!-- <img class="tfo-display-only-on-site" src="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/images/tensorboard_transfer_learning_fine_tuning.png?raw=1"/> -->

After fine-tuning the model nearly reaches 98% accuracy on the validation set.

### Evaluation and prediction

Verify the performance of the model on new data using the test set:

In [None]:
loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)

And now you are all set to use this model to predict if your image contains a cat or dog:

In [None]:
# Retrieve a batch of images from the test set.
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid activation function, since the model returns logits.
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")

## Summary

* **Using a pre-trained model for feature extraction**: When working with a small dataset, it is a common practice to take advantage of features learned by a model trained on a larger dataset in the same domain. This is done by instantiating the pre-trained model and adding a fully-connected classifier on top. The pre-trained model is frozen and only the weights of the classifier get updated during training. In this scenario, the convolutional base extracts all the features associated with each image, and you train only the classifier that determines the image class given the set of extracted features.

* **Fine-tuning a pre-trained model**: To further improve performance, you can repurpose the top-level layers of the pre-trained models to the new dataset via fine-tuning. In this case, you tune your weights such that your model learns high-level features specific to the dataset. This technique is usually recommended when the training dataset is large and very similar to the original dataset that the pre-trained model was originally trained on.

Learn more about [transfer learning in the Keras guide](https://www.tensorflow.org/guide/keras/transfer_learning).
