Skip to content

Commit

Permalink
Implement copying rows from sqlite to splitgraph in batches
Browse files Browse the repository at this point in the history
  • Loading branch information
neumark committed Apr 1, 2023
1 parent c4b6b08 commit 00a3f45
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 11 deletions.
98 changes: 88 additions & 10 deletions splitgraph/ingestion/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import contextlib
import itertools
import math
import os
import re
import sqlite3
import tempfile
from contextlib import contextmanager
from datetime import datetime
from numbers import Number
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union

import requests
Expand All @@ -19,6 +22,7 @@
TableColumn,
TableInfo,
TableParams,
TableSchema,
)
from splitgraph.engine.postgres.engine import _quote_ident
from splitgraph.hooks.data_source.base import LoadableDataSource, PreviewableDataSource
Expand Down Expand Up @@ -78,6 +82,7 @@ def query_connection(
def db_from_minio(url: str) -> Generator[sqlite3.Connection, None, None]:
with minio_file(url) as f:
with contextlib.closing(sqlite3.connect(f)) as con:
con.row_factory = sqlite3.Row
yield con


Expand Down Expand Up @@ -137,6 +142,29 @@ def sqlite_connection_to_introspection_result(con: sqlite3.Connection) -> Intros
BINARY_DATA_MESSAGE = "[binary data]"


def sql_quote_str(s: str) -> str:
return s.replace("'", "''")


def emit_value(value: Any) -> str:
if value is None:
return "NULL"

if isinstance(value, float):
if math.isnan(value):
return "NULL"
return f"{value:.20f}"

if isinstance(value, Number) and not isinstance(value, bool):
return str(value)

if isinstance(value, datetime):
return f"'{value.isoformat()}'"

quoted = sql_quote_str(str(value))
return f"'{quoted}'"


def sanitize_preview_row(row: sqlite3.Row) -> Dict[str, Any]:
return {k: row[k] if type(row[k]) != bytes else BINARY_DATA_MESSAGE for k in row.keys()}

Expand All @@ -153,6 +181,36 @@ def get_preview_rows(
]


SQLITE_IMPLICIT_ROWID_COLUMN_NAME = "ROWID"
# copy 1000 rows in a single iteration
DEFAULT_BATCH_SIZE = 1000


def get_select_query(
table_name: str,
primary_keys: List[str],
end_of_last_batch: Optional[sqlite3.Row],
batch_size: int,
):
effective_pks = primary_keys if len(primary_keys) > 0 else [SQLITE_IMPLICIT_ROWID_COLUMN_NAME]
pk_column_list = ", ".join([_quote_ident(col) for col in effective_pks])
where_clause = "true"
if end_of_last_batch is not None:
last_batch_end_tuple = ", ".join(
[emit_value(end_of_last_batch[col]) for col in effective_pks]
)
where_clause = f"({pk_column_list}) > ({last_batch_end_tuple})"
return "SELECT {}* FROM {} WHERE {} ORDER BY {} ASC LIMIT {}".format( # nosec
# add the implicit rowid column to the select if no explicit primary
# key columns exist on table, based on: https://www.sqlite.org/withoutrowid.html
f"{SQLITE_IMPLICIT_ROWID_COLUMN_NAME}, " if len(primary_keys) == 0 else "",
_quote_ident(table_name),
where_clause,
pk_column_list,
batch_size,
) # nosec


class SQLiteDataSource(LoadableDataSource, PreviewableDataSource):

table_params_schema: Dict[str, Any] = {
Expand Down Expand Up @@ -190,6 +248,35 @@ def _get_url(self, tables: Optional[TableInfo] = None):
url = table_params.get("url", url)
return url

def _batched_copy(
self,
con: sqlite3.Connection,
schema: str,
table_name: str,
schema_spec: TableSchema,
batch_size: int,
) -> int:
primary_keys = [col.name for col in schema_spec if col.is_pk]
last_batch_row_count = batch_size
end_of_last_batch: Optional[sqlite3.Row] = None
total_row_count = 0
while last_batch_row_count == batch_size:
table_contents = query_connection(
con, get_select_query(table_name, primary_keys, end_of_last_batch, batch_size)
)
last_batch_row_count = len(table_contents)
end_of_last_batch = None if last_batch_row_count == 0 else table_contents[-1]
total_row_count += last_batch_row_count
insert_table_contents = (
table_contents if len(primary_keys) > 0 else [row[1:] for row in table_contents]
)
self.engine.run_sql_batch(
SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name))
+ SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"),
insert_table_contents,
) # nosec
return total_row_count

def _load(self, schema: str, tables: Optional[TableInfo] = None):
with db_from_minio(self._get_url(tables)) as con:
introspection_result = sqlite_connection_to_introspection_result(con)
Expand All @@ -201,15 +288,7 @@ def _load(self, schema: str, tables: Optional[TableInfo] = None):
table=table_name,
schema_spec=schema_spec,
)
table_contents = query_connection(
con, "SELECT * FROM {}".format(_quote_ident(table_name)) # nosec
)
self.engine.run_sql_batch(
SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name))
+ SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"),
# TODO: break this up into multiple batches for larger sqlite files
table_contents,
) # nosec
self._batched_copy(con, schema, table_name, schema_spec, DEFAULT_BATCH_SIZE)

def introspect(self) -> IntrospectionResult:
with db_from_minio(str(self._get_url())) as con:
Expand Down Expand Up @@ -241,7 +320,6 @@ def preview(self, tables: Optional[TableInfo]) -> PreviewResult:
if type(tables) == dict:
assert isinstance(tables, dict)
with db_from_minio(self._get_url(tables)) as con:
con.row_factory = sqlite3.Row
result = PreviewResult(
{table_name: get_preview_rows(con, table_name) for table_name in tables.keys()}
)
Expand Down
25 changes: 24 additions & 1 deletion test/splitgraph/ingestion/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from splitgraph.ingestion.sqlite import sqlite_to_postgres_type
from splitgraph.ingestion.sqlite import get_select_query, sqlite_to_postgres_type


def test_type_mapping():
Expand All @@ -15,3 +15,26 @@ def test_type_mapping():
assert sqlite_to_postgres_type("decimal(2, 20)") == "DECIMAL(2,20)"
assert sqlite_to_postgres_type("NATIVE CHARACTER(70)") == "VARCHAR(70)"
assert sqlite_to_postgres_type("NVARCHAR(160)") == "VARCHAR(160)"


def test_sqlite_select_query():
# explicit pk, no previous batch
assert (
get_select_query("my_table", ["id"], None, 10)
== 'SELECT * FROM "my_table" WHERE true ORDER BY "id" ASC LIMIT 10'
)
# implicit pk, no previous batch
assert (
get_select_query("my_table", [], None, 10)
== 'SELECT ROWID, * FROM "my_table" WHERE true ORDER BY "ROWID" ASC LIMIT 10'
)
# explicit pk, has previous batch
assert (
get_select_query("my_table", ["id1", "id2"], {"id1": 0, "id2": "a"}, 10)
== 'SELECT * FROM "my_table" WHERE ("id1", "id2") > (0, \'a\') ORDER BY "id1", "id2" ASC LIMIT 10'
)
# implicit pk, has previous batch
assert (
get_select_query("my_table", [], {"ROWID": 500}, 10)
== 'SELECT ROWID, * FROM "my_table" WHERE ("ROWID") > (500) ORDER BY "ROWID" ASC LIMIT 10'
)

0 comments on commit 00a3f45

Please sign in to comment.