Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scala sealed trait ability (mostly) in dataconf using dataclass #10

Merged
merged 17 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,42 @@ class Config:

print(dataconf.load(conf, Config))
# TestConf(test='pc.home', float=2.1, default='hello', list=['a', 'b'], nested=Nested(a='test'), nested_list=[Nested(a='test1')], duration=relativedelta(seconds=+2), default_factory={}, union=1)

# Replicating pureconfig Scala sealed trait case class behavior
# https://pureconfig.github.io/docs/overriding-behavior-for-sealed-families.html
class InputType:
@dataclass(init=True, repr=True)
class StringImpl:
name: Text
age: Text

def test_method(self):
print(f"{self.name} is {self.age} years old.")

@dataclass(init=True, repr=True)
class IntImpl:
area_code: int
phone_num: Text

def test_method(self):
print(f"The area code for {self.phone_num} is {str(self.area_code)}")

@dataclass
class Base:
location: Text
input_source: Union[InputType.StringImpl, InputType.IntImpl]

str_conf = """
{
location: Europe
input_source {
name: Thailand
age: "12"
}
}
"""

conf = dataconf.loads(str_conf, Base)
```

```python
Expand Down
3 changes: 2 additions & 1 deletion dataconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from dataconf.utils import dumps
from dataconf.utils import load
from dataconf.utils import loads
from dataconf.version import __version__

__all__ = ["load", "loads", "dump", "dumps"]
__all__ = ["load", "loads", "dump", "dumps", "__version__"]
20 changes: 20 additions & 0 deletions dataconf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def __parse(value: any, clazz, path):
if clazz is ConfigTree:
return __parse_type(value, clazz, path, isinstance(value, ConfigTree))

# Todo: this should be cleaner
# iterates through class dict to check for subclasses
# if subclasses are a dataclass parse values
# the idea here is to replicate parsing of a sealed trait in Scala
# when using pureconfig
child_failures = []
for child_clazz in sorted(clazz.__subclasses__(), key=lambda c: c.__name__):
if is_dataclass(child_clazz):
try:
return __parse(value, child_clazz, path)
except TypeConfigException as f:
child_failures.append(str(f))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zifeo also changed f to str(f) for the join. We will need a test for failure too to make sure all is working.


# no need to check length; false if empty
if child_failures:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I removed the length check @zifeo

fails = "\n- ".join(child_failures)
raise TypeConfigException(
f"expected type {clazz} at {path}, failed subclasses:{fails}"
)

raise TypeConfigException(f"expected type {clazz} at {path}, got {type(value)}")


Expand Down
3 changes: 3 additions & 0 deletions dataconf/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import importlib.metadata

__version__ = importlib.metadata.version("dataconf")
dwsmith1983 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dataconf"
version = "0.1.5"
version = "0.1.6"
description = "Lightweight configuration with automatic dataclasses parsing (HOCON/JSON/YAML/PROPERTIES)"
authors = []
license = "Apache2"
Expand Down
30 changes: 30 additions & 0 deletions tests/scala_sealed_trait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from dataclasses import dataclass
from typing import Text


class InputType:
pass


@dataclass(init=True, repr=True)
class StringImpl(InputType):
name: Text
age: Text

def test_method(self):
return f"{self.name} is {self.age} years old."

def test_complex(self):
return int(self.age) * 3


@dataclass(init=True, repr=True)
class IntImpl(InputType):
area_code: int
phone_num: Text

def test_method(self):
return f"The area code for {self.phone_num} is {str(self.area_code)}"

def test_complex(self):
return self.area_code - 10
51 changes: 51 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from dataconf.exceptions import UnexpectedKeysException
from dateutil.relativedelta import relativedelta
import pytest
from tests.scala_sealed_trait import InputType
from tests.scala_sealed_trait import IntImpl
from tests.scala_sealed_trait import StringImpl


PARENT_DIR = os.path.normpath(
Expand Down Expand Up @@ -241,3 +244,51 @@ class Base:
production=True,
conn=Conn(host="test.server.io", port=443),
)

def test_traits_string_impl(self) -> None:
@dataclass
class Base:
location: Text
input_source: InputType

str_conf = """
{
location: Europe
input_source {
name: Thailand
age: "12"
}
}
"""

conf = loads(str_conf, Base)
assert conf == Base(
location="Europe",
input_source=StringImpl(name="Thailand", age="12"),
)
assert conf.input_source.test_method() == "Thailand is 12 years old."
assert conf.input_source.test_complex() == 36

def test_traits_int_impl(self) -> None:
@dataclass
class Base:
location: Text
input_source: InputType

str_conf = """
{
location: Europe
input_source {
area_code: 94
phone_num: "1234567"
}
}
"""

conf = loads(str_conf, Base)
assert conf == Base(
location="Europe",
input_source=IntImpl(area_code=94, phone_num="1234567"),
)
assert conf.input_source.test_method() == "The area code for 1234567 is 94"
assert conf.input_source.test_complex() == 84