# Local Test Index

### Load model

In [None]:
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'

IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3

# prepare images for expected input
def read_and_decode(filename, reshape_dims=[IMG_HEIGHT, IMG_WIDTH]):
  img = tf.io.read_file(filename)
  img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)
  img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
  return tf.image.resize(img, reshape_dims)

# Download model from TF Hub
layers = [
      hub.KerasLayer(
          "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4",
          input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS),
          trainable=False,
          name='mobilenet_embedding'),
      tf.keras.layers.Flatten()
]
model = tf.keras.Sequential(layers, name='z_embedding')
print(model.summary())

### Vector attributes / labels

[specify namespaces and tokens](https://cloud.google.com/vertex-ai/docs/matching-engine/filtering)

In [None]:
# test = {"id": "42", 
#         "embedding": [0.5, 1.0], 
#         "restricts": [
#                       {
#                           "namespace": "class","allow": ["cat", "pet"]
#                       },
#                       {
#                           "namespace": "category", "allow": ["feline"]
#                        }
#                       ]
#         }

# v_attr = {
#     "id": "43", 
#     "embedding": [0.6, 1.0], 
#     "restricts": [
#                   {"namespace":"class", "allow": ["dog", "pet"]},
#                   {"namespace": "category", "allow":["canine"]}
#     ]
# }

### Create Query embeddings

In [None]:
NUM_TEST_SAMPLES = 50
# EVAL_IMG_PATH = 'gs://retail-product-img-kaggle/dataset/test/test'

In [None]:
def create_query_embeddings(embedder, img_path, num_test_samples):
  dataset_filenames = []
  dataset_embeddings = []
  
  list_dir = tf.io.gfile.listdir(img_path)
  
  for file in list_dir[:num_test_samples]:
    img_tensor = read_and_decode(img_path + "/" + file, [IMG_WIDTH, IMG_HEIGHT])
    embeddings = embedder(img_tensor)
    dataset_filenames.append(img_path + "/" + file)
    dataset_embeddings.extend(embeddings)
  
  dataset_embeddings = tf.convert_to_tensor(dataset_embeddings)
  
  return dataset_filenames, dataset_embeddings

In [None]:
query_filenames, query_embeddings = create_query_embeddings(
    lambda x: model.predict(x),
    EVAL_IMG_PATH,
    NUM_TEST_SAMPLES
)

vector_list = []
for q_vector in query_embeddings:
  vector_list.append(q_vector.numpy())

print("query_filenames:", query_filenames)
print("query_embeddings shape:", query_embeddings.shape) # should be (NUM_TEST_SAMPLES, 1280)
# print("vector_list shape:", vector_list.shape)
vector_list[0]

In [None]:
vector_list[0]

### Query ME Index

In [None]:
!gcloud beta ai index-endpoints list --project="jtotten-project" --region=us-central1

In [None]:
index_endpoint_resource_uri = 'projects/163017677720/locations/us-central1/indexEndpoints/5129564791202906112'
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_resource_uri)

NUM_NEIGH = 3

In [None]:
deployed_ann_index_name = 'ann_1280_deployed_index_kg_retail-20220223223806'
brute_index_resource_path = 'projects/163017677720/locations/us-central1/indexes/6397380862865833984'

ann_response = index_endpoint.match(
    deployed_index_id=deployed_ann_index_name, 
    queries=vector_list, 
    num_neighbors=NUM_NEIGH
)

print("ann_response:", ann_response)

In [None]:
deployed_brute_index_name = 'brute_force_1280_deployed_index_kg_retail-20220223222939'
brute_index_resource_path = 'projects/163017677720/locations/us-central1/indexes/1062867104245481472'

brute_force_response = index_endpoint.match(
    deployed_index_id=deployed_brute_index_name, 
    queries=vector_list, 
    num_neighbors=NUM_NEIGH
)

print("brute_force_response:", brute_force_response)

### Visualize Matches

In [None]:
BATCH_SIZE = 32
NUM_IMAGES = 510
NUM_NEIGH = 3 # 3, 10, 20

In [None]:
def decode_to_plot(filename, reshape_dims=[IMG_HEIGHT, IMG_WIDTH]):
  img = tf.io.read_file(filename)
  img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)
  img = tf.image.convert_image_dtype(img, tf.float32) # removed axis from previous
  return tf.image.resize(img, reshape_dims)


