Skip to content

Commit

Permalink
feat: support on-disk databases
Browse files Browse the repository at this point in the history
useful for using across multiple sequential processes
  • Loading branch information
tekumara committed Jan 26, 2024
1 parent 2f52f64 commit 6043f3d
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[files]
# ignore CHANGELOG because it contains commit SHAs
extend-exclude = ["CHANGELOG.md"]
extend-exclude = ["CHANGELOG.md", "notebooks/*"]
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ with fakesnow.patch("mymodule.write_pandas"):
...
```

By default databases are in-memory. To persist databases between processes, specify a databases path:

```python
with fakesnow.patch(db_path="databases/"):
...
```

### pytest fixtures

pytest [fixtures](fakesnow/fixtures.py) are provided for testing. Example _conftest.py_:
Expand Down
10 changes: 8 additions & 2 deletions fakesnow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import importlib
import os
import sys
import unittest.mock as mock
from collections.abc import Iterator, Sequence
Expand All @@ -19,6 +20,7 @@ def patch(
extra_targets: str | Sequence[str] = [],
create_database_on_connect: bool = True,
create_schema_on_connect: bool = True,
db_path: str | os.PathLike | None = None,
) -> Iterator[None]:
"""Patch snowflake targets with fakes.
Expand All @@ -28,12 +30,15 @@ def patch(
Args:
extra_targets (str | Sequence[str], optional): Extra targets to patch. Defaults to [].
create_database_on_connect (bool, optional): Create database if provided in connection. Defaults to True.
create_schema_on_connect (bool, optional): Create schema if provided in connection. Defaults to True.
Allows extra targets beyond the standard snowflake.connector targets to be patched. Needed because we cannot
patch definitions, only usages, see https://docs.python.org/3/library/unittest.mock.html#where-to-patch
create_database_on_connect (bool, optional): Create database if provided in connection. Defaults to True.
create_schema_on_connect (bool, optional): Create schema if provided in connection. Defaults to True.
db_path (str | os.PathLike | None, optional): _description_. Use existing database files from this path
or create them here if they don't already exist. If None databases are in-memory. Defaults to None.
Yields:
Iterator[None]: None.
"""
Expand All @@ -51,6 +56,7 @@ def patch(
duck_conn.cursor(),
create_database=create_database_on_connect,
create_schema=create_schema_on_connect,
db_path=db_path,
**kwargs,
),
snowflake.connector.pandas_tools.write_pandas: fakes.write_pandas,
Expand Down
44 changes: 33 additions & 11 deletions fakesnow/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import runpy
import sys
from collections.abc import Sequence
Expand All @@ -7,22 +8,43 @@
USAGE = "Usage: fakesnow <path> | -m <module> [<arg>]..."


def main(args: Sequence[str] = sys.argv) -> int:
if len(args) < 2 or (len(args) == 2 and args[1] == "-m"):
print(USAGE, file=sys.stderr)
return 42
def arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="""eg: fakesnow script.py OR fakesnow -m pytest""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-d",
"--db_path",
help="databases path. Use existing database files from this path or create them here if they don't already "
"exist. If None databases are in-memory.",
)
parser.add_argument("-m", "--module", help="module")
parser.add_argument("path", type=str, nargs="?", help="path")
parser.add_argument("args", nargs="*", help="args")
return parser

with fakesnow.patch():
if args[1] == "-m":
module = args[2]
sys.argv = args[2:]

def main(args: Sequence[str] = sys.argv[1:]) -> int:
parser = arg_parser()
pargs, remainder = parser.parse_known_args(args)

with fakesnow.patch(db_path=pargs.db_path):
if module := pargs.module:
if pargs.path:
sys.argv = [module, pargs.path, *pargs.args, *remainder]
else:
sys.argv = [module]

# add current directory to path to mimic python -m
sys.path.insert(0, "")
runpy.run_module(module, run_name="__main__", alter_sys=True)
else:
path = args[1]
sys.argv = args[1:]
elif path := pargs.path:
sys.argv = [path, *pargs.args, *remainder]

runpy.run_path(path, run_name="__main__")
else:
parser.print_usage()
return 42

return 0
11 changes: 8 additions & 3 deletions fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys
from collections.abc import Iterable, Iterator, Sequence
from pathlib import Path
from string import Template
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
Expand Down Expand Up @@ -160,7 +161,7 @@ def _execute(
transformed = (
expression.transform(transforms.upper_case_unquoted_identifiers)
.transform(transforms.set_schema, current_database=self._conn.database)
.transform(transforms.create_database)
.transform(transforms.create_database, db_path=self._conn.db_path)
.transform(transforms.extract_comment)
.transform(transforms.information_schema_fs_columns_snowflake)
.transform(transforms.information_schema_fs_tables_ext)
Expand Down Expand Up @@ -461,15 +462,18 @@ def __init__(
schema: str | None = None,
create_database: bool = True,
create_schema: bool = True,
db_path: str | os.PathLike | None = None,
*args: Any,
**kwargs: Any,
):
self._duck_conn = duck_conn
# upper case database and schema like snowflake
# upper case database and schema like snowflake unquoted identifiers
# NB: catalog names are not case-sensitive in duckdb but stored as cased in information_schema.schemata
self.database = database and database.upper()
self.schema = schema and schema.upper()
self.database_set = False
self.schema_set = False
self.db_path = db_path
self._paramstyle = "pyformat"

# create database if needed
Expand All @@ -481,7 +485,8 @@ def __init__(
where catalog_name = '{self.database}'"""
).fetchone()
):
duck_conn.execute(f"ATTACH DATABASE ':memory:' AS {self.database}")
db_file = f"{Path(db_path)/self.database}.db" if db_path else ":memory:"
duck_conn.execute(f"ATTACH DATABASE '{db_file}' AS {self.database}")
duck_conn.execute(info_schema.creation_sql(self.database))
duck_conn.execute(macros.creation_sql(self.database))

Expand Down
8 changes: 4 additions & 4 deletions fakesnow/info_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# use ext prefix in columns to disambiguate when joining with information_schema.tables
SQL_CREATE_INFORMATION_SCHEMA_TABLES_EXT = Template(
"""
create table ${catalog}.information_schema._fs_tables_ext (
create table if not exists ${catalog}.information_schema._fs_tables_ext (
ext_table_catalog varchar,
ext_table_schema varchar,
ext_table_name varchar,
Expand All @@ -18,7 +18,7 @@

SQL_CREATE_INFORMATION_SCHEMA_COLUMNS_EXT = Template(
"""
create table ${catalog}.information_schema._fs_columns_ext (
create table if not exists ${catalog}.information_schema._fs_columns_ext (
ext_table_catalog varchar,
ext_table_schema varchar,
ext_table_name varchar,
Expand All @@ -34,7 +34,7 @@
# snowflake integers are 38 digits, base 10, See https://docs.snowflake.com/en/sql-reference/data-types-numeric
SQL_CREATE_INFORMATION_SCHEMA_COLUMNS_VIEW = Template(
"""
create view ${catalog}.information_schema._fs_columns_snowflake AS
create view if not exists ${catalog}.information_schema._fs_columns_snowflake AS
select table_catalog, table_schema, table_name, column_name, ordinal_position, column_default, is_nullable,
case when starts_with(data_type, 'DECIMAL') or data_type='BIGINT' then 'NUMBER'
when data_type='VARCHAR' then 'TEXT'
Expand Down Expand Up @@ -62,7 +62,7 @@
# replicates https://docs.snowflake.com/sql-reference/info-schema/databases
SQL_CREATE_INFORMATION_SCHEMA_DATABASES_VIEW = Template(
"""
create view ${catalog}.information_schema.databases AS
create view if not exists ${catalog}.information_schema.databases AS
select
catalog_name as database_name,
'SYSADMIN' as database_owner,
Expand Down
2 changes: 1 addition & 1 deletion fakesnow/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

EQUAL_NULL = Template(
"""
CREATE MACRO ${catalog}.equal_null(a, b) AS a IS NOT DISTINCT FROM b;
CREATE MACRO IF NOT EXISTS ${catalog}.equal_null(a, b) AS a IS NOT DISTINCT FROM b;
"""
)

Expand Down
7 changes: 5 additions & 2 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from pathlib import Path
from typing import cast

import sqlglot
Expand All @@ -19,7 +20,7 @@ def array_size(expression: exp.Expression) -> exp.Expression:


# TODO: move this into a Dialect as a transpilation
def create_database(expression: exp.Expression) -> exp.Expression:
def create_database(expression: exp.Expression, db_path: Path | None = None) -> exp.Expression:
"""Transform create database to attach database.
Example:
Expand All @@ -36,9 +37,11 @@ def create_database(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Create) and str(expression.args.get("kind")).upper() == "DATABASE":
assert (ident := expression.find(exp.Identifier)), f"No identifier in {expression.sql}"
db_name = ident.this
db_file = f"{db_path/db_name}.db" if db_path else ":memory:"

return exp.Command(
this="ATTACH",
expression=exp.Literal(this=f"DATABASE ':memory:' AS {db_name}", is_string=True),
expression=exp.Literal(this=f"DATABASE '{db_file}' AS {db_name}", is_string=True),
create_db_name=db_name,
)

Expand Down
8 changes: 6 additions & 2 deletions tests/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import snowflake.connector

name = sys.argv[1] if len(sys.argv) > 1 else "world"
names = sys.argv[1:] if len(sys.argv) > 1 else ["world"]

conn = snowflake.connector.connect()

print(conn.cursor().execute(f"SELECT 'Hello fake {name}!'").fetchone()) # pyright: ignore[reportOptionalMemberAccess]
print(
conn.cursor()
.execute(f"SELECT 'Hello fake {' '.join(names)}!'")
.fetchone() # pyright: ignore[reportOptionalMemberAccess]
)
25 changes: 21 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,31 @@


def test_run_module(capsys: CaptureFixture) -> None:
fakesnow.cli.main(["pytest", "-m", "tests.hello", "frobnitz"])
fakesnow.cli.main(["-m", "tests.hello"])

captured = capsys.readouterr()
assert captured.out == "('Hello fake frobnitz!',)\n"
assert captured.out == "('Hello fake world!',)\n"

fakesnow.cli.main(["-m", "tests.hello", "frobnitz", "--colour", "rainbow"])

captured = capsys.readouterr()
assert captured.out == "('Hello fake frobnitz --colour rainbow!',)\n"


def test_run_path(capsys: CaptureFixture) -> None:
fakesnow.cli.main(["pytest", "tests/hello.py", "frobnitz"])
fakesnow.cli.main(["tests/hello.py"])

captured = capsys.readouterr()
assert captured.out == "('Hello fake world!',)\n"

fakesnow.cli.main(["tests/hello.py", "frobnitz", "--colour", "rainbow"])

captured = capsys.readouterr()
assert captured.out == "('Hello fake frobnitz --colour rainbow!',)\n"


def test_run_no_args_shows_usage(capsys: CaptureFixture) -> None:
fakesnow.cli.main([])

captured = capsys.readouterr()
assert captured.out == "('Hello fake frobnitz!',)\n"
assert "usage" in captured.out
Loading

0 comments on commit 6043f3d

Please sign in to comment.