# Example Databases

In [None]:
#| default_exp db

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
# | hide
from stringdale.core import get_git_root, load_env, checkLogs
import pytest

In [None]:
#| hide
load_env()

True

In [None]:
# | export
import os
from pathlib import Path
from copy import deepcopy
from fastcore.foundation import L
import json

from collections import defaultdict
from singleton_decorator import singleton

import logging

logger = logging.getLogger(__name__)

from openai import OpenAI
import base64
from typing import Optional

from joblib import Memory
from pathlib import Path

from stringdale.core import disk_cache



## Chroma

In [None]:
#| export
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional,Literal
import uuid


In [None]:
#| export
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
from openai import OpenAI,AsyncOpenAI
from copy import copy,deepcopy

In [None]:
#| export
import nest_asyncio
import numpy as np


In [None]:
nest_asyncio.apply()


In [None]:
#| export
def check_openai_key():
    api_key = os.getenv('OPENAI_API_KEY',None)
    if not api_key:
        raise ValueError('OPENAI_API_KEY is not set')

@singleton
def openai_client():
    check_openai_key()
    return OpenAI()

@singleton
def async_openai_client():
    check_openai_key()
    return AsyncOpenAI()
        

In [None]:
#| export

@disk_cache.cache
async def openai_embed(text, model='text-embedding-3-small'):
    response = await async_openai_client().embeddings.create(
        input=text,
        model=model
    )
    return np.array(response.data[0].embedding)


class OpenAIEmbed():
    def __init__(self,model='text-embedding-3-small'):
        self.model = model

    async def __call__(self,text):
        response = await openai_embed(text,model=self.model)
        return response

    def __str__(self):
        return f'OpenAIEmbed(model={self.model})'
    def __repr__(self):
        return self.__str__()

class CachedEmbeddingFunction(chromadb.utils.embedding_functions.EmbeddingFunction):
    def __init__(self,model='text-embedding-3-small'):
        self.model = model
        
    async def _async_call(self, texts):
        import asyncio
        return await asyncio.gather(*[openai_embed(text, model=self.model) for text in texts])
        
    def __call__(self, texts):
        import asyncio
        return asyncio.run(self._async_call(texts))

In [None]:
c = CachedEmbeddingFunction()
x = c(['hello world'])
x

[array([-0.00676333, -0.03919632,  0.03417581, ..., -0.01964353,
        -0.01937133, -0.02247135])]

In [None]:
#| export
class ChromaClient:
    def __init__(self,persist_path=None,embed_model='text-embedding-3-small'):
        """Initialize ChromaDB client with a collection name.
        
        Args:
            persist_path: Path to the directory to persist the database to
            embed_model: Model to use for embedding
        """

        self.embed_func = CachedEmbeddingFunction(model=embed_model)
        
        if persist_path:
            self.client = chromadb.PersistentClient(path=persist_path,settings=chromadb.Settings(allow_reset=True))
        else:
            self.client = chromadb.EphemeralClient(settings=chromadb.Settings(allow_reset=True))
        # Initialize Chroma with OpenAI embeddings
       
        current_collection_names = self.list_collections()
        self.collections={name:self.client.get_or_create_collection(name=name,embedding_function=self.embed_func) for name in current_collection_names}


    def reset(self):
        """Reset the database"""
        self.client.reset()
        self.collections={}
        
    def add_collection(self,name,distance:Literal['l2','ip','cosine']='l2',metadata=None,exists_ok=False):
        """Add a collection to the database

        Args:
            name: Name of the collection to add
            distance: Distance metric to use, one of 'l2','ip','cosine'. Default is 'l2'
            metadata: Metadata to add to the collection
            exists_ok: If True, do not raise an error if the collection already exists
        """
        if name in self.collections:
            if exists_ok:
                return
            raise ValueError(f'Collection {name} already exists')
        if metadata is None:
            metadata = {}
        metadata = {**metadata,**{'"hnsw:space"':distance}}
        self.collections[name] = self.client.get_or_create_collection(
            name=name,
            embedding_function=self.embed_func,
            metadata=metadata
        )

    def delete_collection(self,name):
        """Delete a collection from the database

        Args:
            name: Name of the collection to delete
        """
        self.client.delete_collection(name)
        del self.collections[name]

    def list_collections(self):
        """List all collections in the database

        Returns:
            List of collection names
        """
        return [col.name for col in self.client.list_collections()]

    def embed_texts(self,texts:List[str]):
        """Embed a list of texts

        Args:
            texts: List of texts to embed

        Returns:
            List of embeddings
        """

    def upsert(self,collection_name:str,docs):
        """Upsert a list of documents into a collection

        Args:
            collection_name: Name of the collection to upsert into
            docs: List of documents to upsert
                docs should be a list of dictionaries with a 'text' key, with optional 'id' and 'metadata' keys
        """
        ids = [doc.get('id',str(uuid.uuid4())) for doc in docs]
        texts = [doc['text'] for doc in docs]
        metadatas = [doc.get('metadata',None) for doc in docs]
        embeddings = self.embed_texts(texts)
        self.collections[collection_name].add(
            ids=ids,
            embeddings=embeddings,
            documents=texts,
            metadatas=metadatas
        )
        return docs

    def query(self,collection_name:str,query:str,k:int=10,threshold:float=None,where:Dict[str,Any]=None,where_document:Dict[str,Any]=None):
        """Query a collection for documents similar to a query

        Args:
            collection_name: Name of the collection to query
            query: Query to search for
            k: Number of results to return
            threshold: Threshold for filtering results
            where: Filter results by metadata
            where_document: Filter results by document text

        Returns:
            List of results
        """

        raw_results = self.collections[collection_name].query(
            query_texts=[query],
            n_results=k,
            where=where,
            where_document=where_document
        )
        results = [
            {'id':id,'text':text,'metadata':metadata,'distance':distance}
            for id,text,metadata,distance in zip(raw_results['ids'][0],raw_results['documents'][0],raw_results['metadatas'][0],raw_results['distances'][0])
        ]
        if threshold is not None:
            results = [result for result in results if result['distance'] <= threshold]
        return results
        # TODO add thresholding


    def get(self,collection_name:str,ids:List[str]):
        """Get a list of documents from a collection

        Args:
            collection_name: Name of the collection to get from
            ids: List of ids to get
        """
        raw_results = self.collections[collection_name].get(ids=ids)
        return [
            {'id':id,'text':text,'metadata':metadata}
            for id,text,metadata in zip(raw_results['ids'],raw_results['documents'],raw_results['metadatas'])
        ]
    
    def delete(self,collection_name:str,ids:List[str]):
        """Delete a list of documents from a collection

        Args:
            collection_name: Name of the collection to delete from
            ids: List of ids to delete
        """
        self.collections[collection_name].delete(ids=ids)
    
    def list(self,collection_name:str,k:int=None):
        """Get a list of documents from a collection

        Args:
            collection_name: Name of the collection to list
            k: Number of results to return
        """
        raw_results = self.collections[collection_name].peek(limit=k)
        return [{
            'id':id,
            'text':text,
            'metadata':metadata,
            'embedding':embedding
        } for id,text,metadata,embedding in
        zip(raw_results['ids'],raw_results['documents'],raw_results['metadatas'],raw_results['embeddings'])]

    def __deepcopy__(self,memo):
        return copy(self)
        

### Tests

In [None]:
# Test ChromaClient
client = ChromaClient()


In [None]:
client.reset()  # Start with a clean state

# Test collection management
client.add_collection("test_collection")
assert "test_collection" in client.list_collections(), f"Collection creation failed, {client.list_collections()}"


In [None]:

# Test document operations
test_docs = [
    {
        'id': 'doc1',
        'text': 'The quick brown fox jumps over the lazy dog',
        'metadata': {'type': 'pangram'}
    },
    {
        'id': 'doc2',
        'text': 'A quick brown fox jumped over the lazy dogs',
        'metadata': {'type': 'variant'}
    },
    {
        'id': 'doc3',
        'text': 'The weather is sunny today',
        'metadata': {'type': 'weather'}
    }
]

# Test upsert
client.upsert("test_collection", test_docs)



[{'id': 'doc1', 'text': 'The quick brown fox jumps over the lazy dog'},
 {'id': 'doc2', 'text': 'A quick brown fox jumped over the lazy dogs'},
 {'id': 'doc3', 'text': 'The weather is sunny today'}]

In [None]:
# Test query
results = client.query("test_collection", "fox jumping", k=2)

assert len(results) == 2, "Query should return 2 results"
assert all('fox' in doc['text'] for doc in results), "Query results should contain relevant documents"


In [None]:

# query with metadata filtering
results = client.query("test_collection", "fox jumping",where={'type':'pangram'},k=2)
assert len(results) == 1, results
assert results[0]['text'] == 'The quick brown fox jumps over the lazy dog'

# query with full text search
results = client.query("test_collection", "sunny",k=2,where_document={"$contains":"fox"})
results
assert len(results) == 2, results
assert all('fox' in doc['text'] for doc in results), "Query results should contain relevant documents"

In [None]:
# query with both filters
results = client.query("test_collection", "sunny",k=2,where_document={"$contains":"fox"},where={'type':{'$in':['weather','variant']}})
results


[{'id': 'doc2',
  'text': 'A quick brown fox jumped over the lazy dogs',
  'metadata': {'type': 'variant'},
  'distance': 1.513525366783142}]

In [None]:
client.get("test_collection",["doc2","doc1"])

[{'id': 'doc1',
  'text': 'The quick brown fox jumps over the lazy dog',
  'metadata': {'type': 'pangram'}},
 {'id': 'doc2',
  'text': 'A quick brown fox jumped over the lazy dogs',
  'metadata': {'type': 'variant'}}]

In [None]:
client.list("test_collection",k=3)

[{'id': 'doc1',
  'text': 'The quick brown fox jumps over the lazy dog',
  'metadata': {'type': 'pangram'},
  'embedding': array([-0.02083762, -0.01689642, -0.00453628, ...,  0.01019769,
         -0.01523149,  0.02468777])},
 {'id': 'doc2',
  'text': 'A quick brown fox jumped over the lazy dogs',
  'metadata': {'type': 'variant'},
  'embedding': array([-1.61350556e-02,  1.02180371e-03, -6.04663728e-05, ...,
          8.89423583e-03, -2.04253849e-02,  1.07899625e-02])},
 {'id': 'doc3',
  'text': 'The weather is sunny today',
  'metadata': {'type': 'weather'},
  'embedding': array([ 0.01581731, -0.03885713,  0.00716233, ..., -0.02583253,
          0.01166436,  0.0264344 ])}]

In [None]:
# Test get
doc_get = client.get("test_collection", ["doc1"])

assert doc_get[0]['id'] == 'doc1', "Get should return correct document"
assert doc_get[0]['text'] == test_docs[0]['text'], "Document text should match"

# Test list
collection_peek = client.list("test_collection", k=2)
assert len(collection_peek) == 2, "List should return 2 documents"

# Test query
results = client.query("test_collection", "fox jumping", k=2)
assert len(results) == 2, "Query should return 2 results"
assert all('fox' in result['text'] for result in results), "Query results should contain relevant documents"
assert all(isinstance(result['distance'], float) for result in results), "Each result should have a distance score"
assert all(isinstance(result['metadata'], dict) for result in results), "Each result should have metadata"

# Test delete
client.delete("test_collection", ["doc1"])
remaining_docs = client.list("test_collection")
assert "doc1" not in [doc['id'] for doc in remaining_docs], "Document should be deleted"

# Test collection deletion
client.delete_collection("test_collection")
assert "test_collection" not in client.list_collections(), "Collection deletion failed"


In [None]:

# Test error cases
client.add_collection("test_collection")
with pytest.raises(ValueError,match="Collection test_collection already exists"):
    client.add_collection("test_collection")

client.add_collection("test_collection", exists_ok=True)
client.delete_collection("test_collection")


## SQL

We show here how to create and use an in memory SQL db and configure tables using [SQLModel](https://sqlmodel.tiangolo.com/) Objects

In [None]:
#| export
import sqlalchemy 
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, select, Field
from typing import Optional
import sqlite3


In [None]:
#| export

def temp_sql_db(**kwargs):
    """
    creates and sqlalchemy engine to a shared memory sqlite DB.
    Kwargs are passed to to sqlalchemy's create_engine function.
    """
    creator = lambda: sqlite3.connect('file::memory:?cache=shared', uri=True)
    engine = create_engine('sqlite:///:memory:', creator=creator,**kwargs)
    return engine

In [None]:
engine = temp_sql_db(echo=False)

SQLModel.metadata.clear()

class Hero(SQLModel,table=True,extend_existing=True):
    id: Optional[int] = Field(default=None,primary_key=True)
    name: str
    secret_name: str
    age: Optional[int] = None

SQLModel.metadata.create_all(engine)

In [None]:
def merge_heros(heros:List[Hero]):
    with Session(engine) as session:
        for hero in heros:
            session.merge(hero)
        session.commit()

merge_heros(
    [Hero(id=1,name="Deadpond", secret_name="Dive"),
    Hero(id=2,name="Spider-Boy", secret_name="Pedro"),
    Hero(id=3,name="Rusty-Man", secret_name="Tommy")])

def get_hero(name:str):
    with Session(engine) as session:
        stmt = select(Hero).where(Hero.name == name)
        result = session.exec(stmt).one()
        return result

get_hero("Deadpond")

Hero(name='Deadpond', id=1, age=None, secret_name='Dive')

# Export

In [None]:
# |hide
import nbdev

nbdev.nbdev_export()