diff --git a/docs/cli-reference.rst b/docs/cli-reference.rst index b717461a..55517c92 100644 --- a/docs/cli-reference.rst +++ b/docs/cli-reference.rst @@ -289,6 +289,7 @@ See :ref:`cli_inserting_data`, :ref:`cli_insert_csv_tsv`, :ref:`cli_insert_unstr --analyze Run ANALYZE at the end of this operation --load-extension TEXT Path to SQLite extension, with optional :entrypoint --silent Do not show progress bar + --strict Apply STRICT mode to created table --ignore Ignore records if pk already exists --replace Replace records if pk already exists --truncate Truncate table before inserting records, if table @@ -345,6 +346,7 @@ See :ref:`cli_upsert`. --analyze Run ANALYZE at the end of this operation --load-extension TEXT Path to SQLite extension, with optional :entrypoint --silent Do not show progress bar + --strict Apply STRICT mode to created table -h, --help Show this message and exit. @@ -920,6 +922,7 @@ See :ref:`cli_create_table`. --replace If table already exists, replace it --transform If table already exists, try to transform the schema --load-extension TEXT Path to SQLite extension, with optional :entrypoint + --strict Apply STRICT mode to created table -h, --help Show this message and exit. diff --git a/docs/cli.rst b/docs/cli.rst index 736e7d6c..ed605445 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -1972,6 +1972,25 @@ You can specify foreign key relationships between the tables you are creating us [author_id] INTEGER REFERENCES [authors]([id]) ) +You can create a table in `SQLite STRICT mode `__ using ``--strict``: + +.. code-block:: bash + + sqlite-utils create-table mydb.db mytable id integer name text --strict + +.. code-block:: bash + + sqlite-utils tables mydb.db --schema -t + +.. code-block:: output + + table schema + ------- ------------------------ + mytable CREATE TABLE [mytable] ( + [id] INTEGER, + [name] TEXT + ) STRICT + If a table with the same name already exists, you will get an error. You can choose to silently ignore this error with ``--ignore``, or you can replace the existing table with a new, empty table using ``--replace``. You can also pass ``--transform`` to transform the existing table to match the new schema. See :ref:`python_api_explicit_create` in the Python library documentation for details of how this option works. @@ -2018,7 +2037,7 @@ Use ``--ignore`` to ignore the error if the table does not exist. Transforming tables =================== -The ``transform`` command allows you to apply complex transformations to a table that cannot be implemented using a regular SQLite ``ALTER TABLE`` command. See :ref:`python_api_transform` for details of how this works. +The ``transform`` command allows you to apply complex transformations to a table that cannot be implemented using a regular SQLite ``ALTER TABLE`` command. See :ref:`python_api_transform` for details of how this works. The ``transform`` command preserves a table's ``STRICT`` mode. .. code-block:: bash diff --git a/docs/python-api.rst b/docs/python-api.rst index 9d396c65..b2356843 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -117,6 +117,12 @@ By default, any :ref:`sqlite-utils plugins ` that implement the :ref:`p db = Database(memory=True, execute_plugins=False) +You can pass ``strict=True`` to enable `SQLite STRICT mode `__ for all tables created using this database object: + +.. code-block:: python + + db = Database("my_database.db", strict=True) + .. _python_api_attach: Attaching additional databases @@ -581,6 +587,15 @@ The ``transform=True`` option will update the table schema if any of the followi Changes to ``foreign_keys=`` are not currently detected and applied by ``transform=True``. +You can pass ``strict=True`` to create a table in ``STRICT`` mode: + +.. code-block:: python + + db["cats"].create({ + "id": int, + "name": str, + }, strict=True) + .. _python_api_compound_primary_keys: Compound primary keys @@ -661,7 +676,7 @@ You can set default values for these methods by accessing the table through the # Now you can call .insert() like so: table.insert({"id": 1, "name": "Tracy", "score": 5}) -The configuration options that can be specified in this way are ``pk``, ``foreign_keys``, ``column_order``, ``not_null``, ``defaults``, ``batch_size``, ``hash_id``, ``hash_id_columns``, ``alter``, ``ignore``, ``replace``, ``extracts``, ``conversions``, ``columns``. These are all documented below. +The configuration options that can be specified in this way are ``pk``, ``foreign_keys``, ``column_order``, ``not_null``, ``defaults``, ``batch_size``, ``hash_id``, ``hash_id_columns``, ``alter``, ``ignore``, ``replace``, ``extracts``, ``conversions``, ``columns``, ``strict``. These are all documented below. .. _python_api_defaults_not_null: @@ -1011,6 +1026,7 @@ The first time this is called the record will be created for ``name="Palm"``. An - ``extracts`` - ``conversions`` - ``columns`` +- ``strict`` .. _python_api_extracts: diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 47da407b..0ab9132c 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -909,6 +909,12 @@ def inner(fn): ), load_extension_option, click.option("--silent", is_flag=True, help="Do not show progress bar"), + click.option( + "--strict", + is_flag=True, + default=False, + help="Apply STRICT mode to created table", + ), ) ): fn = decorator(fn) @@ -951,6 +957,7 @@ def insert_upsert_implementation( silent=False, bulk_sql=None, functions=None, + strict=False, ): db = sqlite_utils.Database(path) _load_extensions(db, load_extension) @@ -1066,6 +1073,7 @@ def insert_upsert_implementation( "replace": replace, "truncate": truncate, "analyze": analyze, + "strict": strict, } if not_null: extra_kwargs["not_null"] = set(not_null) @@ -1186,6 +1194,7 @@ def insert( truncate, not_null, default, + strict, ): """ Insert records from FILE into a table, creating the table if it @@ -1264,6 +1273,7 @@ def insert( silent=silent, not_null=not_null, default=default, + strict=strict, ) except UnicodeDecodeError as ex: raise click.ClickException(UNICODE_ERROR.format(ex)) @@ -1299,6 +1309,7 @@ def upsert( analyze, load_extension, silent, + strict, ): """ Upsert records based on their primary key. Works like 'insert' but if @@ -1343,6 +1354,7 @@ def upsert( analyze=analyze, load_extension=load_extension, silent=silent, + strict=strict, ) except UnicodeDecodeError as ex: raise click.ClickException(UNICODE_ERROR.format(ex)) @@ -1511,6 +1523,11 @@ def create_database(path, enable_wal, init_spatialite, load_extension): help="If table already exists, try to transform the schema", ) @load_extension_option +@click.option( + "--strict", + is_flag=True, + help="Apply STRICT mode to created table", +) def create_table( path, table, @@ -1523,6 +1540,7 @@ def create_table( replace, transform, load_extension, + strict, ): """ Add a table with the specified columns. Columns should be specified using @@ -1570,6 +1588,7 @@ def create_table( ignore=ignore, replace=replace, transform=transform, + strict=strict, ) diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 50e26e13..371eed9c 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -303,6 +303,7 @@ class Database: ``sql, parameters`` every time a SQL query is executed :param use_counts_table: set to ``True`` to use a cached counts table, if available. See :ref:`python_api_cached_table_counts` + :param strict: Apply STRICT mode to all created tables (unless overridden) """ _counts_table_name = "_counts" @@ -318,6 +319,7 @@ def __init__( tracer: Optional[Callable] = None, use_counts_table: bool = False, execute_plugins: bool = True, + strict: bool = False, ): assert (filename_or_conn is not None and (not memory and not memory_name)) or ( filename_or_conn is None and (memory or memory_name) @@ -351,6 +353,7 @@ def __init__( self.use_counts_table = use_counts_table if execute_plugins: pm.hook.prepare_connection(conn=self.conn) + self.strict = strict def close(self): "Close the SQLite connection, and the underlying database file" @@ -537,8 +540,11 @@ def table(self, table_name: str, **kwargs) -> Union["Table", "View"]: :param table_name: Name of the table """ - klass = View if table_name in self.view_names() else Table - return klass(self, table_name, **kwargs) + if table_name in self.view_names(): + return View(self, table_name, **kwargs) + else: + kwargs.setdefault("strict", self.strict) + return Table(self, table_name, **kwargs) def quote(self, value: str) -> str: """ @@ -824,6 +830,7 @@ def create_table_sql( hash_id_columns: Optional[Iterable[str]] = None, extracts: Optional[Union[Dict[str, str], List[str]]] = None, if_not_exists: bool = False, + strict: bool = False, ) -> str: """ Returns the SQL ``CREATE TABLE`` statement for creating the specified table. @@ -839,6 +846,7 @@ def create_table_sql( :param hash_id_columns: List of columns to be used when calculating the hash ID for a row :param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts` :param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS`` + :param strict: Apply STRICT mode to table """ if hash_id_columns and (hash_id is None): hash_id = "id" @@ -935,12 +943,13 @@ def sort_key(p): columns_sql = ",\n".join(column_defs) sql = """CREATE TABLE {if_not_exists}[{table}] ( {columns_sql}{extra_pk} -); +){strict}; """.format( if_not_exists="IF NOT EXISTS " if if_not_exists else "", table=name, columns_sql=columns_sql, extra_pk=extra_pk, + strict=" STRICT" if strict and self.supports_strict else "", ) return sql @@ -960,6 +969,7 @@ def create_table( replace: bool = False, ignore: bool = False, transform: bool = False, + strict: bool = False, ) -> "Table": """ Create a table with the specified name and the specified ``{column_name: type}`` columns. @@ -980,6 +990,7 @@ def create_table( :param replace: Drop and replace table if it already exists :param ignore: Silently do nothing if table already exists :param transform: If table already exists transform it to fit the specified schema + :param strict: Apply STRICT mode to table """ # Transform table to match the new definition if table already exists: if self[name].exists(): @@ -1051,6 +1062,7 @@ def create_table( hash_id_columns=hash_id_columns, extracts=extracts, if_not_exists=if_not_exists, + strict=strict, ) self.execute(sql) created_table = self.table( @@ -1419,6 +1431,7 @@ class Table(Queryable): :param extracts: Dictionary or list of column names to extract into a separate table on inserts :param conversions: Dictionary of column names and conversion functions :param columns: Dictionary of column names to column types + :param strict: If True, apply STRICT mode to table """ #: The ``rowid`` of the last inserted, updated or selected row. @@ -1444,6 +1457,7 @@ def __init__( extracts: Optional[Union[Dict[str, str], List[str]]] = None, conversions: Optional[dict] = None, columns: Optional[Dict[str, Any]] = None, + strict: bool = False, ): super().__init__(db, name) self._defaults = dict( @@ -1461,6 +1475,7 @@ def __init__( extracts=extracts, conversions=conversions or {}, columns=columns, + strict=strict, ) def __repr__(self) -> str: @@ -1642,6 +1657,7 @@ def create( replace: bool = False, ignore: bool = False, transform: bool = False, + strict: bool = False, ) -> "Table": """ Create a table with the specified columns. @@ -1661,6 +1677,7 @@ def create( :param replace: Drop and replace table if it already exists :param ignore: Silently do nothing if table already exists :param transform: If table already exists transform it to fit the specified schema + :param strict: Apply STRICT mode to table """ columns = {name: value for (name, value) in columns.items()} with self.db.conn: @@ -1679,6 +1696,7 @@ def create( replace=replace, ignore=ignore, transform=transform, + strict=strict, ) return self @@ -1912,6 +1930,7 @@ def transform_sql( defaults=create_table_defaults, foreign_keys=create_table_foreign_keys, column_order=column_order, + strict=self.strict, ).strip() ) @@ -3114,6 +3133,7 @@ def insert( extracts: Optional[Union[Dict[str, str], List[str], Default]] = DEFAULT, conversions: Optional[Union[Dict[str, str], Default]] = DEFAULT, columns: Optional[Union[Dict[str, Any], Default]] = DEFAULT, + strict: Optional[Union[bool, Default]] = DEFAULT, ) -> "Table": """ Insert a single record into the table. The table will be created with a schema that matches @@ -3146,6 +3166,7 @@ def insert( is being inserted, for example ``{"name": "upper(?)"}``. See :ref:`python_api_conversions`. :param columns: Dictionary over-riding the detected types used for the columns, for example ``{"age": int, "weight": float}``. + :param strict: Boolean, apply STRICT mode if creating the table. """ return self.insert_all( [record], @@ -3162,6 +3183,7 @@ def insert( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ) def insert_all( @@ -3184,6 +3206,7 @@ def insert_all( columns=DEFAULT, upsert=False, analyze=False, + strict=DEFAULT, ) -> "Table": """ Like ``.insert()`` but takes a list of records and ensures that the table @@ -3205,6 +3228,7 @@ def insert_all( extracts = self.value_or_default("extracts", extracts) conversions = self.value_or_default("conversions", conversions) or {} columns = self.value_or_default("columns", columns) + strict = self.value_or_default("strict", strict) if hash_id_columns and hash_id is None: hash_id = "id" @@ -3260,6 +3284,7 @@ def insert_all( hash_id=hash_id, hash_id_columns=hash_id_columns, extracts=extracts, + strict=strict, ) all_columns_set = set() for record in chunk: @@ -3310,6 +3335,7 @@ def upsert( extracts=DEFAULT, conversions=DEFAULT, columns=DEFAULT, + strict=DEFAULT, ) -> "Table": """ Like ``.insert()`` but performs an ``UPSERT``, where records are inserted if they do @@ -3330,6 +3356,7 @@ def upsert( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ) def upsert_all( @@ -3348,6 +3375,7 @@ def upsert_all( conversions=DEFAULT, columns=DEFAULT, analyze=False, + strict=DEFAULT, ) -> "Table": """ Like ``.upsert()`` but can be applied to a list of records. @@ -3368,6 +3396,7 @@ def upsert_all( columns=columns, upsert=True, analyze=analyze, + strict=strict, ) def add_missing_columns(self, records: Iterable[Dict[str, Any]]) -> "Table": @@ -3390,6 +3419,7 @@ def lookup( extracts: Optional[Union[Dict[str, str], List[str]]] = None, conversions: Optional[Dict[str, str]] = None, columns: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = False, ): """ Create or populate a lookup table with the specified values. @@ -3412,6 +3442,7 @@ def lookup( :param lookup_values: Dictionary specifying column names and values to use for the lookup :param extra_values: Additional column values to be used only if creating a new record + :param strict: Boolean, apply STRICT mode if creating the table. """ assert isinstance(lookup_values, dict) if extra_values is not None: @@ -3443,6 +3474,7 @@ def lookup( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ).last_pk else: pk = self.insert( @@ -3455,6 +3487,7 @@ def lookup( extracts=extracts, conversions=conversions, columns=columns, + strict=strict, ).last_pk self.create_index(lookup_values.keys(), unique=True) return pk diff --git a/tests/test_cli.py b/tests/test_cli.py index 53135891..770b02bf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2401,3 +2401,32 @@ def test_load_extension(entrypoint, should_pass, should_fail): catch_exceptions=False, ) assert result.exit_code == 1 + + +@pytest.mark.parametrize("strict", (False, True)) +def test_create_table_strict(strict): + runner = CliRunner() + with runner.isolated_filesystem(): + db = Database("test.db") + result = runner.invoke( + cli.cli, + ["create-table", "test.db", "items", "id", "integer"] + + (["--strict"] if strict else []), + ) + assert result.exit_code == 0 + assert db["items"].strict == strict or not db.supports_strict + + +@pytest.mark.parametrize("method", ("insert", "upsert")) +@pytest.mark.parametrize("strict", (False, True)) +def test_insert_upsert_strict(tmpdir, method, strict): + db_path = str(tmpdir / "test.db") + result = CliRunner().invoke( + cli.cli, + [method, db_path, "items", "-", "--csv", "--pk", "id"] + + (["--strict"] if strict else []), + input="id\n1", + ) + assert result.exit_code == 0 + db = Database(db_path) + assert db["items"].strict == strict or not db.supports_strict diff --git a/tests/test_create.py b/tests/test_create.py index a88374f6..3bf1004a 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1316,3 +1316,46 @@ def test_rename_table(fresh_db): # Should error if table does not exist: with pytest.raises(sqlite3.OperationalError): fresh_db.rename_table("does_not_exist", "renamed") + + +@pytest.mark.parametrize("strict", (False, True)) +def test_database_strict(strict): + db = Database(memory=True, strict=strict) + table = db.table("t", columns={"id": int}) + table.insert({"id": 1}) + assert table.strict == strict or not db.supports_strict + + +@pytest.mark.parametrize("strict", (False, True)) +def test_database_strict_override(strict): + db = Database(memory=True, strict=strict) + table = db.table("t", columns={"id": int}, strict=not strict) + table.insert({"id": 1}) + assert table.strict != strict or not db.supports_strict + + +@pytest.mark.parametrize( + "method_name", ("insert", "upsert", "insert_all", "upsert_all") +) +@pytest.mark.parametrize("strict", (False, True)) +def test_insert_upsert_strict(fresh_db, method_name, strict): + table = fresh_db["t"] + method = getattr(table, method_name) + record = {"id": 1} + if method_name.endswith("_all"): + record = [record] + method(record, pk="id", strict=strict) + assert table.strict == strict or not fresh_db.supports_strict + + +@pytest.mark.parametrize("strict", (False, True)) +def test_create_table_strict(fresh_db, strict): + table = fresh_db.create_table("t", {"id": int}, strict=strict) + assert table.strict == strict or not fresh_db.supports_strict + + +@pytest.mark.parametrize("strict", (False, True)) +def test_create_strict(fresh_db, strict): + table = fresh_db["t"] + table.create({"id": int}, strict=strict) + assert table.strict == strict or not fresh_db.supports_strict diff --git a/tests/test_lookup.py b/tests/test_lookup.py index 31be414c..b5ae61fa 100644 --- a/tests/test_lookup.py +++ b/tests/test_lookup.py @@ -151,3 +151,9 @@ def test_lookup_with_extra_insert_parameters(fresh_db): columns=["name", "type"], ) ] + + +@pytest.mark.parametrize("strict", (False, True)) +def test_lookup_new_table_strict(fresh_db, strict): + fresh_db["species"].lookup({"name": "Palm"}, strict=strict) + assert fresh_db["species"].strict == strict or not fresh_db.supports_strict diff --git a/tests/test_transform.py b/tests/test_transform.py index 1894494a..111236dd 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -530,3 +530,12 @@ def test_transform_preserves_rowids(fresh_db, table_type): tuple(row) for row in fresh_db.execute("select rowid, id, name from places") ) assert previous_rows == next_rows + + +@pytest.mark.parametrize("strict", (False, True)) +def test_transform_strict(fresh_db, strict): + dogs = fresh_db.table("dogs", strict=strict) + dogs.insert({"id": 1, "name": "Cleo"}) + assert dogs.strict == strict or not fresh_db.supports_strict + dogs.transform(not_null={"name"}) + assert dogs.strict == strict or not fresh_db.supports_strict