diff --git a/scripts/utils.py b/scripts/utils.py index 6f5006c4c..48068789c 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -2,9 +2,35 @@ import os import shutil +from private_gpt.paths import local_data_path +from private_gpt.settings.settings import settings -def wipe(): - path = "local_data" + +def wipe() -> None: + WIPE_MAP = { + "simple": wipe_simple, # node store + "chroma": wipe_chroma, # vector store + "postgres": wipe_postgres, # node, index and vector store + } + for dbtype in ("nodestore", "vectorstore"): + database = getattr(settings(), dbtype).database + func = WIPE_MAP.get(database) + if func: + func(dbtype) + else: + print(f"Unable to wipe database '{database}' for '{dbtype}'") + + +def wipe_file(file: str) -> None: + if os.path.isfile(file): + os.remove(file) + print(f" - Deleted {file}") + + +def wipe_tree(path: str) -> None: + if not os.path.exists(path): + print(f"Warning: Path not found {path}") + return print(f"Wiping {path}...") all_files = os.listdir(path) @@ -24,6 +50,54 @@ def wipe(): continue +def wipe_simple(dbtype: str) -> None: + assert dbtype == "nodestore" + from llama_index.core.storage.docstore.types import ( + DEFAULT_PERSIST_FNAME as DOCSTORE, + ) + from llama_index.core.storage.index_store.types import ( + DEFAULT_PERSIST_FNAME as INDEXSTORE, + ) + + for store in (DOCSTORE, INDEXSTORE): + wipe_file(str((local_data_path / store).absolute())) + + +def wipe_postgres(dbtype: str) -> None: + try: + import psycopg2 + except ImportError as e: + raise ImportError("Postgres dependencies not found") from e + + cur = conn = None + try: + tables = { + "nodestore": ["data_docstore", "data_indexstore"], + "vectorstore": ["data_embeddings"], + }[dbtype] + connection = settings().postgres.model_dump(exclude_none=True) + schema = connection.pop("schema_name") + conn = psycopg2.connect(**connection) + cur = conn.cursor() + for table in tables: + sql = f"DROP TABLE IF EXISTS {schema}.{table}" + cur.execute(sql) + print(f"Table {schema}.{table} dropped.") + conn.commit() + except psycopg2.Error as e: + print("Error:", e) + finally: + if cur: + cur.close() + if conn: + conn.close() + + +def wipe_chroma(dbtype: str): + assert dbtype == "vectorstore" + wipe_tree(str((local_data_path / "chroma_db").absolute())) + + if __name__ == "__main__": commands = { "wipe": wipe,