diff --git a/splitgraph/core/fragment_manager.py b/splitgraph/core/fragment_manager.py index 1f723086..aad670cf 100644 --- a/splitgraph/core/fragment_manager.py +++ b/splitgraph/core/fragment_manager.py @@ -26,6 +26,7 @@ from splitgraph.engine.postgres.engine import ( SG_UD_FLAG, add_ud_flag_column, + chunk, get_change_key, ) from splitgraph.exceptions import SplitGraphError @@ -1141,6 +1142,42 @@ def filter_fragments(self, object_ids: List[str], table: "Table", quals: Any) -> # Preserve original object order. return [r for r in object_ids if r in objects_to_scan] + def generate_surrogate_pk( + self, table: "Table", object_pks: List[Tuple[Any, Any]] + ) -> List[Tuple[Any, Any]]: + """ + When partitioning data, if the table doesn't have a primary key, we use a "surrogate" + primary key by concatenating the whole row as a string on the PG side (this is because + the whole row can sometimes contain NULLs which we can't compare in PG). + + We need to mimic this when calculating if the objects we're about to scan through + overlap: e.g. using string comparison, "(some_country, 100)" < "(some_country, 20)", + whereas using typed comparison, (some_country, 100) > (some_country, 20). + + To do this, we use a similar hack from when calculating changeset hashes: to avoid having + to reproduce how PG's ::text works, we give it back the rows and get it to cast them + to text for us. + """ + inner_tuple = "(" + ",".join("%s::" + c.pg_type for c in table.table_schema) + ")" + rows = [r for o in object_pks for r in o] + + result = [] + for batch in chunk(rows, 1000): + query = ( # nosec + "SELECT o::text FROM (VALUES " + + ",".join(itertools.repeat(inner_tuple, len(batch))) + + ") o" + ) + result.extend( + self.object_engine.run_sql( + query, + [o if not isinstance(o, dict) else Json(o) for row in batch for o in row], + return_shape=ResultShape.MANY_ONE, + ) + ) + object_pks = list(zip(result[::2], result[1::2])) + return object_pks + def _add_overlapping_objects( self, table: "Table", all_objects: List[str], filtered_objects: List[str] ) -> Set[str]: @@ -1151,6 +1188,10 @@ def _add_overlapping_objects( table_pk = get_change_key(table.table_schema) object_pks = self.get_min_max_pks(all_objects, table_pk) + surrogate_pk = not any(t.is_pk for t in table.table_schema) + if surrogate_pk: + object_pks = self.generate_surrogate_pk(table, object_pks) + # Go through all objects and see if they 1) come after any of our chosen objects and 2) # overlap those objects' PKs (if they come after them) original_order = {object_id: i for i, object_id in enumerate(all_objects)} diff --git a/splitgraph/core/table.py b/splitgraph/core/table.py index be9c23d7..74ae8887 100644 --- a/splitgraph/core/table.py +++ b/splitgraph/core/table.py @@ -183,6 +183,11 @@ def _extract_singleton_fragments(self) -> Tuple[List[str], List[str]]: # Get fragment boundaries (min-max PKs of every fragment). table_pk = get_change_key(self.table.table_schema) object_pks = self.object_manager.get_min_max_pks(self.filtered_objects, table_pk) + + surrogate_pk = not any(t.is_pk for t in self.table.table_schema) + if surrogate_pk: + object_pks = self.table.repository.objects.generate_surrogate_pk(self.table, object_pks) + # Group fragments into non-overlapping groups: those can be applied independently of each other. object_groups = get_chunk_groups( [ diff --git a/test/splitgraph/commands/test_layered_querying.py b/test/splitgraph/commands/test_layered_querying.py index e593d00c..8ac75a7e 100644 --- a/test/splitgraph/commands/test_layered_querying.py +++ b/test/splitgraph/commands/test_layered_querying.py @@ -15,7 +15,7 @@ from splitgraph.core.repository import Repository, clone from splitgraph.core.table import _generate_select_query from splitgraph.engine import ResultShape, _prepare_engine_config -from splitgraph.engine.postgres.engine import PostgresEngine +from splitgraph.engine.postgres.engine import PostgresEngine, get_change_key from splitgraph.exceptions import ObjectNotFoundError _DT = dt(2019, 1, 1, 12) @@ -587,6 +587,38 @@ def test_disjoint_table_lq_two_singletons_one_overwritten(pg_repo_local): _assert_fragments_applied(_gsc, apply_fragments, pg_repo_local) +def test_query_plan_surrogate_pk_grouping(local_engine_empty): + # If a table doesn't have a PK, we use whole_row::text to chunk it. + OUTPUT.init() + OUTPUT.run_sql("CREATE TABLE test (key INTEGER, value_1 VARCHAR)") + OUTPUT.run_sql("INSERT INTO test VALUES (1, 'apple')") + OUTPUT.run_sql("INSERT INTO test VALUES (10, 'banana')") + OUTPUT.run_sql("INSERT INTO test VALUES (2, 'orange')") + + head = OUTPUT.commit(chunk_size=2) + table = head.get_table("test") + + # This makes two partitions, spanning (1, 'apple') -> (10, 'banana') + # and (2, 'orange') -> (2, 'orange'): this is because we use lexicographical order in this case. + assert sorted(table.objects) == [ + "o03227f37c3f7c1ebabaaf87253dbe4f146a11ae47499a2578876b0ff03ce48", + "oea90d83872b91bfd77d16d5c0f1209cb892cb80c73eda1117ba8d9f8399692", + ] + + # Check the raw get_min_max_pks output for this table + assert table.repository.objects.get_min_max_pks( + table.objects, get_change_key(table.table_schema) + ) == [((1, "apple"), (10, "banana")), ((2, "orange"), (2, "orange"))] + + # Check that the query plan still treats the two objects as disjoint + plan = table.get_query_plan( + quals=None, columns=[c.name for c in table.table_schema], use_cache=False + ) + + assert sorted(plan.filtered_objects) == sorted(table.objects) + assert plan.non_singletons == [] + + def _assert_fragments_applied(_gsc, apply_fragments, pg_repo_local): apply_fragments.assert_called_once_with( [