In [0]:
%pip install faiss-cpu
%pip install -r ../requirements.txt
%pip install "ray[default]>=2.3.0"

In [0]:
from ray_search import *

In [0]:
data = spark.table("hive_metastore.sri_demo_catalog.amazon_luggage_reviews_cleaned").toPandas()

In [0]:
data

In [0]:
ids = data['review_id'].values.astype('U').tolist()
texts = data['review_body_trunc'].values.astype('U').tolist()

In [0]:
# from ray.util.spark import shutdown_ray_cluster
# import ray
# ray.init()
# shutdown_ray_cluster()
# ray.shutdown()

In [0]:
import ray_search

with RayCluster(
  num_worker_nodes = 8, 
  num_cpus_per_node = 8,
  runtime_env={"py_modules":[ray_search]}
  ) as r:
  print(f"Using: {r.max_parallel_workers()} workers")

  matrix_builder: SearchMatrixBuilder = SearchMatrixBuilder() \
        .with_content(ids,
                      texts) \
        .with_workers(r.max_parallel_workers()) \
        .with_worker_memory(Memory.in_mb(512)) \
        .as_512_chunk() \
        .with_vectorizer(Vectorizers.SentenceTransformerDense)

  search_builder: SearchBuilder = SearchBuilder() \
        .with_workers(int(r.max_parallel_workers()/4)) \
        .with_searcher(Searchers.FaissANNSearch) \
        .with_worker_memory(Memory.in_mb(2046)) \
        .with_chunk_size(32) \
        .with_top_k_per_entity(5)
    
  results = SearchPipeline() \
        .with_search_matrix_builder(matrix_builder) \
        .with_search_builder(search_builder) \
        .to_df_unnested()

In [0]:
spark.createDataFrame(results).createOrReplaceTempView("search_results")

In [0]:
%sql
SELECT * FROM search_results LIMIT 10

In [0]:
%sql
SELECT count(1) FROM search_results

In [0]:
%sql
SELECT input_id, count(1) FROM search_results GROUP BY 1