In [26]:
import sqlite3
import sqlite_vec

In [21]:
import struct

In [24]:
# serializing f_32
def serialize_f32(vector: list[float]) -> bytes:
    """serializes a list of floats into a compact "raw bytes" format"""
    return struct.pack("%sf" % len(vector), *vector)

In [27]:
db = sqlite3.connect(":memory:") # maybe this doesn't create a file?

In [28]:
db.enable_load_extension(True) # needed to load the db to sqlite 

In [29]:
sqlite_vec.load(db)

In [30]:
db.enable_load_extension(False) # we may remove this , for better performance

In [31]:
sqlite_version, vec_version = db.execute(
    "select sqlite_version(), vec_version()"
).fetchone()

In [32]:
print(f"sqlite version: {sqlite_version}, vec_version: {vec_version}")

sqlite version: 3.41.2, vec_version: v0.1.1


In [33]:
# creating some dummy data

items = [
    (1, [0.1, 0.1, 0.1, 0.1]),
    (2, [0.2, 0.2, 0.2, 0.2]),
    (3, [0.3, 0.3, 0.3, 0.3]),
    (4, [0.4, 0.4, 0.4, 0.4]),
    (5, [0.5, 0.5, 0.5, 0.5]),
]

query = [0.3, 0.3, 0.3, 0.3]

In [34]:
# creating a virtual table with the above dummy values
# this could be interpreted as -> create a table vec_items with column vec0 -> embeddings of size 4 type float
db.execute("""CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4])""")

<sqlite3.Cursor at 0x121b7cfc0>

In [35]:
with db: # context
    for item in items:
        db.execute(
            """
                INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)
            """, [item[0], serialize_f32(item[1])]
        )


In [36]:
### Querying the vector database

rows = db.execute(
    """
    SELECT rowid, distance FROM vec_items
    WHERE embedding MATCH ? ORDER BY distance LIMIT 3
    """,
    [serialize_f32(query)]
).fetchall()

In [37]:
print(rows)

[(3, 0.0), (4, 0.19999998807907104), (2, 0.20000001788139343)]


In [48]:
# Lets make this modular and do some asserts to test the validity

def get_top_k(k:int, query: list[float]):
    rows = db.execute(
    """
    SELECT rowid, distance FROM vec_items
    WHERE embedding MATCH ? ORDER BY distance LIMIT ?
    """,
    (serialize_f32(query), k)
    ).fetchall()
    return rows

In [51]:
assert get_top_k(1, [0.2, 0.2, 0.2, 0.2])[0][0] == 2
assert get_top_k(1, [0.3, 0.3, 0.3, 0.3])[0][0] == 3
assert get_top_k(1, [0.4, 0.4, 0.4, 0.4])[0][0] == 4
assert get_top_k(1, [0.5, 0.5, 0.5, 0.5])[0][0] == 5
assert get_top_k(1, [0.1, 0.1, 0.1, 0.1])[0][0] == 1