diff --git a/splitgraph/ingestion/sqlite/__init__.py b/splitgraph/ingestion/sqlite/__init__.py index 669a6914..dc8261e8 100644 --- a/splitgraph/ingestion/sqlite/__init__.py +++ b/splitgraph/ingestion/sqlite/__init__.py @@ -1,10 +1,13 @@ import contextlib import itertools +import math import os import re import sqlite3 import tempfile from contextlib import contextmanager +from datetime import datetime +from numbers import Number from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union import requests @@ -19,6 +22,7 @@ TableColumn, TableInfo, TableParams, + TableSchema, ) from splitgraph.engine.postgres.engine import _quote_ident from splitgraph.hooks.data_source.base import LoadableDataSource, PreviewableDataSource @@ -78,6 +82,7 @@ def query_connection( def db_from_minio(url: str) -> Generator[sqlite3.Connection, None, None]: with minio_file(url) as f: with contextlib.closing(sqlite3.connect(f)) as con: + con.row_factory = sqlite3.Row yield con @@ -137,6 +142,29 @@ def sqlite_connection_to_introspection_result(con: sqlite3.Connection) -> Intros BINARY_DATA_MESSAGE = "[binary data]" +def sql_quote_str(s: str) -> str: + return s.replace("'", "''") + + +def emit_value(value: Any) -> str: + if value is None: + return "NULL" + + if isinstance(value, float): + if math.isnan(value): + return "NULL" + return f"{value:.20f}" + + if isinstance(value, Number) and not isinstance(value, bool): + return str(value) + + if isinstance(value, datetime): + return f"'{value.isoformat()}'" + + quoted = sql_quote_str(str(value)) + return f"'{quoted}'" + + def sanitize_preview_row(row: sqlite3.Row) -> Dict[str, Any]: return {k: row[k] if type(row[k]) != bytes else BINARY_DATA_MESSAGE for k in row.keys()} @@ -153,6 +181,36 @@ def get_preview_rows( ] +SQLITE_IMPLICIT_ROWID_COLUMN_NAME = "ROWID" +# copy 1000 rows in a single iteration +DEFAULT_BATCH_SIZE = 1000 + + +def get_select_query( + table_name: str, + primary_keys: List[str], + end_of_last_batch: Optional[sqlite3.Row], + batch_size: int, +): + effective_pks = primary_keys if len(primary_keys) > 0 else [SQLITE_IMPLICIT_ROWID_COLUMN_NAME] + pk_column_list = ", ".join([_quote_ident(col) for col in effective_pks]) + where_clause = "true" + if end_of_last_batch is not None: + last_batch_end_tuple = ", ".join( + [emit_value(end_of_last_batch[col]) for col in effective_pks] + ) + where_clause = f"({pk_column_list}) > ({last_batch_end_tuple})" + return "SELECT {}* FROM {} WHERE {} ORDER BY {} ASC LIMIT {}".format( # nosec + # add the implicit rowid column to the select if no explicit primary + # key columns exist on table, based on: https://www.sqlite.org/withoutrowid.html + f"{SQLITE_IMPLICIT_ROWID_COLUMN_NAME}, " if len(primary_keys) == 0 else "", + _quote_ident(table_name), + where_clause, + pk_column_list, + batch_size, + ) # nosec + + class SQLiteDataSource(LoadableDataSource, PreviewableDataSource): table_params_schema: Dict[str, Any] = { @@ -190,6 +248,35 @@ def _get_url(self, tables: Optional[TableInfo] = None): url = table_params.get("url", url) return url + def _batched_copy( + self, + con: sqlite3.Connection, + schema: str, + table_name: str, + schema_spec: TableSchema, + batch_size: int, + ) -> int: + primary_keys = [col.name for col in schema_spec if col.is_pk] + last_batch_row_count = batch_size + end_of_last_batch: Optional[sqlite3.Row] = None + total_row_count = 0 + while last_batch_row_count == batch_size: + table_contents = query_connection( + con, get_select_query(table_name, primary_keys, end_of_last_batch, batch_size) + ) + last_batch_row_count = len(table_contents) + end_of_last_batch = None if last_batch_row_count == 0 else table_contents[-1] + total_row_count += last_batch_row_count + insert_table_contents = ( + table_contents if len(primary_keys) > 0 else [row[1:] for row in table_contents] + ) + self.engine.run_sql_batch( + SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name)) + + SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"), + insert_table_contents, + ) # nosec + return total_row_count + def _load(self, schema: str, tables: Optional[TableInfo] = None): with db_from_minio(self._get_url(tables)) as con: introspection_result = sqlite_connection_to_introspection_result(con) @@ -201,15 +288,7 @@ def _load(self, schema: str, tables: Optional[TableInfo] = None): table=table_name, schema_spec=schema_spec, ) - table_contents = query_connection( - con, "SELECT * FROM {}".format(_quote_ident(table_name)) # nosec - ) - self.engine.run_sql_batch( - SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name)) - + SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"), - # TODO: break this up into multiple batches for larger sqlite files - table_contents, - ) # nosec + self._batched_copy(con, schema, table_name, schema_spec, DEFAULT_BATCH_SIZE) def introspect(self) -> IntrospectionResult: with db_from_minio(str(self._get_url())) as con: @@ -241,7 +320,6 @@ def preview(self, tables: Optional[TableInfo]) -> PreviewResult: if type(tables) == dict: assert isinstance(tables, dict) with db_from_minio(self._get_url(tables)) as con: - con.row_factory = sqlite3.Row result = PreviewResult( {table_name: get_preview_rows(con, table_name) for table_name in tables.keys()} ) diff --git a/test/splitgraph/ingestion/test_sqlite.py b/test/splitgraph/ingestion/test_sqlite.py index 7d216517..1faaaaec 100644 --- a/test/splitgraph/ingestion/test_sqlite.py +++ b/test/splitgraph/ingestion/test_sqlite.py @@ -1,4 +1,4 @@ -from splitgraph.ingestion.sqlite import sqlite_to_postgres_type +from splitgraph.ingestion.sqlite import get_select_query, sqlite_to_postgres_type def test_type_mapping(): @@ -15,3 +15,26 @@ def test_type_mapping(): assert sqlite_to_postgres_type("decimal(2, 20)") == "DECIMAL(2,20)" assert sqlite_to_postgres_type("NATIVE CHARACTER(70)") == "VARCHAR(70)" assert sqlite_to_postgres_type("NVARCHAR(160)") == "VARCHAR(160)" + + +def test_sqlite_select_query(): + # explicit pk, no previous batch + assert ( + get_select_query("my_table", ["id"], None, 10) + == 'SELECT * FROM "my_table" WHERE true ORDER BY "id" ASC LIMIT 10' + ) + # implicit pk, no previous batch + assert ( + get_select_query("my_table", [], None, 10) + == 'SELECT ROWID, * FROM "my_table" WHERE true ORDER BY "ROWID" ASC LIMIT 10' + ) + # explicit pk, has previous batch + assert ( + get_select_query("my_table", ["id1", "id2"], {"id1": 0, "id2": "a"}, 10) + == 'SELECT * FROM "my_table" WHERE ("id1", "id2") > (0, \'a\') ORDER BY "id1", "id2" ASC LIMIT 10' + ) + # implicit pk, has previous batch + assert ( + get_select_query("my_table", [], {"ROWID": 500}, 10) + == 'SELECT ROWID, * FROM "my_table" WHERE ("ROWID") > (500) ORDER BY "ROWID" ASC LIMIT 10' + )