##### Copyright 2022 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# On-device Text-to-Image Search with TensorFlow Lite Searcher Library

In this colab, we showcase an end to end example of how to train an image-text dual encoder model and how to perform retrieval with TFLite Searcher Library. We are going to use the [COCO 2014](https://cocodataset.org/#home) dataset, and in the end you'll be able to retrieve images using a text description.

First, we need to encode the images into high-dimensional vectors. Then we index them with [Model Maker Searcher API](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/). During inference, a TFLite text embedder encodes the text query into another high-dimensional vector in the same embedding space, and invokes the [on-device ScaNN searcher](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/scann_ondevice) to retrieve similar images.


You can download the pre-trained searcher model packed with ScaNN index from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/searcher_model.tflite) and skip to [inference](#scrollTo=EeZwqEnxW5Xl). Be sure to name it `searcher_model.tflite` and upload it to colab under the current working directory.

In [1]:
!pip install -q -U tensorflow tensorflow-hub tensorflow-addons
!pip install -q -U tflite-support
!pip install -q -U tflite-model-maker
!pip install -q -U tensorflow-text==2.10.0b2
!sudo apt-get -qq install libportaudio2  # Needed by tflite-support

[31mERROR: Could not find a version that satisfies the requirement tensorflow-addons (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for tensorflow-addons[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m390.3/390.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for tflite-support (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.7/10.7 MB[0m [31m94.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Getting requirements to build wheel ... [?25l[?25herro

Note you might need to restart the runtime after installation.

In [4]:
# 1. Install specific compatible versions
# We remove tensorflow-addons entirely.
!pip install -q -U "tensorflow-text==2.15.*" "tensorflow==2.15.*"

import json
import math
import os
import pickle
import random
import shutil
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import tensorflow.compat.v1 as tf1
from tensorflow.keras import layers
# import tensorflow_addons as tfa  <-- DELETED (No longer needed)
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow_text.python.ops import fast_sentencepiece_tokenizer as sentencepiece_tokenizer

# Suppressing tf.hub warnings
tf.get_logger().setLevel('ERROR')

print("Imports successful. TensorFlow version:", tf.__version__)

[31mERROR: Could not find a version that satisfies the requirement tensorflow-text==2.15.* (from versions: 2.18.1, 2.19.0rc0, 2.19.0)[0m[31m
[0m[31mERROR: No matching distribution found for tensorflow-text==2.15.*[0m[31m
[0mImports successful. TensorFlow version: 2.19.0


In [5]:
DATASET_DIR = 'datasets'
CAPTION_URL = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'
TRAIN_IMAGE_URL = 'http://images.cocodataset.org/zips/train2014.zip'
VALID_IMAGE_URL = 'http://images.cocodataset.org/zips/val2014.zip'
TRAIN_IMAGE_DIR = os.path.join(DATASET_DIR, 'train2014')
VALID_IMAGE_DIR = os.path.join(DATASET_DIR, 'val2014')
TRAIN_IMAGE_PREFIX = 'COCO_train2014_'
VALID_IMAGE_PREFIX = 'COCO_val2014_'

In [6]:
IMAGE_SIZE = (384, 384)
EFFICIENT_NET_URL = 'https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/feature_vector/2'
UNIVERSAL_SENTENCE_ENCODER_URL = 'https://tfhub.dev/google/universal-sentence-encoder-lite/2'

BATCH_SIZE = 256
NUM_EPOCHS = 10
SEQ_LENGTH = 128
EMB_SIZE = 128

## Get COCO dataset

We are not using Tensorflow Dataset to get the [coco_captions](https://www.tensorflow.org/datasets/catalog/coco_captions) dataset due to disk space concerns. The following code will download and process the dataset.

In [14]:
#@title Functions for downloading and parsing annotations.

def parse_annotation_json(json_path):
  # Assuming the json file is already downloaded.
  with open(json_path, 'r') as f:
    json_obj = json.load(f)

  # Parsing out the following information from the annotation json: the COCO
  # image id and their corresponding flickr post id, as well as the captions.
  mapping = dict()
  for caption in json_obj['annotations']:
    image_id = caption['image_id']
    if image_id not in mapping:
      mapping[image_id] = [[]]
    mapping[image_id][0].append(caption['caption'])
  for image in json_obj['images']:
    # The flickr url here is the CDN url. We need to split it to get the post
    # id.
    flickr_url = image['flickr_url']
    url_parts = flickr_url.split('/')
    flickr_id = url_parts[-1].split('_')[0]
    mapping[image['id']].append(flickr_id)
  return list(mapping.items())


def get_train_valid_captions():
  # Define paths for the final output files
  train_pickle_path = os.path.join(DATASET_DIR, 'train_captions.pickle')
  valid_pickle_path = os.path.join(DATASET_DIR, 'valid_captions.pickle')

  if not os.path.exists(train_pickle_path) or not os.path.exists(valid_pickle_path):

    # 1. Clean up previous mess to ensure a fresh start
    tmp_dir = os.path.join(DATASET_DIR, 'tmp')
    if os.path.exists(tmp_dir):
        shutil.rmtree(tmp_dir)

    print("Downloading and extracting annotations...")

    # 2. Download and extract
    # We set extract=True, which usually creates the folder structure automatically
    annotation_zip = tf.keras.utils.get_file(
        'annotations.zip',
        cache_dir=os.path.abspath('.'),
        cache_subdir=tmp_dir,
        origin=CAPTION_URL,
        extract=True,
    )

    # 3. Dynamic Path Finding (The Fix)
    # We look for the 'annotations' folder inside tmp.
    # It might be in 'tmp/annotations' or 'tmp/annotations_extracted/annotations'
    base_path = None

    # Walk through the tmp directory to find where 'captions_train2014.json' actually is
    for root, dirs, files in os.walk(tmp_dir):
        if 'captions_train2014.json' in files:
            base_path = root
            break

    if base_path is None:
        # Fallback: Print what we have so we can debug if it fails again
        raise FileNotFoundError(f"Could not find caption JSON files anywhere in {tmp_dir}.")

    print(f"Found JSON files in: {base_path}")

    # 4. Parse the JSON files
    train_json_path = os.path.join(base_path, 'captions_train2014.json')
    valid_json_path = os.path.join(base_path, 'captions_val2014.json')

    train_img_cap = parse_annotation_json(train_json_path)
    valid_img_cap = parse_annotation_json(valid_json_path)

    # 5. Save to pickles
    with open(train_pickle_path, 'wb') as f:
      pickle.dump(train_img_cap, f)
    with open(valid_pickle_path, 'wb') as f:
      pickle.dump(valid_img_cap, f)

    # 6. Cleanup
    if os.path.exists(tmp_dir):
        shutil.rmtree(tmp_dir)

  else:
    print("Loading existing pickle files...")
    with open(train_pickle_path, 'rb') as f:
      train_img_cap = pickle.load(f)
    with open(valid_pickle_path, 'rb') as f:
      valid_img_cap = pickle.load(f)

  return train_img_cap, valid_img_cap

In [18]:
#@title Functions for downloading the images and create the dataset.

def get_sentencepiece_tokenizer_in_tf2():
  # The universal sentence encoder model from TFHub is in TF1 Module format. We
  # need to directly access the asset_paths to get the sentencepiece tokenizer
  # proto path.
  module = hub.load(UNIVERSAL_SENTENCE_ENCODER_URL)
  spm_path = module.asset_paths[0].asset_path.numpy()
  with tf.io.gfile.GFile(spm_path, mode='rb') as f:
    return sentencepiece_tokenizer.FastSentencepieceTokenizer(f.read())


def prepare_dataset(id_image_info_list,
                    image_file_prefix,
                    image_dir,
                    image_zip_url,
                    shuffle=False):
  # Download and unzip the dataset if it's not there already.
  if not os.path.exists(image_dir):
    image_zip = tf.keras.utils.get_file(
        'image.zip',
        cache_dir=os.path.abspath('.'),
        cache_subdir=os.path.join(DATASET_DIR),
        origin=image_zip_url,
        extract=True,
    )

    # --- FIX STARTS HERE ---
    # The variable 'image_zip' might point to a directory (the extraction folder)
    # instead of the actual .zip file. We must check before deleting.
    if os.path.isfile(image_zip):
        os.remove(image_zip)
    else:
        # If image_zip is a folder, we try to find the actual 'image.zip' file
        # in the dataset directory to delete it (to save disk space).
        actual_zip_path = os.path.join(DATASET_DIR, 'image.zip')
        if os.path.exists(actual_zip_path):
            os.remove(actual_zip_path)
    # --- FIX ENDS HERE ---

  # Convert the lists into tensors so that we can index into it in the dataset
  # transformations later.
  coco_ids, image_info = zip(*id_image_info_list)
  captions, flickr_ids = zip(*image_info)

  file_names = list(
      map(
          lambda id: os.path.join(image_dir, '%s%012d.jpg' %
                                  (image_file_prefix, id)), coco_ids))
  coco_ids_tensor = tf.constant(coco_ids)
  captions_tensor = tf.ragged.constant(captions)
  file_names_tensor = tf.constant(file_names)
  flickr_ids_tensor = tf.constant(flickr_ids)

  # The initial dataset only contains the index. This is to make sure the
  # dataset has a known size.
  dataset = tf.data.Dataset.range(len(coco_ids))
  sp = get_sentencepiece_tokenizer_in_tf2()

  def _load_image_and_select_caption(i):
    image_id = coco_ids_tensor[i]
    captions = captions_tensor[i]
    image_path = file_names_tensor[i]
    flickr_id = flickr_ids_tensor[i]

    # Decode the image
    image = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)

    # Randomly select one caption from the many captions we have for each image
    caption_idx = tf.random.uniform((1,),
                                    minval=0,
                                    maxval=tf.shape(captions)[0],
                                    dtype=tf.int32)[0]
    caption = captions[caption_idx]
    caption = tf.sparse.from_dense(sp.tokenize(caption))

    example = {
        'image': image,
        'image_id': image_id,
        'caption': caption,
        'flickr_id': flickr_id
    }
    return example

  def _resize_image(example):
    # Efficient net requires the pixels to be in range of [0, 1].
    example['image'] = tf.image.resize(example['image'], size=IMAGE_SIZE) / 255
    return example

  dataset = (
      # Load the images from disk and decode them into numpy arrays.
      dataset.map(
          _load_image_and_select_caption,
          num_parallel_calls=tf.data.AUTOTUNE,
          deterministic=not shuffle)
      # Resizing image is slow. We put the stage into a separate map so that it
      # could get more threads to not be the bottleneck.
      .map(
          _resize_image,
          num_parallel_calls=tf.data.AUTOTUNE,
          deterministic=not shuffle))

  if shuffle:
    dataset = dataset.shuffle(BATCH_SIZE * 10)

  dataset = dataset.batch(BATCH_SIZE)
  return dataset

Download the datasets and preprocess them.

In [16]:
# We parse the caption json files first.
train_img_cap, valid_img_cap = get_train_valid_captions()
print(f'Train number of images: {len(train_img_cap)}')
print(f'Valid number of images: {len(valid_img_cap)}')

example = train_img_cap[0]
print(f'COCO image id: {example[0]}')
print(f'Captions: {example[1][0]}')
print(f'Flickr post url: http://flickr.com/photo.gne?id={example[1][1]}')

Downloading and extracting annotations...
Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip
[1m252872794/252872794[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step
Found JSON files in: datasets/tmp/annotations_extracted/annotations
Train number of images: 82783
Valid number of images: 40504
COCO image id: 318556
Captions: ['A very clean and well decorated empty bathroom', 'A blue and white bathroom with butterfly themed wall tiles.', 'A bathroom with a border of butterflies and blue paint on the walls above it.', 'An angled view of a beautifully decorated bathroom.', 'A clock that blends in with the wall hangs in a bathroom. ']
Flickr post url: http://flickr.com/photo.gne?id=3378902101


In [19]:
# Shuffle both the train and validation sets
random.shuffle(valid_img_cap)
random.shuffle(train_img_cap)

# We randomly sample 5000 image-caption pairs from validation set for validation
# during training, to match the setup of
# https://www.tensorflow.org/datasets/catalog/coco_captions. However, when
# generating the retrieval database later on, we will use all the images in both
# validation and training splits.
valid_dataset = prepare_dataset(
    valid_img_cap[:5000],
    VALID_IMAGE_PREFIX,
    VALID_IMAGE_DIR,
    VALID_IMAGE_URL)
train_dataset = prepare_dataset(
    train_img_cap,
    TRAIN_IMAGE_PREFIX,
    TRAIN_IMAGE_DIR,
    TRAIN_IMAGE_URL,
    shuffle=True)

Downloading data from http://images.cocodataset.org/zips/train2014.zip
[1m13510573713/13510573713[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m266s[0m 0us/step


## Define models

The image encoder and text encoder may not output the embeddings with the same amount of dimensions. We need to project them into the same embedding space

In [54]:
def project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate):
    projected_embeddings = layers.Dense(units=projection_dims)(embeddings)

    for _ in range(num_projection_layers):
        x = layers.ReLU()(projected_embeddings) # Replaced tf.nn.relu
        x = layers.Dense(projection_dims)(x)
        x = layers.Dropout(dropout_rate)(x)
        x = layers.Add()([projected_embeddings, x])
        projected_embeddings = layers.LayerNormalization()(x)

    # Wrap math operation to prevent KerasTensor error
    projected_embeddings = layers.Lambda(
        lambda x: tf.math.l2_normalize(x, axis=1),
        name='l2_norm'
    )(projected_embeddings)

    return projected_embeddings

In [55]:
class EfficientNetLayer(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Load the Hub layer once
        self.hub_layer = hub.KerasLayer(EFFICIENT_NET_URL, trainable=False)

    def call(self, inputs):
        # We explicitly enforce training=False to prevent smart_cond errors
        return self.hub_layer(inputs, training=False)

# 3. Create the Image Encoder Model using the wrapper
def create_image_encoder(num_projection_layers, projection_dims, dropout_rate):
    # Define Input
    inputs = layers.Input(shape=IMAGE_SIZE + (3,), name='image_input')

    # FIX: Use our custom Wrapper Layer instead of calling hub.KerasLayer directly
    embeddings = EfficientNetLayer(name='efficientnet_lite')(inputs)

    # Project embeddings
    outputs = project_embeddings(embeddings, num_projection_layers,
                                 projection_dims, dropout_rate)

    return keras.Model(inputs, outputs, name='image_encoder')

We use [Universal Sentence Encoder](https://tfhub.dev/google/universal-sentence-encoder-lite/2), a SOTA sentence embedding model, as the text encoder base model. The TFHub lite version is a TF1 saved model. To make it work well in TF2 and later TFLite conversion, we create two models, one is the frozen universal sentence encoder, and the other is the trainable projection layer.

In [56]:
# Define a custom Keras Layer to handle the Hub interaction safely
class USELiteLayer(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Initialize the Hub layer once
        self.hub_layer = hub.KerasLayer(UNIVERSAL_SENTENCE_ENCODER_URL, signature='default', trainable=False)

    def call(self, inputs):
        # 1. Unpack the SparseTensor
        # 2. EXPLICITLY cast to int64. The error happened because Keras was defaulting to float32,
        #    but the Hub module strictly requires int64.
        values = tf.cast(inputs.values, tf.int64)
        indices = tf.cast(inputs.indices, tf.int64)
        dense_shape = tf.cast(inputs.dense_shape, tf.int64)

        # 3. Pass the dictionary to the Hub layer
        return self.hub_layer(dict(
            values=values,
            indices=indices,
            dense_shape=dense_shape
        ))

def create_text_encoder():
  # 1. Define the input (Sparse Tensor)
  inputs = tf.keras.Input(shape=(None,), dtype=tf.int64, sparse=True, name='text_input')

  # 2. Use our custom layer to process the input
  embeddings = USELiteLayer(name='use_lite_layer')(inputs)

  # 3. Return the model
  return keras.Model(inputs=inputs, outputs=embeddings, name='text_encoder')


def create_text_embedder_projection(input_dim, num_projection_layers,
                                    projection_dims, dropout_rate):
    inputs = layers.Input(shape=(input_dim,), dtype=tf.float32, name='text_input')
    outputs = project_embeddings(inputs, num_projection_layers, projection_dims,
                                 dropout_rate)
    return keras.Model(inputs, outputs, name='projection_layers')

This dual encoder model is derived from this [Keras post](https://keras.io/examples/nlp/nl_image_search/)

In [57]:
class DualEncoder(keras.Model):

  def __init__(self,
               text_encoder,
               text_encoder_projection,
               image_encoder,
               temperature,
               **kwargs):
    super(DualEncoder, self).__init__(**kwargs)
    self.text_encoder = text_encoder
    self.text_encoder_projection = text_encoder_projection
    self.image_encoder = image_encoder

    # Temperature controls the contrast of softmax output. In general, a low
    # temperature increases the contrast and a high temperature decreases it.
    self.temperature = temperature
    self.loss_tracker = keras.metrics.Mean(name='loss')

  @property
  def metrics(self):
    return [self.loss_tracker]

  def call(self, features, training=False):
    # If there are two GPUs present, we use one of them for image encoder and
    # one for text encoder. If there's only one GPU then they will be trained on
    # the same GPU.
    with tf.device('/gpu:0'):
      caption_embeddings = self.text_encoder(
          features['caption'], training=False)
      caption_embeddings = self.text_encoder_projection(
          caption_embeddings, training=training)
    with tf.device('/gpu:1'):
      image_embeddings = self.image_encoder(
          features['image'], training=training)
    return caption_embeddings, image_embeddings

  def compute_loss(self, caption_embeddings, image_embeddings):
    # Computing the loss with dot product similarity between image and text
    # embeddings.
    logits = (
        tf.matmul(caption_embeddings, image_embeddings, transpose_b=True) /
        self.temperature)
    images_similarity = tf.matmul(
        image_embeddings, image_embeddings, transpose_b=True)
    captions_similarity = tf.matmul(
        caption_embeddings, caption_embeddings, transpose_b=True)

    # The targets is the mean of the self-similarity of the captions and images.
    # This is more lenient to the similar examples appeared in the same batch.
    targets = keras.activations.softmax(
        (captions_similarity + images_similarity) / (2 * self.temperature))
    captions_loss = keras.losses.categorical_crossentropy(
        y_true=targets, y_pred=logits, from_logits=True)
    images_loss = keras.losses.categorical_crossentropy(
        y_true=tf.transpose(targets),
        y_pred=tf.transpose(logits),
        from_logits=True)
    return (captions_loss + images_loss) / 2

  def train_step(self, features):
    with tf.GradientTape() as tape:
      # Forward pass
      caption_embeddings, image_embeddings = self(features, training=True)
      loss = self.compute_loss(caption_embeddings, image_embeddings)

    # Backward pass
    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    self.loss_tracker.update_state(loss)
    return {'loss': self.loss_tracker.result()}

  def test_step(self, features):
    caption_embeddings, image_embeddings = self(features, training=False)
    loss = self.compute_loss(caption_embeddings, image_embeddings)
    self.loss_tracker.update_state(loss)
    return {'loss': self.loss_tracker.result()}

## Train the Dual Encoder model

Load the models from Tensorflow Hub.

In [58]:
# The text embedder consists of two models. One is the frozen base universal
# sentence encoder, and the other is the trainable projection layer.
text_encoder = create_text_encoder()
projection_layers = create_text_embedder_projection(
    input_dim=512,  # Universal sentence encoder output has 512 dimensions
    num_projection_layers=1,
    projection_dims=EMB_SIZE,
    dropout_rate=0.1)

image_encoder = create_image_encoder(
    num_projection_layers=1, projection_dims=EMB_SIZE, dropout_rate=0.1)

dual_encoder = DualEncoder(
    text_encoder, projection_layers, image_encoder, temperature=0.05)

# FIX: Use tf.keras.optimizers instead of tfa.optimizers
dual_encoder.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001))

Train the dual encoder model.

In [60]:
# We train the first three epochs with the learning rate of 0.001 and
# decrease it exponentially later on.
def lr_scheduler(epoch, lr):
  if epoch < 3:
    return lr
  else:
    return max(lr * tf.math.exp(-0.1), lr * 0.1)

# In colab, training takes roughly 4s per step, around 24 mins per epoch
# Create the early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=2,
    restore_best_weights=True
)

# Run the training
# FIX: Removed any legacy arguments like 'max_queue_size', 'workers', or 'use_multiprocessing'
# that might have been causing the crash in Keras 3.
history = dual_encoder.fit(
    train_dataset,
    epochs=NUM_EPOCHS,
    validation_data=valid_dataset,
    validation_steps=100,
    callbacks=[early_stopping]
)

# Save the models. We are not going to save the text_encoder since it's frozen
# and the TF2 saved model for text_encoder has problems converting to TFLite.
print('Training completed. Saving image and text encoders.')
dual_encoder.image_encoder.save('image_encoder')
dual_encoder.text_encoder_projection.save('text_encoder_projection')
print('Models are saved.')

Epoch 1/10


NotFoundError: Graph execution error:

Detected at node ReadFile defined at (most recent call last):
<stack traces unavailable>
Error in user-defined function passed to ParallelMapDatasetV2:5 transformation with iterator: Iterator::Root::Prefetch::BatchV2::Shuffle::ParallelMapV2::ParallelMapV2: datasets/train2014/COCO_train2014_000000526290.jpg; No such file or directory
	 [[{{node ReadFile}}]]
	 [[IteratorGetNext]] [Op:__inference_multi_step_on_iterator_186163]

## Create the text-to-image Searcher model using Model Maker

### Generate image embeddings

Load the valid and train dataset one more time. This time we are not going to shuffle the train split and we use the whole validataion split. Since images are not loaded until they are iterated, creating the datasets should be cheap.

In [None]:
combined_valid_dataset = prepare_dataset(
    valid_img_cap,
    VALID_IMAGE_PREFIX,
    VALID_IMAGE_DIR,
    VALID_IMAGE_URL)
deterministic_train_dataset = prepare_dataset(
    train_img_cap,
    TRAIN_IMAGE_PREFIX,
    TRAIN_IMAGE_DIR,
    TRAIN_IMAGE_URL)

all_combined = deterministic_train_dataset.concatenate(combined_valid_dataset)

Create the metadata (image file names and the flickr post id) from the dataset. This will later be packed into the TFLite model.

In [None]:
def create_metadata(image_file_prefix, image_dir):

  def _create_metadata(image_info):
    # This is the same way we generated the image paths in the prepare_dataset
    # function above
    coco_id = image_info[0]
    flickr_id = image_info[1][1]
    return ('%s_%s' %
            (flickr_id,
             os.path.join(image_dir, '%s%012d.jpg' %
                          (image_file_prefix, coco_id)))).encode('utf-8')

  return _create_metadata


# We don't store the images in the index file, as that would be too big. We only
# store the image path and the corresponding Flickr id.
metadata = list(
    map(create_metadata(TRAIN_IMAGE_PREFIX, TRAIN_IMAGE_DIR), train_img_cap))
metadata.extend(
    map(create_metadata(VALID_IMAGE_PREFIX, VALID_IMAGE_DIR), valid_img_cap))

Generate the embeddings for all the images we have. We do it in Tensorflow with GPU instead of Model Maker. Again, these will be packed into the TFLite model.

In [None]:
# Image encoder takes one input named `image_input` so we remove other values in
# the dataset.
image_dataset = all_combined.map(
    lambda example: {'image_input': example['image']})
image_embeddings = dual_encoder.image_encoder.predict(image_dataset, verbose=1)
print(f'Embedding matrix shape: {image_embeddings.shape}')

Embedding matrix shape: (123287, 128)


### Convert text embedder to TFLite

We need to convert the saved model to TF1 as the base Universal Sentence Encoder is a TF1 model. It'll create a saved model dir on disk called `converted_model`

In [None]:
#@title Prepare the saved model
!rm -rf converted_model

# This create a new TF1 SavedModel from 1). The tfhub USE, and 2). The
# projection layers trained and saved from TF2.
with tf1.Graph().as_default() as g:
  with tf1.Session() as sess:
    # Reload the Universal Sentence Encoder model from tfhub. We can't just save
    # the USE in TF2 as we did for the projection layers as that causes issues
    # in the TFLite converter.
    module = hub.Module(UNIVERSAL_SENTENCE_ENCODER_URL)
    spm_path = sess.run(module(signature='spm_path'))
    with tf1.io.gfile.GFile(spm_path, mode='rb') as f:
      serialized_spm = f.read()
    spm_path = sess.run(module(signature='spm_path'))
    input_str = tf1.placeholder(dtype=tf1.string, shape=[None])
    tokenizer = sentencepiece_tokenizer.FastSentencepieceTokenizer(
        model=serialized_spm)
    tokenized = tf1.sparse.from_dense(tokenizer.tokenize(input_str).to_tensor())
    tokenized = tf1.cast(tokenized, dtype=tf1.int64)
    encodings = module(
        inputs=dict(
            values=tokenized.values,
            indices=tokenized.indices,
            dense_shape=tokenized.dense_shape))

    # Then combine it with the trained projection layers
    projection_layers = tf1.keras.models.load_model('text_encoder_projection')
    encodings = projection_layers(encodings)

    sess.run([tf1.global_variables_initializer(), tf1.tables_initializer()])

    # Save with SavedModelBuilder
    builder = tf1.saved_model.Builder('converted_model')
    sig_def = tf1.saved_model.predict_signature_def(
        inputs={'input': input_str}, outputs={'output': encodings})
    builder.add_meta_graph_and_variables(
        sess,
        tags=['serve'],
        signature_def_map={
            tf1.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
        },
        clear_devices=True)
    builder.save()
print('Model saved to converted_model/')

Model saved to converted_model/


Convert and save the TFLite model. Here the model only has the text encoder. We will add in the retrieval index in the following steps.

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model('converted_model')
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.allow_custom_ops = True
converted_model_tflite = converter.convert()
with open('text_embedder.tflite', 'wb') as f:
  f.write(converted_model_tflite)

### Create TFLite Searcher model

In general see the documentation of [`ScaNNOptions`](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/ScaNNOptions) for how to configure the searcher for your dataset.

In [None]:
import tflite_model_maker as mm

In [None]:
scann_options = mm.searcher.ScaNNOptions(
    # We use the dot product similarity as this is how the model is trained
    distance_measure='dot_product',
    # Enable space partitioning with K-Means tree
    tree=mm.searcher.Tree(
        # How many partitions to have. A rule of thumb is the square root of the
        # dataset size. In this case it's 351.
        num_leaves=int(math.sqrt(len(metadata))),
        # Searching 4 partitions seems to give reasonable result. Searching more
        # will definitely return better results, but it's more costly to run.
        num_leaves_to_search=4),
    # Compress each float to int8 in the embedding. See
    # https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/ScoreAH
    # for details
    score_ah=mm.searcher.ScoreAH(
        # Using 1 dimension per quantization block.
        1,
        # Generally 0.2 works pretty well.
        anisotropic_quantization_threshold=0.2))

data = mm.searcher.DataLoader(
    embedder_path='text_embedder.tflite',
    dataset=image_embeddings,
    metadata=metadata)

model = mm.searcher.Searcher.create_from_data(
    data=data, scann_options=scann_options)
model.export(
    export_filename='searcher_model.tflite',
    userinfo='',
    export_format=mm.searcher.ExportFormat.TFLITE)

## Run inference using Task Library

In [None]:
from tflite_support.task import text
from tflite_support.task import core

Configure the searcher to return 6 results per query and not to L2 normalize the query embeddings because the text encoder has already normalized them. See [source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/python/task/text/text_searcher.py) on how to configure the `TextSearcher`.

In [None]:
options = text.TextSearcherOptions(
    base_options=core.BaseOptions(
        file_name='searcher_model.tflite'))

# The searcher returns 6 results
options.search_options.max_results = 6

tflite_searcher = text.TextSearcher.create_from_options(options)

In [None]:
def search_image_with_text(query_str, show_images=False):
  neighbors = tflite_searcher.search(query_str)

  for i, neighbor in enumerate(neighbors.nearest_neighbors):
    metadata = neighbor.metadata.decode('utf-8').split('_')
    flickr_id = metadata[0]
    print('Flickr url for %d: http://flickr.com/photo.gne?id=%s' %
          (i + 1, flickr_id))

  if show_images:
    plt.figure(figsize=(20, 13))
    for i, neighbor in enumerate(neighbors.nearest_neighbors):
      ax = plt.subplot(2, 3, i + 1)

      # Using negative distance since on-device ScaNN returns negative
      # dot-product distance.
      ax.set_title('%d: Similarity: %.05f' % (i + 1, -neighbor.distance))
      metadata = neighbor.metadata.decode('utf-8').split('_')
      image_path = '_'.join(metadata[1:])
      image = tf.image.decode_jpeg(
          tf.io.read_file(image_path), channels=3) / 255
      plt.imshow(image)
      plt.axis('off')

We will not show the image here due to copyright issues. You can set `show_images=True` to display them (note that you can't set it to `True` unless you've downloaded the images at [this cell](#scrollTo=Ke6EeKAqj1vB&line=12&uniqifier=1)). Please check the post links for the license of each image.

In [None]:
search_image_with_text('A man riding on a bike')

Flickr url for 1: http://flickr.com/photo.gne?id=6388219123
Flickr url for 2: http://flickr.com/photo.gne?id=30100145
Flickr url for 3: http://flickr.com/photo.gne?id=3322126404
Flickr url for 4: http://flickr.com/photo.gne?id=4945223078
Flickr url for 5: http://flickr.com/photo.gne?id=120446248
Flickr url for 6: http://flickr.com/photo.gne?id=4807048033


Congratulations on finishing this colab! For next steps, you can try deploy the model on-device (inference + search on Pixel 6 is around 6 ms), or you can train the model with your own dataset. In the mean time, don't forget to checkout our documentations ([Model Maker](https://www.tensorflow.org/lite/guide/model_maker/), [Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/text_searcher/)) and the [reference app](https://github.com/tensorflow/examples/tree/master/lite/examples/text_searcher/android), which searches news articles in [CNN_DailyMail dataset](https://www.tensorflow.org/datasets/catalog/cnn_dailymail)