From 038dcfeaba98fbfe1e38febfa24c9c46ad59e48f Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Thu, 4 Nov 2021 12:04:15 +0000 Subject: [PATCH] Fix a query performance issue with tables without a PK. When partitioning data, if the table doesn't have a primary key, we use a "surrogate" primary key by concatenating the whole row as a string on the PG side (this is because the whole row can sometimes contain NULLs which we can't compare in PG). The query planner used to use the non-textual composite PK to figure out which chunks overlapped in order to materialize them. This meant false positives where something like (2, 'orange') would force a materialization because numerically, it's greater than 10 and lexicographically, it isn't. To fix this, regenerate the textual surrogate PK at query plan time and use it to figure out when chunks overlap. --- splitgraph/core/fragment_manager.py | 41 +++++++++++++++++++ splitgraph/core/table.py | 5 +++ .../commands/test_layered_querying.py | 34 ++++++++++++++- 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/splitgraph/core/fragment_manager.py b/splitgraph/core/fragment_manager.py index 1f723086..aad670cf 100644 --- a/splitgraph/core/fragment_manager.py +++ b/splitgraph/core/fragment_manager.py @@ -26,6 +26,7 @@ from splitgraph.engine.postgres.engine import ( SG_UD_FLAG, add_ud_flag_column, + chunk, get_change_key, ) from splitgraph.exceptions import SplitGraphError @@ -1141,6 +1142,42 @@ def filter_fragments(self, object_ids: List[str], table: "Table", quals: Any) -> # Preserve original object order. return [r for r in object_ids if r in objects_to_scan] + def generate_surrogate_pk( + self, table: "Table", object_pks: List[Tuple[Any, Any]] + ) -> List[Tuple[Any, Any]]: + """ + When partitioning data, if the table doesn't have a primary key, we use a "surrogate" + primary key by concatenating the whole row as a string on the PG side (this is because + the whole row can sometimes contain NULLs which we can't compare in PG). + + We need to mimic this when calculating if the objects we're about to scan through + overlap: e.g. using string comparison, "(some_country, 100)" < "(some_country, 20)", + whereas using typed comparison, (some_country, 100) > (some_country, 20). + + To do this, we use a similar hack from when calculating changeset hashes: to avoid having + to reproduce how PG's ::text works, we give it back the rows and get it to cast them + to text for us. + """ + inner_tuple = "(" + ",".join("%s::" + c.pg_type for c in table.table_schema) + ")" + rows = [r for o in object_pks for r in o] + + result = [] + for batch in chunk(rows, 1000): + query = ( # nosec + "SELECT o::text FROM (VALUES " + + ",".join(itertools.repeat(inner_tuple, len(batch))) + + ") o" + ) + result.extend( + self.object_engine.run_sql( + query, + [o if not isinstance(o, dict) else Json(o) for row in batch for o in row], + return_shape=ResultShape.MANY_ONE, + ) + ) + object_pks = list(zip(result[::2], result[1::2])) + return object_pks + def _add_overlapping_objects( self, table: "Table", all_objects: List[str], filtered_objects: List[str] ) -> Set[str]: @@ -1151,6 +1188,10 @@ def _add_overlapping_objects( table_pk = get_change_key(table.table_schema) object_pks = self.get_min_max_pks(all_objects, table_pk) + surrogate_pk = not any(t.is_pk for t in table.table_schema) + if surrogate_pk: + object_pks = self.generate_surrogate_pk(table, object_pks) + # Go through all objects and see if they 1) come after any of our chosen objects and 2) # overlap those objects' PKs (if they come after them) original_order = {object_id: i for i, object_id in enumerate(all_objects)} diff --git a/splitgraph/core/table.py b/splitgraph/core/table.py index be9c23d7..74ae8887 100644 --- a/splitgraph/core/table.py +++ b/splitgraph/core/table.py @@ -183,6 +183,11 @@ def _extract_singleton_fragments(self) -> Tuple[List[str], List[str]]: # Get fragment boundaries (min-max PKs of every fragment). table_pk = get_change_key(self.table.table_schema) object_pks = self.object_manager.get_min_max_pks(self.filtered_objects, table_pk) + + surrogate_pk = not any(t.is_pk for t in self.table.table_schema) + if surrogate_pk: + object_pks = self.table.repository.objects.generate_surrogate_pk(self.table, object_pks) + # Group fragments into non-overlapping groups: those can be applied independently of each other. object_groups = get_chunk_groups( [ diff --git a/test/splitgraph/commands/test_layered_querying.py b/test/splitgraph/commands/test_layered_querying.py index e593d00c..8ac75a7e 100644 --- a/test/splitgraph/commands/test_layered_querying.py +++ b/test/splitgraph/commands/test_layered_querying.py @@ -15,7 +15,7 @@ from splitgraph.core.repository import Repository, clone from splitgraph.core.table import _generate_select_query from splitgraph.engine import ResultShape, _prepare_engine_config -from splitgraph.engine.postgres.engine import PostgresEngine +from splitgraph.engine.postgres.engine import PostgresEngine, get_change_key from splitgraph.exceptions import ObjectNotFoundError _DT = dt(2019, 1, 1, 12) @@ -587,6 +587,38 @@ def test_disjoint_table_lq_two_singletons_one_overwritten(pg_repo_local): _assert_fragments_applied(_gsc, apply_fragments, pg_repo_local) +def test_query_plan_surrogate_pk_grouping(local_engine_empty): + # If a table doesn't have a PK, we use whole_row::text to chunk it. + OUTPUT.init() + OUTPUT.run_sql("CREATE TABLE test (key INTEGER, value_1 VARCHAR)") + OUTPUT.run_sql("INSERT INTO test VALUES (1, 'apple')") + OUTPUT.run_sql("INSERT INTO test VALUES (10, 'banana')") + OUTPUT.run_sql("INSERT INTO test VALUES (2, 'orange')") + + head = OUTPUT.commit(chunk_size=2) + table = head.get_table("test") + + # This makes two partitions, spanning (1, 'apple') -> (10, 'banana') + # and (2, 'orange') -> (2, 'orange'): this is because we use lexicographical order in this case. + assert sorted(table.objects) == [ + "o03227f37c3f7c1ebabaaf87253dbe4f146a11ae47499a2578876b0ff03ce48", + "oea90d83872b91bfd77d16d5c0f1209cb892cb80c73eda1117ba8d9f8399692", + ] + + # Check the raw get_min_max_pks output for this table + assert table.repository.objects.get_min_max_pks( + table.objects, get_change_key(table.table_schema) + ) == [((1, "apple"), (10, "banana")), ((2, "orange"), (2, "orange"))] + + # Check that the query plan still treats the two objects as disjoint + plan = table.get_query_plan( + quals=None, columns=[c.name for c in table.table_schema], use_cache=False + ) + + assert sorted(plan.filtered_objects) == sorted(table.objects) + assert plan.non_singletons == [] + + def _assert_fragments_applied(_gsc, apply_fragments, pg_repo_local): apply_fragments.assert_called_once_with( [