Skip to content

Commit

Permalink
fix: Generate type hints for compound fields with token elements (#997)
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Mar 24, 2024
1 parent a15c3e7 commit ff01d9b
Show file tree
Hide file tree
Showing 19 changed files with 147 additions and 49 deletions.
2 changes: 1 addition & 1 deletion docs/codegen/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ xsdata relies on the field ordering for serialization. This process fails for re
choice or complex sequence elements. When you enable compound fields, these elements are
grouped into a single field.

```xml show_lines="2:9"
```xml show_lines="2:10"
--8<-- "tests/fixtures/compound/schema.xsd"
```

Expand Down
5 changes: 1 addition & 4 deletions tests/codegen/handlers/test_create_compound_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from xsdata.codegen.models import Restrictions
from xsdata.models.config import GeneratorConfig
from xsdata.models.enums import Tag
from xsdata.utils import collections
from xsdata.utils.testing import (
AttrFactory,
ClassFactory,
Expand Down Expand Up @@ -93,9 +92,7 @@ def test_group_fields(self):
name="choice",
tag="Choice",
index=0,
types=collections.unique_sequence(
t for attr in target.attrs for t in attr.types
),
types=[],
choices=[
AttrFactory.create(
tag=target.attrs[0].tag,
Expand Down
5 changes: 3 additions & 2 deletions tests/codegen/handlers/test_disambiguate_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def test_process_with_duplicate_wildcards(self):
self.assertEqual(4, wildcard.restrictions.max_occurs)

def test_process_with_duplicate_simple_types(self):
compound = AttrFactory.create(tag=Tag.CHOICE, types=[])
compound = AttrFactory.create(tag=Tag.CHOICE)
compound.types.clear()
target = ClassFactory.create()
target.attrs.append(compound)
compound.choices.append(AttrFactory.native(DataType.STRING, name="a"))
Expand All @@ -70,7 +71,7 @@ def test_process_with_duplicate_simple_types(self):
self.assertEqual("a", target.inner[0].qname)
self.assertEqual("{xs}b", target.inner[1].qname)

self.assertEqual(["a", "{xs}b"], [x.qname for x in compound.types])
self.assertEqual([], [x.qname for x in compound.types])

def test_process_with_duplicate_any_types(self):
compound = AttrFactory.create(tag=Tag.CHOICE, types=[])
Expand Down
9 changes: 8 additions & 1 deletion tests/fixtures/compound/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Root:
class Meta:
name = "root"

alpha_or_bravo: List[Union[Alpha, Bravo]] = field(
alpha_or_bravo_or_charlie: List[Union[Alpha, Bravo, List[str]]] = field(
default_factory=list,
metadata={
"type": "Elements",
Expand All @@ -48,6 +48,13 @@ class Meta:
"name": "bravo",
"type": Bravo,
},
{
"name": "charlie",
"type": List[str],
"namespace": "",
"default_factory": list,
"tokens": True,
},
),
},
)
14 changes: 12 additions & 2 deletions tests/fixtures/compound/sample.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"alpha_or_bravo": [
"alpha_or_bravo_or_charlie": [
{
"a": true
},
Expand All @@ -21,8 +21,18 @@
{
"a": true
},
[
"a",
"b",
"c"
],
{
"b": true
}
},
[
"d",
"e",
"f"
]
]
}
12 changes: 11 additions & 1 deletion tests/fixtures/compound/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


obj = Root(
alpha_or_bravo=[
alpha_or_bravo_or_charlie=[
Alpha(

),
Expand All @@ -26,8 +26,18 @@
Alpha(

),
[
'a',
'b',
'c',
],
Bravo(

),
[
'd',
'e',
'f',
],
]
)
2 changes: 2 additions & 0 deletions tests/fixtures/compound/sample.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
<alpha a="true" />
<bravo b="true" />
<alpha a="true" />
<charlie>a b c</charlie>
<bravo b="true" />
<charlie>d e f</charlie>
</root>
2 changes: 2 additions & 0 deletions tests/fixtures/compound/sample.xsdata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@
<alpha a="true"/>
<bravo b="true"/>
<alpha a="true"/>
<charlie>a b c</charlie>
<bravo b="true"/>
<charlie>d e f</charlie>
</root>
40 changes: 22 additions & 18 deletions tests/fixtures/compound/schema.xsd
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
<xsd:schema xmlns:xsd="http://www.w3.org/2001/XMLSchema">
<xsd:element name="root">
<xsd:complexType>
<xsd:choice maxOccurs="unbounded">
<xsd:element ref="alpha" />
<xsd:element ref="bravo" />
</xsd:choice>
</xsd:complexType>
</xsd:element>
<xsd:element name="alpha">
<xsd:complexType>
<xsd:attribute name="a" type="xsd:boolean" fixed="true" />
</xsd:complexType>
</xsd:element>
<xsd:element name="bravo">
<xsd:complexType>
<xsd:attribute name="b" type="xsd:boolean" fixed="true" />
</xsd:complexType>
</xsd:element>
<xsd:element name="root">
<xsd:complexType>
<xsd:choice maxOccurs="unbounded">
<xsd:element ref="alpha"/>
<xsd:element ref="bravo"/>
<xsd:element name="charlie" type="charlie"/>
</xsd:choice>
</xsd:complexType>
</xsd:element>
<xsd:element name="alpha">
<xsd:complexType>
<xsd:attribute name="a" type="xsd:boolean" fixed="true"/>
</xsd:complexType>
</xsd:element>
<xsd:element name="bravo">
<xsd:complexType>
<xsd:attribute name="b" type="xsd:boolean" fixed="true"/>
</xsd:complexType>
</xsd:element>
<xsd:simpleType name="charlie">
<xsd:restriction base="xsd:NMTOKENS"/>
</xsd:simpleType>
</xsd:schema>
31 changes: 31 additions & 0 deletions tests/formats/dataclass/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,37 @@ def test_field_type_with_prohibited_attr(self):

self.assertEqual("Any", self.filters.field_type(attr, ["a", "b"]))

def test_field_type_with_compound_attr(self):
attr = AttrFactory.create(
tag=Tag.CHOICE,
choices=[
AttrFactory.create(
name="a", types=[AttrTypeFactory.native(DataType.STRING)]
),
AttrFactory.create(
name="b", types=[AttrTypeFactory.native(DataType.INT)]
),
AttrFactory.create(
name="c",
types=[AttrTypeFactory.native(DataType.DECIMAL)],
restrictions=Restrictions(tokens=True),
),
],
restrictions=Restrictions(min_occurs=0, max_occurs=1),
)

expected = "Optional[Union[str, int, List[Decimal]]]"
self.assertEqual(expected, self.filters.field_type(attr, []))

attr.restrictions.max_occurs = 2
expected = "List[Union[str, int, List[Decimal]]]"
self.assertEqual(expected, self.filters.field_type(attr, []))

attr.restrictions.min_occurs = attr.restrictions.max_occurs = 1
self.filters.format.kw_only = True
expected = "Union[str, int, List[Decimal]]"
self.assertEqual(expected, self.filters.field_type(attr, []))

def test_choice_type(self):
choice = AttrFactory.create(types=[AttrTypeFactory.create("foobar")])
actual = self.filters.choice_type(choice, ["a", "b"])
Expand Down
7 changes: 5 additions & 2 deletions tests/formats/dataclass/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def test_evaluate_union(self):
Optional[List[Union[int, float]]]: (list, int, float),
Optional[A]: (int, str),
Union[List[int], None]: (list, int),
Union[List[int], List[str]]: False,
Union[Tuple[int, ...], None]: (tuple, int),
Union[List[int], List[str]]: (list, int, list, str),
Union[List[Dict]]: False,
}

if sys.version_info[:2] >= (3, 10):
Expand All @@ -169,7 +171,8 @@ def test_evaluate_union(self):
None | List[int | float]: (list, int, float),
None | A: (int, str),
List[int] | None: (list, int),
List[int] | List[str]: False,
Tuple[int, ...] | None: (tuple, int),
List[int] | List[str]: (list, int, list, str),
}
)

Expand Down
1 change: 0 additions & 1 deletion xsdata/codegen/handlers/create_compound_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def group_fields(self, target: Class, attrs: List[Attr]):
),
choices=choices,
)
ClassUtils.reset_choice_types(compound_attr)
target.attrs.insert(pos, compound_attr)

def sum_counters(self, counters: Dict) -> Tuple[List[int], List[int]]:
Expand Down
3 changes: 0 additions & 3 deletions xsdata/codegen/handlers/disambiguate_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from xsdata.codegen.mixins import ContainerInterface, RelativeHandlerInterface
from xsdata.codegen.models import Attr, AttrType, Class, Extension, Restrictions
from xsdata.codegen.utils import ClassUtils
from xsdata.models.enums import DataType, Tag
from xsdata.utils import collections, text
from xsdata.utils.constants import DEFAULT_ATTR_NAME
Expand Down Expand Up @@ -63,8 +62,6 @@ def process_compound_field(self, target: Class, attr: Attr):
for choice in self.find_ambiguous_choices(attr):
self.disambiguate_choice(target, choice)

ClassUtils.reset_choice_types(attr)

@classmethod
def merge_wildcard_choices(cls, attr: Attr):
"""Merge choices derived from xs:any elements.
Expand Down
9 changes: 1 addition & 8 deletions xsdata/codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_qname,
get_slug,
)
from xsdata.models.enums import DataType, Tag
from xsdata.models.enums import DataType
from xsdata.utils import collections, namespaces, text
from xsdata.utils.constants import DEFAULT_ATTR_NAME

Expand Down Expand Up @@ -454,13 +454,6 @@ def unique_name(cls, name: str, reserved: Set[str]) -> str:

return name

@classmethod
def reset_choice_types(cls, attr: Attr):
"""Reset the choice types."""
if attr.tag == Tag.CHOICE:
types = (tp for choice in attr.choices for tp in choice.types)
attr.types = collections.unique_sequence(x.clone() for x in types)

@classmethod
def cleanup_class(cls, target: Class):
"""Go through the target class attrs and filter their types.
Expand Down
32 changes: 32 additions & 0 deletions xsdata/formats/dataclass/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ObjectType,
OutputFormat,
)
from xsdata.models.enums import Tag
from xsdata.utils import collections, namespaces, text
from xsdata.utils.objects import literal_value

Expand Down Expand Up @@ -745,6 +746,9 @@ def field_type(self, attr: Attr, parents: List[str]) -> str:
if attr.is_prohibited:
return "Any"

if attr.tag == Tag.CHOICE:
return self.compound_field_types(attr, parents)

result = self._field_type_names(attr, parents, choice=False)

iterable_fmt = self._get_iterable_format()
Expand All @@ -767,6 +771,34 @@ def field_type(self, attr: Attr, parents: List[str]) -> str:

return result

def compound_field_types(self, attr: Attr, parents: List[str]):
"""Generate type hint for a compound field.
Args:
attr: The compound attr instance
parents: A list of the parent class names
Returns:
The string representation of the type hint.
"""
results = []
iterable_fmt = self._get_iterable_format()
for choice in attr.choices:
names = self._field_type_names(choice, parents, choice=False)
if choice.is_tokens:
names = iterable_fmt.format(names)
results.append(names)

result = self._join_type_names(results)

if attr.is_list:
return iterable_fmt.format(result)

if attr.is_optional or not self.format.kw_only:
return f"None | {result}" if self.union_type else f"Optional[{result}]"

return result

def choice_type(self, choice: Attr, parents: List[str]) -> str:
"""Generate type hints for the given choice.
Expand Down
1 change: 0 additions & 1 deletion xsdata/formats/dataclass/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
import subprocess
from pathlib import Path
from textwrap import indent
from typing import Iterator, List, Optional

from jinja2 import Environment, FileSystemLoader
Expand Down
13 changes: 12 additions & 1 deletion xsdata/formats/dataclass/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ def build(
f"Xml {xml_type} does not support typing `{type_hint}`"
)

if xml_type == XmlType.ELEMENTS:
sub_origin = None
types = (object,)

local_name = local_name or self.build_local_name(xml_type, name)

if tokens and sub_origin is None:
Expand Down Expand Up @@ -586,7 +590,11 @@ def is_any_type(cls, types: Sequence[Type], xml_type: str) -> bool:

@classmethod
def analyze_types(
cls, model: Type, name: str, type_hint: Any, globalns: Any
cls,
model: Type,
name: str,
type_hint: Any,
globalns: Any,
) -> Tuple[Any, Any, Tuple[Type, ...]]:
"""Analyze a type hint and return the origin, sub origin and the type args.
Expand Down Expand Up @@ -647,6 +655,9 @@ def is_valid(
# Any type, secondary types are not allowed except for 'Elements' XML type
return len(types) == 1

if xml_type == XmlType.ELEMENTS:
return True

return self.is_typing_supported(types)

def is_typing_supported(self, types: Sequence[Type]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion xsdata/formats/dataclass/models/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xsdata.formats.converter import converter
from xsdata.models.enums import NamespaceType
from xsdata.utils import collections
from xsdata.utils.namespaces import build_qname, local_name, target_uri
from xsdata.utils.namespaces import build_qname, target_uri

NoneType = type(None)

Expand Down

0 comments on commit ff01d9b

Please sign in to comment.