f, ax = plt.subplots(len(query_filenames), NUM_NEIGH + 1,
                     figsize=(5 * (1 + NUM_NEIGH), 5 * len(query_filenames)))

for rowno, query_filename in enumerate(query_filenames):
  ax[rowno][0].imshow(decode_to_plot(query_filename).numpy())
  ax[rowno][0].axis('off')
  for colno, neigh in enumerate(neighbors[rowno]):                                      # TODO: change neighbors to responses?
    ax[rowno][colno+1].imshow(decode_to_plot(query_filenames[neigh]).numpy())           # TODO: query_filenames |  dataset_filenames
    ax[rowno][colno+1].set_title('dist={:.1f}'.format(distances[rowno][colno].numpy())) # TODO: fix
    ax[rowno][colno+1].axis('off')

### Compute Recall

Use deployed brute force Index as the ground truth to calculate the recall of ANN Index:

In [None]:
NUM_NEIGH = 10

In [None]:
# Retrieve nearest neighbors for both the tree-AH index and the brute-force index

deployed_ann_index_name = 'ann_1280_deployed_index_kg_retail-20220223223806'
brute_index_resource_path = 'projects/163017677720/locations/us-central1/indexes/6397380862865833984'


# Retrieve nearest neighbors for both the tree-AH index and the brute-force index
ann_response_test = index_endpoint.match(
    deployed_index_id=deployed_ann_index_name, 
    queries=vector_list, 
    num_neighbors=NUM_NEIGH
)

# Brute Force Index
deployed_brute_index_name = 'brute_force_1280_deployed_index_kg_retail-20220223222939'
brute_index_resource_path = 'projects/163017677720/locations/us-central1/indexes/1062867104245481472'

brute_force_response_test = index_endpoint.match(
    deployed_index_id=deployed_brute_index_name, 
    queries=vector_list, 
    num_neighbors=NUM_NEIGH
)

In [None]:
# Calculate recall by determining how many neighbors correctly retrieved, compared to brute-force method.

correct_neighbors = 0

for tree_ah_neighbors, brute_force_neighbors in zip(ann_response_test, brute_force_response_test):
    tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
    brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]
    
    correct_neighbors += len(set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids))

recall = correct_neighbors / (len(vector_list) * NUM_NEIGH)

print("Recall: {}".format(recall))

### Create local model for testing...

In [None]:
# BUCKET = "retail-product-kaggle"
# save_path = os.path.join("gs://", gcp_bucket, f'saved_models/mobilenet_v2')
aiplatform.init(project=PROJECT_ID,location=LOCATION,)

In [None]:
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'

IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3
BATCH_SIZE = 32
NUM_IMAGES = 510
NUM_NEIGH = 3 # top 3

layers = [
      hub.KerasLayer(
          "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4",
          input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS),
          trainable=False,
          name='mobilenet_embedding'),
      tf.keras.layers.Flatten()
]
local_model = tf.keras.Sequential(layers, name='z_embedding')

In [None]:
def read_and_decode(filename, reshape_dims=[IMG_HEIGHT, IMG_WIDTH]):
  img = tf.io.read_file(filename)
  img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)
  img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
  return tf.image.resize(img, reshape_dims)

def create_embeddings_dataset(embedder, img_path):
  dataset_filenames = []
  dataset_embeddings = []
  list_dir = tf.io.gfile.listdir(img_path)
  for file in list_dir:
    img_tensor = read_and_decode(img_path + "/" + file, [IMG_WIDTH, IMG_HEIGHT])
    embeddings = embedder(img_tensor)
    dataset_filenames.append(img_path + "/" + file)
    dataset_embeddings.extend(embeddings)
  
  dataset_embeddings = tf.convert_to_tensor(dataset_embeddings)
  
  return dataset_filenames, dataset_embeddings

In [None]:
IMG_PATH = f'gs://{BUCKET}/extract/image_data_500_images_500_data_100230683.0.jpg'
read_and_decode(IMG_PATH)

In [None]:
# loaded = tf.saved_model.load(MODEL_DIR)
# loaded_k = tf.keras.models.load_model(MODEL_DIR)

IMG_PATH = f'gs://{BUCKET}/extract'

dataset_filenames, dataset_embeddings = create_embeddings_dataset(
    lambda x: local_model.predict(x),
    IMG_PATH
)

print(dataset_filenames[:3])
print(dataset_embeddings.shape) # should be (NUM_IMAGES, 1280)

In [None]:
img_tensor = read_and_decode(IMG_PATH)
# json_string