Lets first explore sqlite-vec with some simulated data from scikit-learn.

In [10]:
from sklearn.datasets import make_blobs

In [22]:
n_feats = 400
X, y = make_blobs(n_samples=100_000, n_features=n_feats, centers=2000, random_state=42)

In [23]:
rm -rf local.sqlite

In [24]:
import sqlite3
import sqlite_vec

from typing import List
import struct


def serialize_f32(vector: List[float]) -> bytes:
    """serializes a list of floats into a compact "raw bytes" format"""
    return struct.pack("%sf" % len(vector), *vector)


db = sqlite3.connect("local.sqlite")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)

sqlite_version, vec_version = db.execute(
    "select sqlite_version(), vec_version()"
).fetchone()
print(f"sqlite_version={sqlite_version}, vec_version={vec_version}")

sqlite_version=3.46.0, vec_version=v0.1.1


In [25]:
db.execute(f"CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[{n_feats}])")

<sqlite3.Cursor at 0x15f5f3e40>

In [26]:
db.execute('SELECT name from sqlite_master').fetchall()

[('vec_items',),
 ('vec_items_chunks',),
 ('sqlite_sequence',),
 ('vec_items_rowids',),
 ('vec_items_vector_chunks00',),
 ('sqlite_autoindex_vec_items_vector_chunks00_1',)]

In [27]:
with db:
    for i, item in enumerate([{"vector": x, "cluster": y[i]} for i, x in enumerate(X)]):
        db.execute(
            "INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)",
            [i, serialize_f32(item["vector"])],
        )

In [42]:
db.row_factory = None

In [43]:
%%time

for i in range(10, 20):
    rows = db.execute(
        """
          SELECT
            rowid,
            distance
          FROM vec_items
          WHERE embedding MATCH ?
          ORDER BY distance
          LIMIT 25
        """,
        [serialize_f32(X[i])],
    ).fetchall()

# for i in rows:
#     print(i, y[i[0]])

CPU times: user 114 ms, sys: 339 ms, total: 453 ms
Wall time: 617 ms


In [45]:
query = X[0]

rows = db.execute(
    """
      SELECT
        rowid,
        distance
      FROM vec_items
      WHERE embedding MATCH ?
      ORDER BY distance
      LIMIT 25
    """,
    [serialize_f32(query)],
).fetchall()

[r for r in rows]

[(0, 0.0),
 (69166, 26.37257957458496),
 (28308, 26.60123634338379),
 (78824, 26.618066787719727),
 (62287, 26.834030151367188),
 (53374, 26.95392417907715),
 (22770, 27.02711296081543),
 (75846, 27.03619384765625),
 (99298, 27.093564987182617),
 (70053, 27.229528427124023),
 (71406, 27.265193939208984),
 (97010, 27.275630950927734),
 (88168, 27.286842346191406),
 (74349, 27.35616111755371),
 (63563, 27.394691467285156),
 (19480, 27.428003311157227),
 (25880, 27.4810733795166),
 (73005, 27.49576187133789),
 (87888, 27.54792594909668),
 (24674, 27.576269149780273),
 (69826, 27.629663467407227),
 (77301, 27.635019302368164),
 (2502, 27.780303955078125),
 (51854, 27.85087013244629),
 (55620, 27.900531768798828)]

In [46]:
db.execute(f"CREATE VIRTUAL TABLE bin_vec_items USING vec0(embedding bit[{n_feats}])")

<sqlite3.Cursor at 0x15fd0bc40>

In [47]:
db.execute("select vec_quantize_binary('[-0.73, -0.80, 0.12, -0.73, 0.79, -0.11, 0.23, 0.97]');").fetchall()

[(b'\xd4',)]

The single byte 0xd4 in hexadecimal is 11010100 in binary.

In [48]:
with db:
    for i, item in enumerate([{"vector": x, "cluster": y[i]} for i, x in enumerate(X)]):
        db.execute(
            "INSERT INTO bin_vec_items(rowid, embedding) VALUES (?, vec_quantize_binary(?))",
            [i, serialize_f32(item["vector"])],
        )

In [49]:
%%timeit

query = X[0]

rows = db.execute(
    """
      SELECT
        rowid,
        distance
      FROM bin_vec_items
      WHERE embedding MATCH vec_quantize_binary(?)
      ORDER BY distance
      LIMIT 200
    """,
    [serialize_f32(query)],
).fetchall()

rows

19.2 ms ± 612 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
limit = 1000

