# 01-6: ScaNN

In [3]:
!pip install scann pandas seaborn 
!pip install tensorflow==2.11.0 # Required for compatibility with tf_hub 

In [4]:
!gsutil cp gs://cloud-samples-data/vertex-ai/dataset-management/datasets/bert_finetuning/wide_and_deep_trainer_container_tests_input.jsonl .

In [None]:
import json
import time

import numpy as np
import pandas as pd
import scann
import tensorflow_hub as hub

In [None]:
module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
model = hub.load(module_url)
print ("module %s loaded" % module_url)
def embed(input):
  return model(input)


records = []
with open("wide_and_deep_trainer_container_tests_input.jsonl") as f:
    for line in f:
        record = json.loads(line)
        records.append(record)


# Peek at the data.
df = pd.DataFrame(records)
print(df.head(50))


def get_embedding(text):
  #message_embeddings = embed(messages)

  #for i, message_embedding in enumerate(np.array(message_embeddings).tolist()):
    #print("Message: {}".format(messages[i]))
    #print("Embedding size: {}".format(len(message_embedding)))
    #message_embedding_snippet = ", ".join(
    #    (str(x) for x in message_embedding[:3]))
    #print("Embedding: [{}, ...]\n".format(message_embedding_snippet))

  return embed([text])

import tensorflow as tf
# This may take several minutes to complete.
df["embedding"] = df["textContent"].apply(lambda x: get_embedding(x))

## Crear indice ScaNN

In [None]:
# Create an Index
record_count = len(records)
print("recor_cont")
print(record_count)

dataset = np.empty((record_count, 512)) # embedding size of sentence encoder (768 if palm)
for i in range(record_count):
    dataset[i] = df.embedding[i]

normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
# configure ScaNN as a tree - asymmetric hash hybrid with reordering
# anisotropic quantization as described in the paper; see README

# use scann.scann_ops.build() to instead create a TensorFlow-compatible searcher
searcher = (
    scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product")
    .tree(
        num_leaves=record_count,
        num_leaves_to_search=record_count,
        training_sample_size=record_count,
    )
    .score_ah(2, anisotropic_quantization_threshold=0.2)
    .reorder(100)
    .build()
)

## Lanzar queries al index

In [None]:
def search(query):
    start = time.time()
    query = get_embedding(query)
    print(query)
    neighbors, distances = searcher.search(tf.reshape(query,[-1]), final_num_neighbors=3)
    end = time.time()

    for id, dist in zip(neighbors, distances):
        print(f"[docid:{id}] [{dist}] -- {df.textContent[int(id)][:125]}...")
    print("Latency (ms):", 1000 * (end - start))


search("tell me about shark or animal")


search("tell me about an important moment or event in your life")

## Visualización con seaborn (semantic search similarity)

In [None]:
import seaborn as sns
def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  sns.set(font_scale=1.2)
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlOrRd")
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)


# TODO: Write messages to compare, at least 10 messages

messages = []

run_and_plot(messages)