Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an aggregate function for querying a subset of rows #3

Closed
simonw opened this issue Jan 16, 2023 · 13 comments
Closed

Add an aggregate function for querying a subset of rows #3

simonw opened this issue Jan 16, 2023 · 13 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@simonw
Copy link
Owner

simonw commented Jan 16, 2023

The current datasette_faiss() function always operates against the entire table - but what if you want to calculate the best scores for just a subset of rows, based on a join or filter?

I think a solution for this could be an aggregate function - something like this:

with filtered as (
  select id from blog_entry where created_at > '2019'
),
top_n_ids as (
  select faiss_agg(id, 'blog', 'blog_entry_embeddings', :embedding, 10) as faiss
  from filtered
),
top_ids as (
  select value as id from json_each(faiss), top_n_ids
)
select * from blog_entry where id in (select id from top_ids)
@simonw simonw added the enhancement New feature or request label Jan 16, 2023
@simonw
Copy link
Owner Author

simonw commented Jan 16, 2023

I'm not sure how to actually implement this.

The FAISS docs aren't great - this looks like h to r most relevant section: https://faiss.ai/cpp_api/struct/structfaiss_1_1Index.html#structfaiss_1_1Index

I can't see a way to tell the .search() method to only consider these IDs.

So maybe I need to calculate scores for the entire table and then filter? But that's ignoring one of the main reasons to use a FAISS index in the first place: optimized lookup of the first k matches.

@simonw simonw added the help wanted Extra attention is needed label Jan 16, 2023
@simonw
Copy link
Owner Author

simonw commented Jan 16, 2023

I wonder what the overhead would be of constructing a brand new FAISS index inside the aggregation class, only to throw it away at the end of the query?

@simonw
Copy link
Owner Author

simonw commented Jan 16, 2023

Wrote this. Haven't fully confirmed it works correctly yet, but it seems to? Not very fast though - took 440ms against my blog, though that's still faster than the 1.3s it took for a brute force Python cosine similarity.

conn.create_aggregate("faiss_agg", 4, FaissAgg)

class FaissAgg:
    def __init__(self):
        self.ids = []
        self.embeddings = []
        self.compare_embedding = None
        self.k = None
        self.first = True

    def step(self, id, embedding, compare_embedding, k):
        if self.first:
            self.first = False
            self.compare_embedding = decode(compare_embedding)
            self.k = k
        self.ids.append(id)
        self.embeddings.append(decode(embedding))

    def finalize(self):
        index = faiss.IndexFlatL2(len(self.compare_embedding))
        index.add(np.array(self.embeddings))
        _, I = index.search(np.array([self.compare_embedding]), self.k)
        return json.dumps([self.ids[i] for i in I[0]])

Query I tested with (might not be correct):

with q as (
  select
    embedding as e
  from
    blog.embeddings
  order by
    id desc
  limit
    1
), id_string as (
  select
    faiss_agg(embeddings.id, embedding, q.e, 10) as s
  from
    embeddings,
    q
),
ids as (
  select
    value
  from
    json_each(id_string.s),
    id_string
)
select
  *
from
  simonwillisonblog.blog_entry
where
  id in (
    select
      value
    from
      ids
  )

@simonw
Copy link
Owner Author

simonw commented Jan 16, 2023

But... this query, which filters for blog entries created greater than 2019, returns in 129ms:

with q as (
  select
    embedding as e
  from
    blog.embeddings
  order by
    id desc
  limit
    1
), id_string as (
  select
    faiss_agg(embeddings.id, embedding, q.e, 10) as s
  from
    embeddings,
    q
  where embeddings.id in (
    select id from blog_entry where created > '2019'
  )
),
ids as (
  select
    value
  from
    json_each(id_string.s),
    id_string
)
select
  *
from
  simonwillisonblog.blog_entry
where
  id in (
    select
      value
    from
      ids
  )

Which is the goal with this function - it should take less time to run if you filter the embeddings you are calculating against first.

So this is actually a pretty good result.

@simonw
Copy link
Owner Author

simonw commented Jan 16, 2023

Should do a faiss_agg_with_scores() version too.

@simonw simonw closed this as completed in b90d8d1 Jan 19, 2023
simonw added a commit that referenced this issue Jan 19, 2023
simonw added a commit to simonw/simonwillisonblog-backup that referenced this issue Jan 20, 2023
@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

