Skip to content

Commit

Permalink
refactor: share logic between FK & O2O when init fields (#1618)
Browse files Browse the repository at this point in the history
* refactor: use for loop to set attr when init fk relations

* refactor: share logic between FK & O2O

* Make init_fk_o2o_field as closure function

* Use kwargs for clarity
  • Loading branch information
waketzheng committed May 22, 2024
1 parent 1f6e823 commit a191ea8
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 155 deletions.
4 changes: 2 additions & 2 deletions tests/model_setup/test_bad_relation_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def test_wrong_model_init(self):

async def test_no_app_in_reference_init(self):
with self.assertRaisesRegex(
ConfigurationError, 'Foreign key accepts model name in format "app.Model"'
ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"'
):
await Tortoise.init(
{
Expand All @@ -80,7 +80,7 @@ async def test_no_app_in_reference_init(self):

async def test_more_than_two_dots_in_reference_init(self):
with self.assertRaisesRegex(
ConfigurationError, 'Foreign key accepts model name in format "app.Model"'
ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"'
):
await Tortoise.init(
{
Expand Down
4 changes: 2 additions & 2 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def test_create_index(self):

async def test_fk_bad_model_name(self):
with self.assertRaisesRegex(
ConfigurationError, 'Foreign key accepts model name in format "app.Model"'
ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"'
):
await self.init_for("tests.schema.models_fk_1")

Expand Down Expand Up @@ -205,7 +205,7 @@ async def test_o2o_bad_null(self):

async def test_m2m_bad_model_name(self):
with self.assertRaisesRegex(
ConfigurationError, 'Foreign key accepts model name in format "app.Model"'
ConfigurationError, 'ManyToManyField accepts model name in format "app.Model"'
):
await self.init_for("tests.schema.models_m2m_1")

Expand Down
225 changes: 80 additions & 145 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,90 @@ def get_related_model(related_app_name: str, related_model_name: str) -> Type["M

def split_reference(reference: str) -> Tuple[str, str]:
"""
Test, if reference follow the official naming conventions. Throws a
Validate, if reference follow the official naming conventions. Throws a
ConfigurationError with a hopefully helpful message. If successful,
returns the app and the model name.
:raises ConfigurationError: If no model reference is invalid.
:raises ConfigurationError: If reference is invalid.
"""
items = reference.split(".")
if len(items) != 2: # pragma: nocoverage
if len(items := reference.split(".")) != 2: # pragma: nocoverage
raise ConfigurationError(
(
"'%s' is not a valid model reference Bad Reference."
" Should be something like <appname>.<modelname>."
)
% reference
f"'{reference}' is not a valid model reference Bad Reference."
" Should be something like '<appname>.<modelname>'."
)
return items[0], items[1]

return (items[0], items[1])
def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
if is_o2o:
fk_object: Union[OneToOneFieldInstance, ForeignKeyFieldInstance] = cast(
OneToOneFieldInstance, model._meta.fields_map[field]
)
else:
fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field])
related_app_name, related_model_name = split_reference(fk_object.model_name)
related_model = get_related_model(related_app_name, related_model_name)

if to_field := fk_object.to_field:
related_field = related_model._meta.fields_map.get(to_field)
if not related_field:
raise ConfigurationError(
f'there is no field named "{to_field}" in model "{related_model_name}"'
)
if not related_field.unique:
raise ConfigurationError(
f'field "{to_field}" in model "{related_model_name}" is not unique'
)
else:
fk_object.to_field = related_model._meta.pk_attr
related_field = related_model._meta.pk
key_fk_object = deepcopy(related_field)
fk_object.to_field_instance = related_field # type:ignore[arg-type,call-overload]
fk_object.field_type = fk_object.to_field_instance.field_type

key_field = f"{field}_id"
key_fk_object.reference = fk_object
key_fk_object.source_field = fk_object.source_field or key_field
for attr in ("index", "default", "null", "generated", "description"):
setattr(key_fk_object, attr, getattr(fk_object, attr))
if is_o2o:
key_fk_object.pk = fk_object.pk
key_fk_object.unique = fk_object.unique
else:
key_fk_object.pk = False
key_fk_object.unique = False
model._meta.add_field(key_field, key_fk_object)
fk_object.related_model = related_model
fk_object.source_field = key_field
if (backward_relation_name := fk_object.related_name) is not False:
if not backward_relation_name:
backward_relation_name = f"{model._meta.db_table}s"
if backward_relation_name in related_model._meta.fields:
raise ConfigurationError(
f'backward relation "{backward_relation_name}" duplicates in'
f" model {related_model_name}"
)
if is_o2o:
fk_relation: Union[BackwardOneToOneRelation, BackwardFKRelation] = (
BackwardOneToOneRelation(
model,
key_field,
key_fk_object.source_field,
null=True,
description=fk_object.description,
)
)
else:
fk_relation = BackwardFKRelation(
model,
key_field,
key_fk_object.source_field,
null=fk_object.null,
description=fk_object.description,
)
fk_relation.to_field_instance = fk_object.to_field_instance # type:ignore
related_model._meta.add_field(backward_relation_name, fk_relation)
if is_o2o and fk_object.pk:
model._meta.pk_attr = key_field

for app_name, app in cls.apps.items():
for model_name, model in app.items():
Expand All @@ -158,145 +225,16 @@ def split_reference(reference: str) -> Tuple[str, str]:
if not model._meta.db_table:
model._meta.db_table = model.__name__.lower()

# TODO: refactor to share logic between FK & O2O
for field in sorted(model._meta.fk_fields):
fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field])
reference = fk_object.model_name
related_app_name, related_model_name = split_reference(reference)
related_model = get_related_model(related_app_name, related_model_name)

if fk_object.to_field:
related_field = related_model._meta.fields_map.get(fk_object.to_field, None)
if related_field:
if related_field.unique:
key_fk_object = deepcopy(related_field)
fk_object.to_field_instance = related_field # type: ignore
else:
raise ConfigurationError(
f'field "{fk_object.to_field}" in model'
f' "{related_model_name}" is not unique'
)
else:
raise ConfigurationError(
f'there is no field named "{fk_object.to_field}"'
f' in model "{related_model_name}"'
)
else:
key_fk_object = deepcopy(related_model._meta.pk)
fk_object.to_field_instance = related_model._meta.pk # type: ignore
fk_object.to_field = related_model._meta.pk_attr
fk_object.field_type = fk_object.to_field_instance.field_type
key_field = f"{field}_id"
key_fk_object.pk = False
key_fk_object.unique = False
key_fk_object.index = fk_object.index
key_fk_object.default = fk_object.default
key_fk_object.null = fk_object.null
key_fk_object.generated = fk_object.generated
key_fk_object.reference = fk_object
key_fk_object.description = fk_object.description
if fk_object.source_field:
key_fk_object.source_field = fk_object.source_field
else:
key_fk_object.source_field = key_field
model._meta.add_field(key_field, key_fk_object)

fk_object.related_model = related_model
fk_object.source_field = key_field
backward_relation_name = fk_object.related_name
if backward_relation_name is not False:
if not backward_relation_name:
backward_relation_name = f"{model._meta.db_table}s"
if backward_relation_name in related_model._meta.fields:
raise ConfigurationError(
f'backward relation "{backward_relation_name}" duplicates in'
f" model {related_model_name}"
)
fk_relation = BackwardFKRelation(
model,
f"{field}_id",
key_fk_object.source_field,
fk_object.null,
fk_object.description,
)
fk_relation.to_field_instance = fk_object.to_field_instance # type: ignore
related_model._meta.add_field(backward_relation_name, fk_relation)
init_fk_o2o_field(model, field)

for field in model._meta.o2o_fields:
o2o_object = cast(OneToOneFieldInstance, model._meta.fields_map[field])
reference = o2o_object.model_name
related_app_name, related_model_name = split_reference(reference)
related_model = get_related_model(related_app_name, related_model_name)

if o2o_object.to_field:
related_field = related_model._meta.fields_map.get(
o2o_object.to_field, None
)
if related_field:
if related_field.unique:
key_o2o_object = deepcopy(related_field)
o2o_object.to_field_instance = related_field # type: ignore
else:
raise ConfigurationError(
f'field "{o2o_object.to_field}" in model'
f' "{related_model_name}" is not unique'
)
else:
raise ConfigurationError(
f'there is no field named "{o2o_object.to_field}"'
f' in model "{related_model_name}"'
)
else:
key_o2o_object = deepcopy(related_model._meta.pk)
o2o_object.to_field_instance = related_model._meta.pk # type: ignore
o2o_object.to_field = related_model._meta.pk_attr

o2o_object.field_type = o2o_object.to_field_instance.field_type

key_field = f"{field}_id"
key_o2o_object.pk = o2o_object.pk
key_o2o_object.index = o2o_object.index
key_o2o_object.default = o2o_object.default
key_o2o_object.null = o2o_object.null
key_o2o_object.unique = o2o_object.unique
key_o2o_object.generated = o2o_object.generated
key_o2o_object.reference = o2o_object
key_o2o_object.description = o2o_object.description
if o2o_object.source_field:
key_o2o_object.source_field = o2o_object.source_field
else:
key_o2o_object.source_field = key_field
model._meta.add_field(key_field, key_o2o_object)

o2o_object.related_model = related_model
o2o_object.source_field = key_field
backward_relation_name = o2o_object.related_name
if backward_relation_name is not False:
if not backward_relation_name:
backward_relation_name = f"{model._meta.db_table}"
if backward_relation_name in related_model._meta.fields:
raise ConfigurationError(
f'backward relation "{backward_relation_name}" duplicates in'
f" model {related_model_name}"
)
o2o_relation = BackwardOneToOneRelation(
model,
f"{field}_id",
key_o2o_object.source_field,
null=True,
description=o2o_object.description,
)
o2o_relation.to_field_instance = o2o_object.to_field_instance # type: ignore
related_model._meta.add_field(backward_relation_name, o2o_relation)

if o2o_object.pk:
model._meta.pk_attr = key_field
init_fk_o2o_field(model, field, is_o2o=True)

for field in list(model._meta.m2m_fields):
m2m_object = cast(ManyToManyFieldInstance, model._meta.fields_map[field])
if m2m_object._generated:
continue

backward_key = m2m_object.backward_key
if not backward_key:
backward_key = f"{model._meta.db_table}_id"
Expand All @@ -323,11 +261,8 @@ def split_reference(reference: str) -> Tuple[str, str]:

if not m2m_object.through:
related_model_table_name = (
related_model._meta.db_table
if related_model._meta.db_table
else related_model.__name__.lower()
related_model._meta.db_table or related_model.__name__.lower()
)

m2m_object.through = f"{model._meta.db_table}_{related_model_table_name}"

m2m_relation = ManyToManyFieldInstance(
Expand Down
15 changes: 9 additions & 6 deletions tortoise/fields/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ def describe(self, serializable: bool) -> dict:
del desc["db_column"]
return desc

@classmethod
def validate_model_name(cls, model_name: str) -> None:
if len(model_name.split(".")) != 2:
field_type = cls.__name__.replace("Instance", "")
raise ConfigurationError(f'{field_type} accepts model name in format "app.Model"')


class ForeignKeyFieldInstance(RelationalField[MODEL]):
def __init__(
Expand All @@ -316,8 +322,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(None, **kwargs) # type: ignore
if len(model_name.split(".")) != 2:
raise ConfigurationError('Foreign key accepts model name in format "app.Model"')
self.validate_model_name(model_name)
self.model_name = model_name
self.related_name = related_name
if on_delete not in set(OnDelete):
Expand Down Expand Up @@ -359,8 +364,7 @@ def __init__(
on_delete: OnDelete = CASCADE,
**kwargs: Any,
) -> None:
if len(model_name.split(".")) != 2:
raise ConfigurationError('OneToOneField accepts model name in format "app.Model"')
self.validate_model_name(model_name)
super().__init__(model_name, related_name, on_delete, unique=True, **kwargs)


Expand All @@ -385,8 +389,7 @@ def __init__(
# TODO: rename through to through_table
# TODO: add through to use a Model
super().__init__(field_type, **kwargs)
if len(model_name.split(".")) != 2:
raise ConfigurationError('Foreign key accepts model name in format "app.Model"')
self.validate_model_name(model_name)
self.model_name: str = model_name
self.related_name: str = related_name
self.forward_key: str = forward_key or f"{model_name.split('.')[1].lower()}_id"
Expand Down

0 comments on commit a191ea8

Please sign in to comment.