In [None]:
%pip install chromadb
%pip install onnxruntime-gpu
%pip install sentence-transformers

In [None]:
import onnxruntime
print(onnxruntime.get_available_providers())

In [8]:
import chromadb
import sqlite3
import pandas as pd
import time
import sqlite3
import pandas as pd
import time
import multiprocessing as mp
import os

In [None]:
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2

cuda_ef = ONNXMiniLM_L6_V2(preferred_providers=['CUDAExecutionProvider'])
ef = ONNXMiniLM_L6_V2()

In [9]:
import onnxruntime
print(onnxruntime.get_available_providers())

['CoreMLExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']


In [11]:
db_path = os.getcwd() + "/CompFood.sqlite"
print(db_path)
conn = sqlite3.connect(db_path)
batch_size = 1000

/Users/scmitton/Documents/Dev/protein-count/fast-api/db/CompFood.sqlite


In [12]:
print(conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall())

[('usda_non_branded_column',), ('usda_branded_column',), ('menustat',)]


In [13]:
def producer(batch_size, queue):
    '''
    Iterates throught the rows of the usda_non_branded_column, usda_branded_column, and menustat table
    and adds them to the queue in batches of size batch_size
    '''
    documents = []
    ids = []
    metadatas = []

    # -------------------------- Embed Non-Branded Foods ------------------------- #

    df = pd.read_sql_query("SELECT * FROM usda_non_branded_column;", conn)

    # Process and transform
    df = df.drop_duplicates(subset=['fdc_id']).dropna(subset=['description'])
    df['fdc_id'] = df['fdc_id'].astype(str) # Change id to str
    meta_columns = [
        'protein_amount', 'energy_amount', 'carb_amount',
        'fat_amount', 'serving_size'
    ]
    df = df.dropna(subset=meta_columns) # Clear rows with null values in any meta_column

    for _, row in df.iterrows():
        documents.append(row['description'])
        ids.append(row['fdc_id'])
        metadatas.append(row[meta_columns].to_dict())
        # Empty buffer
        if len(ids) > batch_size:
            queue.put((ids, documents, metadatas))
            ids, documents, metadatas = [], [], []

    # ---------------------------- Embed Branded Foods --------------------------- #

    df = pd.read_sql_query("SELECT * FROM usda_branded_column;", conn)

    # Process and transform
    df = df.drop_duplicates(subset=['fdc_id']).dropna(subset=['description'])
    df['fdc_id'] = df['fdc_id'].astype(str) # Change id to str
    meta_columns += ['brand_name']
    df = df.dropna(subset=meta_columns) # Clear rows with null values in any meta_column

    for _, row in df.iterrows():
        documents.append(row['description'])
        ids.append(row['fdc_id'])
        metadatas.append(row[meta_columns].to_dict())
        # Empty buffer
        if len(ids) > batch_size:
            queue.put((ids, documents, metadatas))
            ids, documents, metadatas = [], [], []

    if len(ids) > 0:
        queue.put((ids, documents, metadatas))
        ids, documents, metadatas = [], [], []

    # ---------------------------- Embed Menu Foods ---------------------------- #

    df = pd.read_sql_query("SELECT * FROM menustat;", conn)
    df = df.drop_duplicates(subset=['menustat_id']).dropna(
        subset=['description', 'restaurant'])
    df['menustat_id'] = df['menustat_id'].astype(str) # Change id to str
    df = df[df['food_category'] != 'Beverages'] # Filter beverages
    df['description'] = df['restaurant'] + ' ' + df['item_description']

    columns = ['protein', 'energy', 'carbs', 'fat']
    for col in columns:
        amount_column = col.strip('s') + '_amount'
        grams_column = col + '_per_100g'

        # Filter nulls
        df = df.dropna(subset=[grams_column, 'serving_size'])

        # Convert to int
        df[grams_column] = df[grams_column].astype(float)
        df['serving_size'] = df['serving_size'].astype(float)

        # Calculate amount
        df[amount_column] = df[grams_column] * df['serving_size'] / 100.0

    meta_columns = ['protein_amount', 'energy_amount', 'carb_amount',
                    'fat_amount', 'serving_size', 'restaurant']

    # Clear rows with null values in any meta_column
    df = df.dropna(subset=meta_columns)

    for _, row in df.iterrows():
        documents.append(row['description'])
        ids.append(row['menustat_id'])
        metadatas.append(row[meta_columns].to_dict())
        # Empty buffer
        if len(ids) > batch_size:
            queue.put((ids, documents, metadatas))
            ids, documents, metadatas = [], [], []

    if len(ids) > 0:
        queue.put((ids, documents, metadatas))

def consumer(use_cuda, queue):
    client = chromadb.PersistentClient(path='./chroma')
    device = "cuda" if use_cuda else "cpu"
    print(f"Using device: {device}")
    current_batch = 0

    embedding_function = cuda_ef if use_cuda else ef

    collection = client.get_collection(
        name='foods', embedding_function=embedding_function)

    while True:
        # Check for items in queue, this process blocks until queue has items to process.
        if queue.empty():
            if current_batch > 1:
                print('\r', end='', flush=True)
                print('Consumer currently blocked', end='', flush=True)
        batch = queue.get()

        if batch is None:
            break

        if current_batch > 1:
            print(f"\rProcessing batch {current_batch} of {len(batch[0])} items", end="", flush=True)
        else:
            print(f"Processing batch {current_batch} of {len(batch[0])} items", end="", flush=True)

        collection.add(
            ids=batch[0],
            documents=batch[1],
            metadatas=batch[2]
        )
        current_batch += 1


In [15]:
client = chromadb.PersistentClient(path='./chroma')
use_cuda = onnxruntime.get_device() == "CUDA" # Check if cuda is available
print(f"Using device: {onnxruntime.get_device()}")
embedding_function = cuda_ef if use_cuda else ef

try:
    collection = client.get_collection(name="foods")
    if collection:
        print("Deleting collection...")
        client.delete_collection(name="foods")
except Exception as e:
    print(e)

print("Creating collection...")
client.create_collection(
    name="foods", embedding_function=embedding_function)

# For cleaner reloading, delete and re-create the collections
queue = mp.Queue()

# Create producer and consumer processes.
producer_process = mp.Process(target=producer, args=(batch_size, queue))
consumer_process = mp.Process(target=consumer, args=(True, queue))

# Start processes
print("Starting producer...")
producer_process.start()
print("Starting consumer...")
consumer_process.start()

tik = time.time()

# Wait for producer to finish producing
producer_process.join()

# Signal consumer to stop consuming by putting None into the queue. Need 2 None's to stop 2 consumers.
queue.put(None)

# Wait for consumer to finish consuming
consumer_process.join()

tok = time.time()
print('\nFinished!')
print(f"Time taken: {tok - tik} seconds")

Using device: CPU
Deleting collection...
Creating collection...
Starting producer...
Starting consumer...

Finished!
Time taken: 0.04242110252380371 seconds


Traceback (most recent call last):
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "<string>", line 1, in <module>
  File "/Users/scmitton/.pyenv/versions/3.11.8/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
  File "/Users/scmitton/.pyenv/versions/3.11.8/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
    exitcode = _main(fd, parent_sentinel)
                            ^ ^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^  File "/Users/scmitton/.pyenv/versions/3.11.8/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
^^^^
  File "/Users/scmitton/.pyenv/versions/3.11.8/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
    self = reduction.pickle.load(from_parent)
                  ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^AttributeError: Can't get attribute 'producer' on <module '__main__' (built-in)>^
^