Example query: https://datasette.simonwillison.net/simonwillisonblog?sql=with+last_500+as+%28%0D%0A++select%0D%0A++++id%2C%0D%0A++++embedding%0D%0A++from%0D%0A++++blog_entry_embeddings%0D%0A++order+by%0D%0A++++id+desc%0D%0A++limit%0D%0A++++500%0D%0A%29%2C+ids_and_scores+as+%28%0D%0A++select%0D%0A++++faiss_agg_with_scores%28%0D%0A++++++id%2C%0D%0A++++++embedding%2C%0D%0A++++++%28%0D%0A++++++++select%0D%0A++++++++++embedding%0D%0A++++++++from%0D%0A++++++++++blog_entry_embeddings%0D%0A++++++++where%0D%0A++++++++++id+%3D+%3Aid%0D%0A++++++%29%2C+10%0D%0A++++%29+as+s%0D%0A++from%0D%0A++++last_500%0D%0A%29%2C%0D%0Aresults+as+%28%0D%0A++select%0D%0A++++json_extract%28value%2C+%27%24%5B0%5D%27%29+as+id%2C%0D%0A++++json_extract%28value%2C+%27%24%5B1%5D%27%29+as+score%0D%0A++from%0D%0A++++json_each%28ids_and_scores.s%29%2C%0D%0A++++ids_and_scores%0D%0A%29%0D%0Aselect%0D%0A++results.score%2C%0D%0A++blog_entry.id%2C%0D%0A++blog_entry.title%2C%0D%0A++blog_entry.created%0D%0Afrom%0D%0A++results%0D%0A++join+blog_entry+on+results.id+%3D+blog_entry.id&id=8214

with last_500 as (
  select
    id,
    embedding
  from
    blog_entry_embeddings
  order by
    id desc
  limit
    500
), ids_and_scores as (
  select
    faiss_agg_with_scores(
      id,
      embedding,
      (
        select
          embedding
        from
          blog_entry_embeddings
        where
          id = :id
      ), 10
    ) as s
  from
    last_500
),
results as (
  select
    json_extract(value, '$[0]') as id,
    json_extract(value, '$[1]') as score
  from
    json_each(ids_and_scores.s),
    ids_and_scores
)
select
  results.score,
  blog_entry.id,
  blog_entry.title,
  blog_entry.created
from
  results
  join blog_entry on results.id = blog_entry.id

It grabs the embeddings from the most recent 500 entries and then finds the 10 items that are most similar to the provided item ID.

score id title created
0.0 8214 A new AI game: Give me ideas for crimes to do 2022-12-04T15:11:31+00:00
0.23480114340782166 8215 AI assisted learning: Learning Rust with ChatGPT, Copilot and Advent of Code 2022-12-05T21:11:08+00:00
0.24925395846366882 8192 You can't solve AI security problems with more AI 2022-09-17T22:57:44+00:00
0.27296721935272217 8191 I don't know how to solve prompt injection 2022-09-16T16:28:53+00:00
0.28459107875823975 8189 Prompt injection attacks against GPT-3 2022-09-12T22:20:19+00:00
0.2949942350387573 8170 How to use the GPT-3 language model 2022-06-05T17:28:33+00:00
0.30468830466270447 8197 Is the AI spell-casting metaphor harmful or helpful? 2022-10-05T20:40:16+00:00
0.3121495544910431 8169 A Datasette tutorial written by GPT-3 2022-05-31T22:54:36+00:00
0.3125951290130615 8178 Using GPT-3 to explain how code works 2022-07-09T15:19:29+00:00
0.33426302671432495 8176 First impressions of DALL-E, generating images from text 2022-06-23T23:05:56+00:00

@simonw
Copy link
Owner Author

simonw commented Feb 10, 2023

This article is a useful explanation of why the underlying problem here is hard to solve: https://www.pinecone.io/learn/vector-search-filtering/

@itissid
Copy link

itissid commented Jul 10, 2023

Thanks for making this! Noob Question: for a couple of values of the ":id" for example this since its not too efficient to run it for every ":idX" needle in a python loop. Could one pass through those needles in the final result?

@simonw
Copy link
Owner Author

simonw commented Sep 5, 2023

@itissid I'm afraid I don't understand the question there, can you expand?

@simonw
Copy link
Owner Author

simonw commented Sep 5, 2023

In writing up:

I realized that one thing I never tested here was if there is any benefit to building a scratch FAISS index here as opposed to just filtering to ~1,000 results and then doing a brute force score on every remaining result.

I would expect that a brute force score would still be a lot less expensive than a FAISS index build, since the FAISS index build has to visit every item anyway.

So this whole existing faiss_agg() function may not be a useful optimization at all!

@itissid
Copy link

itissid commented Sep 5, 2023

So this whole existing faiss_agg() function may not be a useful optimization at all!

So the logic for this statement is for, say, "small" doc set sizes like 1000, 5000, and 10000 and for say 10 search terms we have to do 10*1000, 10*5000, 10*10000 distance calculations and just sort them? Sounds reasonable.

Do you(still) see any special use cases where faiss_agg becomes useful? Like I could imagine say if the doc set is result of a query that returns the same set(say like a geo spatial query for a small fixed areas of NYC), over and over, but with slightly different queries within the area, faiss_agg may be useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants