Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions splitgraph/core/fragment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from splitgraph.engine.postgres.engine import (
SG_UD_FLAG,
add_ud_flag_column,
chunk,
get_change_key,
)
from splitgraph.exceptions import SplitGraphError
Expand Down Expand Up @@ -1141,6 +1142,42 @@ def filter_fragments(self, object_ids: List[str], table: "Table", quals: Any) ->
# Preserve original object order.
return [r for r in object_ids if r in objects_to_scan]

def generate_surrogate_pk(
self, table: "Table", object_pks: List[Tuple[Any, Any]]
) -> List[Tuple[Any, Any]]:
"""
When partitioning data, if the table doesn't have a primary key, we use a "surrogate"
primary key by concatenating the whole row as a string on the PG side (this is because
the whole row can sometimes contain NULLs which we can't compare in PG).

We need to mimic this when calculating if the objects we're about to scan through
overlap: e.g. using string comparison, "(some_country, 100)" < "(some_country, 20)",
whereas using typed comparison, (some_country, 100) > (some_country, 20).

To do this, we use a similar hack from when calculating changeset hashes: to avoid having
to reproduce how PG's ::text works, we give it back the rows and get it to cast them
to text for us.
"""
inner_tuple = "(" + ",".join("%s::" + c.pg_type for c in table.table_schema) + ")"
rows = [r for o in object_pks for r in o]

result = []
for batch in chunk(rows, 1000):
query = ( # nosec
"SELECT o::text FROM (VALUES "
+ ",".join(itertools.repeat(inner_tuple, len(batch)))
+ ") o"
)
result.extend(
self.object_engine.run_sql(
query,
[o if not isinstance(o, dict) else Json(o) for row in batch for o in row],
return_shape=ResultShape.MANY_ONE,
)
)
object_pks = list(zip(result[::2], result[1::2]))
return object_pks

def _add_overlapping_objects(
self, table: "Table", all_objects: List[str], filtered_objects: List[str]
) -> Set[str]:
Expand All @@ -1151,6 +1188,10 @@ def _add_overlapping_objects(
table_pk = get_change_key(table.table_schema)
object_pks = self.get_min_max_pks(all_objects, table_pk)

surrogate_pk = not any(t.is_pk for t in table.table_schema)
if surrogate_pk:
object_pks = self.generate_surrogate_pk(table, object_pks)

# Go through all objects and see if they 1) come after any of our chosen objects and 2)
# overlap those objects' PKs (if they come after them)
original_order = {object_id: i for i, object_id in enumerate(all_objects)}
Expand Down
5 changes: 5 additions & 0 deletions splitgraph/core/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def _extract_singleton_fragments(self) -> Tuple[List[str], List[str]]:
# Get fragment boundaries (min-max PKs of every fragment).
table_pk = get_change_key(self.table.table_schema)
object_pks = self.object_manager.get_min_max_pks(self.filtered_objects, table_pk)

surrogate_pk = not any(t.is_pk for t in self.table.table_schema)
if surrogate_pk:
object_pks = self.table.repository.objects.generate_surrogate_pk(self.table, object_pks)

# Group fragments into non-overlapping groups: those can be applied independently of each other.
object_groups = get_chunk_groups(
[
Expand Down
34 changes: 33 additions & 1 deletion test/splitgraph/commands/test_layered_querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from splitgraph.core.repository import Repository, clone
from splitgraph.core.table import _generate_select_query
from splitgraph.engine import ResultShape, _prepare_engine_config
from splitgraph.engine.postgres.engine import PostgresEngine
from splitgraph.engine.postgres.engine import PostgresEngine, get_change_key
from splitgraph.exceptions import ObjectNotFoundError

_DT = dt(2019, 1, 1, 12)
Expand Down Expand Up @@ -587,6 +587,38 @@ def test_disjoint_table_lq_two_singletons_one_overwritten(pg_repo_local):
_assert_fragments_applied(_gsc, apply_fragments, pg_repo_local)


def test_query_plan_surrogate_pk_grouping(local_engine_empty):
# If a table doesn't have a PK, we use whole_row::text to chunk it.
OUTPUT.init()
OUTPUT.run_sql("CREATE TABLE test (key INTEGER, value_1 VARCHAR)")
OUTPUT.run_sql("INSERT INTO test VALUES (1, 'apple')")
OUTPUT.run_sql("INSERT INTO test VALUES (10, 'banana')")
OUTPUT.run_sql("INSERT INTO test VALUES (2, 'orange')")

head = OUTPUT.commit(chunk_size=2)
table = head.get_table("test")

# This makes two partitions, spanning (1, 'apple') -> (10, 'banana')
# and (2, 'orange') -> (2, 'orange'): this is because we use lexicographical order in this case.
assert sorted(table.objects) == [
"o03227f37c3f7c1ebabaaf87253dbe4f146a11ae47499a2578876b0ff03ce48",
"oea90d83872b91bfd77d16d5c0f1209cb892cb80c73eda1117ba8d9f8399692",
]

# Check the raw get_min_max_pks output for this table
assert table.repository.objects.get_min_max_pks(
table.objects, get_change_key(table.table_schema)
) == [((1, "apple"), (10, "banana")), ((2, "orange"), (2, "orange"))]

# Check that the query plan still treats the two objects as disjoint
plan = table.get_query_plan(
quals=None, columns=[c.name for c in table.table_schema], use_cache=False
)

assert sorted(plan.filtered_objects) == sorted(table.objects)
assert plan.non_singletons == []


def _assert_fragments_applied(_gsc, apply_fragments, pg_repo_local):
apply_fragments.assert_called_once_with(
[
Expand Down