In [0]:
%pip install faiss-cpu
%pip install -r ../requirements.txt
%pip install "ray[default]>=2.3.0"

In [0]:
from ray_search import *

In [0]:
from pyspark.sql.functions import monotonically_increasing_id, sha2, concat
qadf = spark.sql("""SELECT * FROM stackoverflow.qa_table""")
qadf = qadf.withColumn('id', sha2("question", 256).cast('string'))
qadf = qadf.select("id", concat("question", "answer_body").alias("content"))

In [0]:
df = qadf.toPandas()

In [0]:
df.head(2)

In [0]:
df.shape

In [0]:
ids = df['id'].values.astype('str').tolist()
texts = df['content']

In [0]:
import ray_search

with RayCluster(
  num_worker_nodes = 8, 
  num_cpus_per_node = 8,
  num_gpus_per_node = 1,
  runtime_env={"py_modules":[ray_search]},
  ) as r:
  print(f"Using: {r.max_parallel_workers()} workers")

  matrix_builder: SearchMatrixBuilder = SearchMatrixBuilder() \
        .with_content(ids,
                      texts) \
        .with_workers(8) \
        .with_worker_memory(Memory.in_mb(512)) \
        .with_text_chunk_size(1024*10) \
        .with_gpu() \
        .with_vectorizer(
          Vectorizers.SentenceTransformerDense \
            .with_encoder_batch_size(1024+512)
        )

  search_builder: SearchBuilder = SearchBuilder() \
        .with_workers(int(r.max_parallel_workers())) \
        .with_searcher(Searchers.FaissANNSearch) \
        .with_worker_memory(Memory.in_mb(512)) \
        .with_text_chunk_size(96) \
        .with_top_k_per_entity(5)
    
  results = SearchPipeline() \
        .with_search_matrix_builder(matrix_builder) \
        .with_search_builder(search_builder) \
        .to_df_unnested()

In [0]:
results