Skip to content

Commit

Permalink
950 ModelBuilder and recursive foreign keys (#951)
Browse files Browse the repository at this point in the history
* fix `ModelBuilder` when using recursive foreign keys

* fix linter warnings
  • Loading branch information
dantownsend committed Mar 13, 2024
1 parent c55f8cf commit 0711d3a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
26 changes: 18 additions & 8 deletions piccolo/testing/model_builder.py
Expand Up @@ -122,14 +122,24 @@ async def _build(
continue # Column value exists

if isinstance(column, ForeignKey) and persist:
reference_model = await cls._build(
column._foreign_key_meta.resolved_references,
persist=True,
)
random_value = getattr(
reference_model,
reference_model._meta.primary_key._meta.name,
)
# Check for recursion
if column._foreign_key_meta.references is table_class:
if column._meta.null is True:
# We can avoid this problem entirely by setting it to
# None.
random_value = None
else:
# There's no way to avoid recursion in the situation.
raise ValueError("Recursive foreign key detected")
else:
reference_model = await cls._build(
column._foreign_key_meta.resolved_references,
persist=True,
)
random_value = getattr(
reference_model,
reference_model._meta.primary_key._meta.name,
)
else:
random_value = cls._randomize_attribute(column)

Expand Down
4 changes: 2 additions & 2 deletions tests/apps/migrations/auto/integration/test_migrations.py
Expand Up @@ -1041,10 +1041,10 @@ def test_target_column(self):

@engines_only("postgres", "cockroach")
class TestForeignKeySelf(MigrationTestCase):
def setUp(self):
def setUp(self) -> None:
class TableA(Table):
id = UUID(primary_key=True)
table_a = ForeignKey("self")
table_a: ForeignKey[TableA] = ForeignKey("self")

self.table_classes: t.List[t.Type[Table]] = [TableA]

Expand Down
15 changes: 15 additions & 0 deletions tests/testing/test_model_builder.py
Expand Up @@ -51,6 +51,10 @@ class BandWithLazyReference(Table):
)


class BandWithRecursiveReference(Table):
manager: ForeignKey["Manager"] = ForeignKey("self")


TABLES = (
Manager,
Band,
Expand All @@ -63,6 +67,7 @@ class BandWithLazyReference(Table):
TableWithArrayField,
TableWithDecimal,
BandWithLazyReference,
BandWithRecursiveReference,
)


Expand Down Expand Up @@ -133,6 +138,16 @@ def test_lazy_foreign_key(self):
Manager.exists().where(Manager.id == model.manager).run_sync()
)

def test_recursive_foreign_key(self):
"""
Make sure no infinite loops are created with recursive foreign keys.
"""
model = ModelBuilder.build_sync(
BandWithRecursiveReference, persist=True
)
# It should be set to None, as this foreign key is nullable.
self.assertIsNone(model.manager)

def test_invalid_column(self):
with self.assertRaises(ValueError):
ModelBuilder.build_sync(Band, defaults={"X": 1})
Expand Down

0 comments on commit 0711d3a

Please sign in to comment.