Skip to content

Commit

Permalink
Revert "Implement copying rows from sqlite to splitgraph in batches (#…
Browse files Browse the repository at this point in the history
…798)"

This reverts commit bb991a8.
  • Loading branch information
neumark committed Apr 3, 2023
1 parent bb991a8 commit bbe44bf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 91 deletions.
78 changes: 11 additions & 67 deletions splitgraph/ingestion/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import contextlib
import itertools
import math
import os
import re
import sqlite3
import tempfile
from contextlib import contextmanager
from datetime import datetime
from numbers import Number
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union

import requests
from psycopg2.sql import SQL, Identifier
Expand All @@ -22,7 +19,6 @@
TableColumn,
TableInfo,
TableParams,
TableSchema,
)
from splitgraph.engine.postgres.engine import _quote_ident
from splitgraph.hooks.data_source.base import LoadableDataSource, PreviewableDataSource
Expand Down Expand Up @@ -82,7 +78,6 @@ def query_connection(
def db_from_minio(url: str) -> Generator[sqlite3.Connection, None, None]:
with minio_file(url) as f:
with contextlib.closing(sqlite3.connect(f)) as con:
con.row_factory = sqlite3.Row
yield con


Expand Down Expand Up @@ -158,36 +153,6 @@ def get_preview_rows(
]


SQLITE_IMPLICIT_ROWID_COLUMN_NAME = "ROWID"
# copy 1000 rows in a single iteration
DEFAULT_BATCH_SIZE = 1000


def get_select_query(
table_name: str,
primary_keys: List[str],
end_of_last_batch: Optional[sqlite3.Row],
batch_size: int,
) -> Tuple[str, Dict[str, Any]]:
effective_pks = primary_keys if len(primary_keys) > 0 else [SQLITE_IMPLICIT_ROWID_COLUMN_NAME]
pk_column_list = ", ".join([_quote_ident(col) for col in effective_pks])
where_clause = "true"
parameters = {}
if end_of_last_batch is not None:
parameters = {col: end_of_last_batch[col] for col in effective_pks}
where_clause = f"({pk_column_list}) > ({', '.join(['%s'] * len(effective_pks))})"
query = "SELECT {}* FROM {} WHERE {} ORDER BY {} ASC LIMIT {}".format( # nosec
# add the implicit rowid column to the select if no explicit primary
# key columns exist on table, based on: https://www.sqlite.org/withoutrowid.html
f"{SQLITE_IMPLICIT_ROWID_COLUMN_NAME}, " if len(primary_keys) == 0 else "",
_quote_ident(table_name),
where_clause,
pk_column_list,
batch_size,
) # nosec
return (query, parameters)


class SQLiteDataSource(LoadableDataSource, PreviewableDataSource):

table_params_schema: Dict[str, Any] = {
Expand Down Expand Up @@ -225,36 +190,6 @@ def _get_url(self, tables: Optional[TableInfo] = None):
url = table_params.get("url", url)
return url

def _batched_copy(
self,
con: sqlite3.Connection,
schema: str,
table_name: str,
schema_spec: TableSchema,
batch_size: int,
) -> int:
primary_keys = [col.name for col in schema_spec if col.is_pk]
last_batch_row_count = batch_size
end_of_last_batch: Optional[sqlite3.Row] = None
total_row_count = 0
while last_batch_row_count == batch_size:
query, parameters = get_select_query(
table_name, primary_keys, end_of_last_batch, batch_size
)
table_contents = query_connection(con, query, parameters)
last_batch_row_count = len(table_contents)
end_of_last_batch = None if last_batch_row_count == 0 else table_contents[-1]
total_row_count += last_batch_row_count
insert_table_contents = (
table_contents if len(primary_keys) > 0 else [row[1:] for row in table_contents]
)
self.engine.run_sql_batch(
SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name))
+ SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"),
insert_table_contents,
) # nosec
return total_row_count

def _load(self, schema: str, tables: Optional[TableInfo] = None):
with db_from_minio(self._get_url(tables)) as con:
introspection_result = sqlite_connection_to_introspection_result(con)
Expand All @@ -266,7 +201,15 @@ def _load(self, schema: str, tables: Optional[TableInfo] = None):
table=table_name,
schema_spec=schema_spec,
)
self._batched_copy(con, schema, table_name, schema_spec, DEFAULT_BATCH_SIZE)
table_contents = query_connection(
con, "SELECT * FROM {}".format(_quote_ident(table_name)) # nosec
)
self.engine.run_sql_batch(
SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name))
+ SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"),
# TODO: break this up into multiple batches for larger sqlite files
table_contents,
) # nosec

def introspect(self) -> IntrospectionResult:
with db_from_minio(str(self._get_url())) as con:
Expand Down Expand Up @@ -298,6 +241,7 @@ def preview(self, tables: Optional[TableInfo]) -> PreviewResult:
if type(tables) == dict:
assert isinstance(tables, dict)
with db_from_minio(self._get_url(tables)) as con:
con.row_factory = sqlite3.Row
result = PreviewResult(
{table_name: get_preview_rows(con, table_name) for table_name in tables.keys()}
)
Expand Down
25 changes: 1 addition & 24 deletions test/splitgraph/ingestion/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from splitgraph.ingestion.sqlite import get_select_query, sqlite_to_postgres_type
from splitgraph.ingestion.sqlite import sqlite_to_postgres_type


def test_type_mapping():
Expand All @@ -15,26 +15,3 @@ def test_type_mapping():
assert sqlite_to_postgres_type("decimal(2, 20)") == "DECIMAL(2,20)"
assert sqlite_to_postgres_type("NATIVE CHARACTER(70)") == "VARCHAR(70)"
assert sqlite_to_postgres_type("NVARCHAR(160)") == "VARCHAR(160)"


def test_sqlite_select_query():
# explicit pk, no previous batch
assert get_select_query("my_table", ["id"], None, 10) == (
'SELECT * FROM "my_table" WHERE true ORDER BY "id" ASC LIMIT 10',
{},
)
# implicit pk, no previous batch
assert get_select_query("my_table", [], None, 10) == (
'SELECT ROWID, * FROM "my_table" WHERE true ORDER BY "ROWID" ASC LIMIT 10',
{},
)
# explicit pk, has previous batch
assert get_select_query("my_table", ["id1", "id2"], {"id1": 0, "id2": "a"}, 10) == (
'SELECT * FROM "my_table" WHERE ("id1", "id2") > (%s, %s) ORDER BY "id1", "id2" ASC LIMIT 10',
{"id1": 0, "id2": "a"},
)
# implicit pk, has previous batch
assert get_select_query("my_table", [], {"ROWID": 500}, 10) == (
'SELECT ROWID, * FROM "my_table" WHERE ("ROWID") > (%s) ORDER BY "ROWID" ASC LIMIT 10',
{"ROWID": 500},
)

0 comments on commit bbe44bf

Please sign in to comment.