rows_orig = db.execute(
    f"""
      SELECT
        rowid,
        distance
      FROM vec_items
      WHERE embedding MATCH ?
      ORDER BY distance
      LIMIT {limit}
    """,
    [serialize_f32(query)],
).fetchall()

rows_bin = db.execute(
    f"""
      SELECT
        rowid,
        distance
      FROM bin_vec_items
      WHERE embedding MATCH vec_quantize_binary(?)
      ORDER BY distance
      LIMIT {limit}
    """,
    [serialize_f32(query)],
).fetchall()

In [15]:
set_orig, set_bin = set(i[0] for i in rows_orig), set(i[0] for i in rows_bin)
overlap = set_orig.intersection(set_bin)
len(overlap)

237

This overlap feels very very small! But then again ... think about how this dataset was generated!

Lets now move on to a dataset that actually uses embeddings!

In [53]:
from datasets import load_dataset

dataset = load_dataset("m3hrdadfi/recipe_nlg_lite")

  from .autonotebook import tqdm as notebook_tqdm
Repo card metadata block was not found. Setting CardData to empty.


In [54]:
recipe_names = dataset['train']['name']

In [55]:
recipe_names

['pork chop noodle soup',
 '5 ingredient almond cake with fresh berries',
 'shrimp cakes',
 'chili roasted okra',
 'slow cooker chicken chili',
 'fall superfood beef stew',
 'rotisserie turkey breast tenderloin with pesto',
 "j . alexander's mac and cheese",
 'spinach stuffed pork roast',
 'lasagna soup',
 'mexican wedding cookies',
 'chicken and rice jambalaya',
 "lipton's carne asada ole",
 'oven braised beef brisket',
 'sweet and sour eggplant',
 'sausage spinach quiche',
 'red kidney beans dumpling in rich tomato gravy',
 'israeli style fish cakes',
 'rum raisin cake',
 'shrimp and fried tofu rad nah',
 'rum and coffee brisket',
 'pb and j waffle sandwich',
 'shredded carrot, arugula, and wild rice salad',
 'slow cooker balsamic red wine pot roast',
 'peach and prosciutto flatbread',
 'protein egg and quinoa salad jars',
 'eye opening spicy bloody mary ham steaks',
 'chili a la jimmy fallon',
 'oyakudon',
 "maddy's favorite brownie sundaes",
 'gluten free matzo balls potato knaidel

In [56]:
from sentence_transformers import SentenceTransformer

tfm_base = SentenceTransformer("all-MiniLM-L6-v2")



In [57]:
X_tfm = tfm_base.encode(recipe_names)

In [58]:
n_feats = X_tfm.shape[1]
db.execute(f"CREATE VIRTUAL TABLE vec_sents USING vec0(embedding float[{n_feats}])")
db.execute(f"CREATE VIRTUAL TABLE bin_vec_sents USING vec0(embedding bit[{n_feats}])")

<sqlite3.Cursor at 0x29426cbc0>

In [59]:
with db:
    for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
        db.execute(
            "INSERT INTO vec_sents(rowid, embedding) VALUES (?, ?)",
            [i, serialize_f32(item["vector"])],
        )
        db.execute(
            "INSERT INTO bin_vec_sents(rowid, embedding) VALUES (?, vec_quantize_binary(?))",
            [i, serialize_f32(item["vector"])],
        )

In [62]:
limit = 1000
query = tfm_base.encode(["I would like to have some vegetable soup"])[0]

rows_orig = db.execute(
    f"""
      SELECT
        rowid,
        distance
      FROM vec_sents
      WHERE embedding MATCH ?
      ORDER BY distance
      LIMIT {limit}
    """,
    [serialize_f32(query)],
).fetchall()

rows_bin = db.execute(
    f"""
      SELECT
        rowid,
        distance
      FROM bin_vec_sents
      WHERE embedding MATCH vec_quantize_binary(?)
      ORDER BY distance
      LIMIT {limit}
    """,
    [serialize_f32(query)],
).fetchall()

In [63]:
set_orig, set_bin = set(i[0] for i in rows_orig[1:]), set(i[0] for i in rows_bin[1:])
overlap = set_orig.intersection(set_bin)
len(overlap)

733

Thats a whole lot better!

In [64]:
import polars as pl

df1 = pl.DataFrame([{'i': i[0], 'dist_orig': i[1]} for i in rows_orig if i[0] in overlap])
df2 = pl.DataFrame([{'i': i[0], 'dist_bin': i[1]} for i in rows_bin if i[0] in overlap])

In [65]:
import altair as alt

pltr = df1.join(df2, left_on='i', right_on='i')

alt.Chart(pltr).mark_point().encode(x='dist_orig', y='dist_bin').interactive()

In [66]:
import numpy as np

pltr = (
    pl.DataFrame({'hits':[item in set_bin for item in set_orig]})
    .with_columns(
        cs=pl.col('hits').cum_sum(), 
        r=pl.lit(1)
    ).with_columns(
        k=pl.col("r").cum_sum()
    ).with_columns(
        precision_at_k=pl.col('cs')/pl.col('k')
    )
)

In [67]:
alt.Chart(pltr).mark_line().encode(x='k', y='precision_at_k')

There is also another alternative, you can choose to store float16 values instead as well. This should yield a 2x improvement to storage and I would be surpised if it has a real downside to retreival performance.

## One last ... super cool demo

In [68]:
import pandas as pd

df = pd.read_csv("car_prices.csv").dropna()
df.head(3)

Unnamed: 0,year,make,model,trim,body,transmission,vin,state,condition,odometer,color,interior,seller,mmr,sellingprice,saledate
0,2015,Kia,Sorento,LX,SUV,automatic,5xyktca69fg566472,ca,5.0,16639.0,white,black,kia motors america inc,20500.0,21500.0,Tue Dec 16 2014 12:30:00 GMT-0800 (PST)
1,2015,Kia,Sorento,LX,SUV,automatic,5xyktca69fg561319,ca,5.0,9393.0,white,beige,kia motors america inc,20800.0,21500.0,Tue Dec 16 2014 12:30:00 GMT-0800 (PST)
2,2014,BMW,3 Series,328i SULEV,Sedan,automatic,wba3c1c51ek116351,ca,45.0,1331.0,gray,black,financial services remarketing (lease),31900.0,30000.0,Thu Jan 15 2015 04:30:00 GMT-0800 (PST)


In [69]:
from sklearn.pipeline import make_union, make_pipeline
from sklearn.compose import make_column_selector
from sklearn.preprocessing import StandardScaler, OneHotEncoder, FunctionTransformer
from sklearn.feature_extraction.text import HashingVectorizer, CountVectorizer
from skrub import SelectCols

pipe = make_union(
    SelectCols(["mmr"]),
    make_pipeline(
        SelectCols(["year", "condition", "odometer"]),
        StandardScaler()
    ),
    make_pipeline(
        SelectCols(["make", "model", "body", "transmission", "color"]),
        OneHotEncoder(sparse_output=False)
    )
)

X_demo = pipe.fit_transform(df)

In [70]:
X_demo.shape

(472325, 932)

In [71]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.linear_model import Ridge


class RidgeEmbedder(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y):
        self.mod_ = Ridge(fit_intercept=False).fit(X, y)
        return self

    def transform(self, X):
        return X * self.mod_.coef_

In [72]:
pipe = make_pipeline(
    make_union(
        SelectCols(["mmr"]),
        make_pipeline(
            SelectCols(["year", "condition", "odometer"]),
            StandardScaler()
        ),
        make_pipeline(
            SelectCols(["make", "model", "body", "transmission", "color"]),
            OneHotEncoder(sparse_output=False)
        )
    ),
    RidgeEmbedder()
)

X_demo = pipe.fit_transform(df, df["sellingprice"])

In [73]:
X_demo.shape

(472325, 932)

In [75]:
db.execute(f"CREATE VIRTUAL TABLE carvec USING vec0(embedding float[{932}])")

<sqlite3.Cursor at 0x37d8495c0>

In [76]:
with db:
    for i, item in enumerate([{"vector": x} for i, x in enumerate(X_demo)]):
        db.execute(
            "INSERT INTO carvec(rowid, embedding) VALUES (?, ?)",
            [i, serialize_f32(item["vector"])],
        )

In [77]:
%%time

rows_orig = db.execute(
    f"""
      SELECT
        rowid,
        distance
      FROM carvec
      WHERE embedding MATCH ?
      ORDER BY distance
      LIMIT {limit}
    """,
    [serialize_f32(X_demo[5])]).fetchall()

CPU times: user 529 ms, sys: 604 ms, total: 1.13 s
Wall time: 1.95 s


So we would like to have some more speed, so we will need to wait for the indices to be implemented. 
That said ... this is cool!

Work for later, can we batch these predictions?