Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: All non-native types must have a reference id #966

Merged
merged 1 commit into from Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/codegen/handlers/test_process_attributes_types.py
Expand Up @@ -259,6 +259,8 @@ def test_process_inner_type_with_circular_reference(
attr_type = AttrTypeFactory.create(circular=True)

self.processor.process_inner_type(target, attr, attr_type)

self.assertEqual(target.ref, attr_type.reference)
self.assertEqual(0, mock_copy_attribute_properties.call_count)
self.assertEqual(0, mock_update_restrictions.call_count)

Expand Down
7 changes: 5 additions & 2 deletions tests/codegen/handlers/test_unnest_inner_classes.py
Expand Up @@ -77,11 +77,14 @@ def test_update_types(self):
AttrTypeFactory.create(qname="b", forward=False),
]
)
source = ClassFactory.create()

self.processor.update_types(attr, "a", "c")
self.processor.update_types(attr, "a", source)

self.assertEqual("c", attr.types[0].qname)
self.assertEqual(source.qname, attr.types[0].qname)
self.assertFalse(attr.types[0].forward)
self.assertEqual(source.ref, attr.types[0].reference)

self.assertEqual("a", attr.types[1].qname)
self.assertFalse(attr.types[1].forward)
self.assertEqual("b", attr.types[2].qname)
Expand Down
33 changes: 11 additions & 22 deletions tests/codegen/test_analyzer.py
Expand Up @@ -8,7 +8,6 @@
from xsdata.utils.testing import (
AttrFactory,
ClassFactory,
ExtensionFactory,
FactoryTestCase,
)

Expand Down Expand Up @@ -36,26 +35,7 @@ def test_process(
mock_container_process.assert_called_once_with()
mock_validate_references.assert_called_once_with(classes)

def test_class_references(self):
target = ClassFactory.elements(
2,
inner=ClassFactory.list(2, attrs=AttrFactory.list(1)),
extensions=ExtensionFactory.list(1),
)

actual = ClassAnalyzer.class_references(target)
# +1 target
# +2 attrs
# +2 attr types
# +1 extension
# +1 extension type
# +2 inner classes
# +2 inner classes attrs
# +2 inner classes attr types
self.assertEqual(13, len(actual))
self.assertEqual(id(target), actual[0])

def test_validate_references(self):
def test_validate_with_cross_references(self):
first = ClassFactory.elements(2)
second = ClassFactory.create(attrs=first.attrs)

Expand All @@ -64,4 +44,13 @@ def test_validate_references(self):
with self.assertRaises(AnalyzerValueError) as cm:
ClassAnalyzer.validate_references([first, second])

self.assertEqual("Cross references detected!", str(cm.exception))
self.assertEqual("Cross reference detected", str(cm.exception))

def test_validate_unresolved_references(self):
first = ClassFactory.create()
first.attrs.append(AttrFactory.reference("foo"))

with self.assertRaises(AnalyzerValueError) as cm:
ClassAnalyzer.validate_references([first])

self.assertEqual("Unresolved reference", str(cm.exception))
79 changes: 47 additions & 32 deletions xsdata/codegen/analyzer.py
@@ -1,7 +1,12 @@
from typing import List
from dataclasses import fields
from typing import Iterator, List, Tuple

from xsdata.codegen.container import ClassContainer
from xsdata.codegen.models import Class
from xsdata.codegen.models import (
AttrType,
Class,
CodegenModel,
)
from xsdata.codegen.validator import ClassValidator
from xsdata.exceptions import AnalyzerValueError

Expand Down Expand Up @@ -33,44 +38,54 @@ def process(cls, container: ClassContainer) -> List[Class]:
return classes

@classmethod
def class_references(cls, target: Class) -> List[int]:
"""Produce a list of instance references for the given class.
def validate_references(cls, classes: List[Class]):
"""Validate codegen object references.

#Todo - Add details on these exceptions

Collect the ids of the class, attr, extension and inner instances.
Rules:
1. No shared codegen objects between classes
2. All attr types must have a reference id, except natives

Args:
target: The target class instance
classes: The list of classes to be generated.

List:
The list of id references.
Raises:
AnalyzerValueError: If an object violates the rules.
"""
result = [id(target)]
for attr in target.attrs:
result.append(id(attr))
result.extend(id(attr_type) for attr_type in attr.types)

for extension in target.extensions:
result.append(id(extension))
result.append(id(extension.type))

for inner in target.inner:
result.extend(cls.class_references(inner))

return result
seen = set()
for target in classes:
for objects in cls.codegen_models(target):
child = objects[-1]
ref = id(child)
if ref in seen:
raise AnalyzerValueError("Cross reference detected")

if (
isinstance(child, AttrType)
and not child.reference
and not child.native
):
raise AnalyzerValueError("Unresolved reference")

seen.add(ref)

@classmethod
def validate_references(cls, classes: List[Class]):
"""Validate all codegen objects are not cross-referenced.

This validation ensures we never share any attr, or extension
between classes.
def codegen_models(cls, *args: CodegenModel) -> Iterator[Tuple[CodegenModel, ...]]:
"""Find and yield all children codegen models.

Args:
classes: The list of classes to be generated.
*args: The codegen objects path.

Raises:
AnalyzerValueError: If an object is shared between the classes.
Yields:
A tuple of codegen models like a path e.g. class, attr, attr_type
"""
references = [ref for obj in classes for ref in cls.class_references(obj)]
if len(references) != len(set(references)):
raise AnalyzerValueError("Cross references detected!")
yield args
model = args[-1]
for f in fields(model):
value = getattr(model, f.name)
if isinstance(value, list) and value and isinstance(value[0], CodegenModel):
for val in value:
yield from cls.codegen_models(*args, val)
elif isinstance(value, CodegenModel):
yield *args, value
6 changes: 4 additions & 2 deletions xsdata/codegen/handlers/create_compound_fields.py
Expand Up @@ -133,7 +133,9 @@ def group_fields(self, target: Class, attrs: List[Attr]):

min_occurs, max_occurs = self.sum_counters(counters)
name = self.choose_name(target, names, list(filter(None, substitutions)))
types = collections.unique_sequence(t for attr in attrs for t in attr.types)
types = collections.unique_sequence(
t.clone() for attr in attrs for t in attr.types
)

target.attrs.insert(
pos,
Expand Down Expand Up @@ -256,7 +258,7 @@ def build_attr_choice(cls, attr: Attr) -> Attr:
return Attr(
name=attr.local_name,
namespace=attr.namespace,
types=attr.types,
types=[x.clone() for x in attr.types],
tag=attr.tag,
help=attr.help,
restrictions=restrictions,
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/disambiguate_choices.py
Expand Up @@ -64,7 +64,7 @@ def process_compound_field(self, target: Class, attr: Attr):

if attr.tag == Tag.CHOICE:
types = (tp for choice in attr.choices for tp in choice.types)
attr.types = collections.unique_sequence(types)
attr.types = collections.unique_sequence(x.clone() for x in types)

@classmethod
def merge_wildcard_choices(cls, attr: Attr):
Expand Down
7 changes: 5 additions & 2 deletions xsdata/codegen/handlers/flatten_class_extensions.py
Expand Up @@ -109,8 +109,11 @@ def process_enum_extension(
# the target enumeration, mypy doesn't play nicely.
target.attrs.clear()

if extension and target.is_enumeration:
target.extensions.remove(extension)
if extension:
if target.is_enumeration:
target.extensions.remove(extension)
else:
extension.type.reference = source.ref

@classmethod
def merge_enumerations(cls, source: Class, target: Class):
Expand Down
3 changes: 3 additions & 0 deletions xsdata/codegen/handlers/process_attributes_types.py
Expand Up @@ -154,6 +154,7 @@ def process_inner_type(self, target: Class, attr: Attr, attr_type: AttrType):
attr_type: The attr type instance
"""
if attr_type.circular:
attr_type.reference = target.ref
return

inner = self.container.find_inner(target, attr_type.qname)
Expand All @@ -166,6 +167,8 @@ def process_inner_type(self, target: Class, attr: Attr, attr_type: AttrType):
):
self.copy_attribute_properties(inner, target, attr, attr_type)
target.inner.remove(inner)
else:
attr_type.reference = inner.ref

def process_dependency_type(self, target: Class, attr: Attr, attr_type: AttrType):
"""Process an attr type that depends on any global type.
Expand Down
9 changes: 5 additions & 4 deletions xsdata/codegen/handlers/unnest_inner_classes.py
Expand Up @@ -42,7 +42,7 @@ def promote(self, target: Class, inner: Class):
attr = self.find_forward_attr(target, inner.qname)
if attr:
clone = self.clone_class(inner, target.name)
self.update_types(attr, inner.qname, clone.qname)
self.update_types(attr, inner.qname, clone)
self.container.add(clone)

@classmethod
Expand All @@ -65,17 +65,18 @@ def clone_class(cls, inner: Class, name: str) -> Class:
return clone

@classmethod
def update_types(cls, attr: Attr, search: str, replace: str):
def update_types(cls, attr: Attr, search: str, source: Class):
"""Update the references from an inner to a global class.

Args:
attr: The target attr to inspect and update
search: The current inner class qname
replace: The new global class qname
source: The new global class qname
"""
for attr_type in attr.types:
if attr_type.qname == search and attr_type.forward:
attr_type.qname = replace
attr_type.qname = source.qname
attr_type.reference = source.ref
attr_type.forward = False

@classmethod
Expand Down
14 changes: 12 additions & 2 deletions xsdata/codegen/mappers/definitions.py
Expand Up @@ -145,7 +145,11 @@ def map_binding_operation(
# Only Envelope classes need to be added in service input/output
if message_class.meta_name:
message_type = message_class.name.split("_")[-1]
attrs.append(cls.build_attr(message_type, message_class.qname))
attrs.append(
cls.build_attr(
message_type, message_class.qname, reference=id(message_class)
)
)

assert binding_operation.location is not None

Expand Down Expand Up @@ -516,6 +520,7 @@ def build_attr(
forward: bool = False,
namespace: Optional[str] = None,
default: Optional[str] = None,
reference: int = 0,
) -> Attr:
"""Helper method to build an attr instance.

Expand All @@ -526,6 +531,7 @@ def build_attr(
forward: Whether the type is a forward reference
namespace: The attr namespace
default: The attr default value
reference: The class id reference, if any

Returns:
The new attr instance.
Expand All @@ -539,6 +545,10 @@ def build_attr(
name=name,
namespace=namespace,
default=default,
types=[AttrType(qname=qname, forward=forward, native=native)],
types=[
AttrType(
qname=qname, forward=forward, native=native, reference=reference
)
],
restrictions=Restrictions(min_occurs=occurs, max_occurs=occurs),
)
15 changes: 10 additions & 5 deletions xsdata/codegen/models.py
Expand Up @@ -29,7 +29,12 @@


@dataclass
class Restrictions:
class CodegenModel:
"""Base codegen model."""


@dataclass
class Restrictions(CodegenModel):
"""Class field validation restrictions.

Args:
Expand Down Expand Up @@ -196,7 +201,7 @@ def from_element(cls, element: ElementBase) -> "Restrictions":


@dataclass(unsafe_hash=True)
class AttrType:
class AttrType(CodegenModel):
"""Class field typing information.

Args:
Expand Down Expand Up @@ -249,7 +254,7 @@ def clone(self, **kwargs: Any) -> "AttrType":


@dataclass
class Attr:
class Attr(CodegenModel):
"""Class field model representation.

Args:
Expand Down Expand Up @@ -435,7 +440,7 @@ def can_be_restricted(self) -> bool:


@dataclass(unsafe_hash=True)
class Extension:
class Extension(CodegenModel):
"""Base class model representation.

Args:
Expand Down Expand Up @@ -474,7 +479,7 @@ class Status(IntEnum):


@dataclass
class Class:
class Class(CodegenModel):
"""Class model representation.

Args:
Expand Down
2 changes: 2 additions & 0 deletions xsdata/codegen/utils.py
Expand Up @@ -197,12 +197,14 @@ def copy_inner_class(cls, source: Class, target: Class, attr_type: AttrType):
inner = ClassUtils.find_inner(source, attr_type.qname)
if inner is target:
attr_type.circular = True
attr_type.reference = target.ref
else:
# In extreme cases this adds duplicate inner classes
clone = inner.clone()
clone.package = target.package
clone.module = target.module
clone.status = Status.RAW
attr_type.reference = clone.ref
target.inner.append(clone)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion xsdata/exceptions.py
Expand Up @@ -47,7 +47,7 @@ class DefinitionsValueError(ValueError):


class AnalyzerValueError(ValueError):
"""Unhandled behaviour during class analyze process.."""
"""Unhandled behaviour during class analyze process."""


class ResolverValueError(ValueError):
Expand Down