-
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an aggregate function for querying a subset of rows #3
Comments
I'm not sure how to actually implement this. The FAISS docs aren't great - this looks like h to r most relevant section: https://faiss.ai/cpp_api/struct/structfaiss_1_1Index.html#structfaiss_1_1Index I can't see a way to tell the So maybe I need to calculate scores for the entire table and then filter? But that's ignoring one of the main reasons to use a FAISS index in the first place: optimized lookup of the first |
I wonder what the overhead would be of constructing a brand new FAISS index inside the aggregation class, only to throw it away at the end of the query? |
Wrote this. Haven't fully confirmed it works correctly yet, but it seems to? Not very fast though - took 440ms against my blog, though that's still faster than the 1.3s it took for a brute force Python cosine similarity. conn.create_aggregate("faiss_agg", 4, FaissAgg)
class FaissAgg:
def __init__(self):
self.ids = []
self.embeddings = []
self.compare_embedding = None
self.k = None
self.first = True
def step(self, id, embedding, compare_embedding, k):
if self.first:
self.first = False
self.compare_embedding = decode(compare_embedding)
self.k = k
self.ids.append(id)
self.embeddings.append(decode(embedding))
def finalize(self):
index = faiss.IndexFlatL2(len(self.compare_embedding))
index.add(np.array(self.embeddings))
_, I = index.search(np.array([self.compare_embedding]), self.k)
return json.dumps([self.ids[i] for i in I[0]]) Query I tested with (might not be correct): with q as (
select
embedding as e
from
blog.embeddings
order by
id desc
limit
1
), id_string as (
select
faiss_agg(embeddings.id, embedding, q.e, 10) as s
from
embeddings,
q
),
ids as (
select
value
from
json_each(id_string.s),
id_string
)
select
*
from
simonwillisonblog.blog_entry
where
id in (
select
value
from
ids
) |
But... this query, which filters for blog entries created greater than 2019, returns in 129ms: with q as (
select
embedding as e
from
blog.embeddings
order by
id desc
limit
1
), id_string as (
select
faiss_agg(embeddings.id, embedding, q.e, 10) as s
from
embeddings,
q
where embeddings.id in (
select id from blog_entry where created > '2019'
)
),
ids as (
select
value
from
json_each(id_string.s),
id_string
)
select
*
from
simonwillisonblog.blog_entry
where
id in (
select
value
from
ids
) Which is the goal with this function - it should take less time to run if you filter the embeddings you are calculating against first. So this is actually a pretty good result. |
Should do a |
with last_500 as (
select
id,
embedding
from
blog_entry_embeddings
order by
id desc
limit
500
), ids_and_scores as (
select
faiss_agg_with_scores(
id,
embedding,
(
select
embedding
from
blog_entry_embeddings
where
id = :id
), 10
) as s
from
last_500
),
results as (
select
json_extract(value, '$[0]') as id,
json_extract(value, '$[1]') as score
from
json_each(ids_and_scores.s),
ids_and_scores
)
select
results.score,
blog_entry.id,
blog_entry.title,
blog_entry.created
from
results
join blog_entry on results.id = blog_entry.id It grabs the embeddings from the most recent 500 entries and then finds the 10 items that are most similar to the provided item ID.
|
with last_500 as (
select
id,
embedding
from
blog_entry_embeddings
order by
id desc
limit
500
), faiss as (
select
faiss_agg(
id,
embedding,
(
select
embedding
from
blog_entry_embeddings
where
id = :id
),
10
) as results
from
last_500
),
ids as (
select
value as id
from
json_each(faiss.results),
faiss
)
select
blog_entry.id,
blog_entry.title,
blog_entry.created
from
ids
join blog_entry on ids.id = blog_entry.id |
This article is a useful explanation of why the underlying problem here is hard to solve: https://www.pinecone.io/learn/vector-search-filtering/ |
Thanks for making this! Noob Question: for a couple of values of the ":id" for example this since its not too efficient to run it for every ":idX" needle in a python loop. Could one pass through those needles in the final result? |
@itissid I'm afraid I don't understand the question there, can you expand? |
In writing up: I realized that one thing I never tested here was if there is any benefit to building a scratch FAISS index here as opposed to just filtering to ~1,000 results and then doing a brute force score on every remaining result. I would expect that a brute force score would still be a lot less expensive than a FAISS index build, since the FAISS index build has to visit every item anyway. So this whole existing |
So the logic for this statement is for, say, "small" doc set sizes like 1000, 5000, and 10000 and for say 10 search terms we have to do 10*1000, 10*5000, 10*10000 distance calculations and just sort them? Sounds reasonable. Do you(still) see any special use cases where |
The current
datasette_faiss()
function always operates against the entire table - but what if you want to calculate the best scores for just a subset of rows, based on a join or filter?I think a solution for this could be an aggregate function - something like this:
The text was updated successfully, but these errors were encountered: