Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Utils add wipe qdrant and new stats command #1783

Merged
merged 6 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ api-docs:
ingest:
@poetry run python scripts/ingest_folder.py $(call args)

stats:
poetry run python scripts/utils.py stats

wipe:
poetry run python scripts/utils.py wipe

Expand Down
193 changes: 133 additions & 60 deletions scripts/utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
import argparse
import os
import shutil
from typing import Any, ClassVar

from private_gpt.paths import local_data_path
from private_gpt.settings.settings import settings


def wipe() -> None:
WIPE_MAP = {
"simple": wipe_simple, # node store
"chroma": wipe_chroma, # vector store
"postgres": wipe_postgres, # node, index and vector store
}
for dbtype in ("nodestore", "vectorstore"):
database = getattr(settings(), dbtype).database
func = WIPE_MAP.get(database)
if func:
func(dbtype)
else:
print(f"Unable to wipe database '{database}' for '{dbtype}'")


def wipe_file(file: str) -> None:
if os.path.isfile(file):
os.remove(file)
Expand Down Expand Up @@ -50,62 +36,149 @@ def wipe_tree(path: str) -> None:
continue


def wipe_simple(dbtype: str) -> None:
assert dbtype == "nodestore"
from llama_index.core.storage.docstore.types import (
DEFAULT_PERSIST_FNAME as DOCSTORE,
)
from llama_index.core.storage.index_store.types import (
DEFAULT_PERSIST_FNAME as INDEXSTORE,
)
class Postgres:
tables: ClassVar[dict[str, list[str]]] = {
"nodestore": ["data_docstore", "data_indexstore"],
"vectorstore": ["data_embeddings"],
}

def __init__(self) -> None:
try:
import psycopg2
except ModuleNotFoundError:
raise ModuleNotFoundError("Postgres dependencies not found") from None

for store in (DOCSTORE, INDEXSTORE):
wipe_file(str((local_data_path / store).absolute()))
connection = settings().postgres.model_dump(exclude_none=True)
self.schema = connection.pop("schema_name")
self.conn = psycopg2.connect(**connection)

def wipe(self, storetype: str) -> None:
cur = self.conn.cursor()
try:
for table in self.tables[storetype]:
sql = f"DROP TABLE IF EXISTS {self.schema}.{table}"
cur.execute(sql)
print(f"Table {self.schema}.{table} dropped.")
self.conn.commit()
finally:
cur.close()

def wipe_postgres(dbtype: str) -> None:
try:
import psycopg2
except ImportError as e:
raise ImportError("Postgres dependencies not found") from e
def stats(self, store_type: str) -> None:
template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}"
sql = " UNION ALL ".join(
template.format(table=tbl) for tbl in self.tables[store_type]
)

cur = self.conn.cursor()
try:
print(f"Storage for Postgres {store_type}.")
print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size"))
print("-" * 45) # Print a line separator

cur = conn = None
try:
tables = {
"nodestore": ["data_docstore", "data_indexstore"],
"vectorstore": ["data_embeddings"],
}[dbtype]
connection = settings().postgres.model_dump(exclude_none=True)
schema = connection.pop("schema_name")
conn = psycopg2.connect(**connection)
cur = conn.cursor()
for table in tables:
sql = f"DROP TABLE IF EXISTS {schema}.{table}"
cur.execute(sql)
print(f"Table {schema}.{table} dropped.")
conn.commit()
except psycopg2.Error as e:
print("Error:", e)
finally:
if cur:
for row in cur.fetchall():
formatted_row_count = f"{row[1]:,}"
print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}")

print()
finally:
cur.close()
if conn:
conn.close()

def __del__(self):
if hasattr(self, "conn") and self.conn:
self.conn.close()

def wipe_chroma(dbtype: str):
assert dbtype == "vectorstore"
wipe_tree(str((local_data_path / "chroma_db").absolute()))

class Simple:
def wipe(self, store_type: str) -> None:
assert store_type == "nodestore"
from llama_index.core.storage.docstore.types import (
DEFAULT_PERSIST_FNAME as DOCSTORE,
)
from llama_index.core.storage.index_store.types import (
DEFAULT_PERSIST_FNAME as INDEXSTORE,
)

if __name__ == "__main__":
commands = {
"wipe": wipe,
for store in (DOCSTORE, INDEXSTORE):
wipe_file(str((local_data_path / store).absolute()))


class Chroma:
def wipe(self, store_type: str) -> None:
assert store_type == "vectorstore"
wipe_tree(str((local_data_path / "chroma_db").absolute()))


class Qdrant:
COLLECTION = (
"make_this_parameterizable_per_api_call" # ?! see vector_store_component.py
dbzoo marked this conversation as resolved.
Show resolved Hide resolved
)

def __init__(self) -> None:
try:
from qdrant_client import QdrantClient # type: ignore
except ImportError:
raise ImportError("Qdrant dependencies not found") from None
self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True))

def wipe(self, store_type: str) -> None:
assert store_type == "vectorstore"
try:
self.client.delete_collection(self.COLLECTION)
print("Collection dropped successfully.")
except Exception as e:
print("Error dropping collection:", e)

def stats(self, store_type: str) -> None:
print(f"Storage for Qdrant {store_type}.")
try:
collection_data = self.client.get_collection(self.COLLECTION)
if collection_data:
# Collection Info
# https://qdrant.tech/documentation/concepts/collections/
print(f"\tPoints: {collection_data.points_count:,}")
print(f"\tVectors: {collection_data.vectors_count:,}")
print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}")
return
except ValueError:
pass
print("\t- Qdrant collection not found or empty")


class Command:
DB_HANDLERS: ClassVar[dict[str, Any]] = {
"simple": Simple, # node store
"chroma": Chroma, # vector store
"postgres": Postgres, # node, index and vector store
"qdrant": Qdrant, # vector store
}

def for_each_store(self, cmd: str):
for store_type in ("nodestore", "vectorstore"):
database = getattr(settings(), store_type).database
handler_class = self.DB_HANDLERS.get(database)
if handler_class is None:
print(f"No handler found for database '{database}'")
continue
handler_instance = handler_class() # Instantiate the class
# If the DB can handle this cmd dispatch it.
if hasattr(handler_instance, cmd) and callable(
func := getattr(handler_instance, cmd)
):
func(store_type)
else:
print(
f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'"
)

def execute(self, cmd: str) -> None:
if cmd in ("wipe", "stats"):
self.for_each_store(cmd)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"mode", help="select a mode to run", choices=list(commands.keys())
)
parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"])
args = parser.parse_args()
commands[args.mode.lower()]()

Command().execute(args.mode.lower())
Loading