-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathtest_optional_subparsers.py
140 lines (102 loc) · 4.39 KB
/
test_optional_subparsers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import collections
import functools
import random
from dataclasses import dataclass
from typing import Union
import pytest
from simple_parsing.helpers.fields import field, subparsers
from simple_parsing.helpers.hparams import uniform
from simple_parsing.helpers.hparams.hyperparameters import HyperParameters
from .testutils import TestSetup
@dataclass
class A:
foo: int = 123
@dataclass
class B:
bar: float = 4.56
class TestWithDefault:
@dataclass
class Options(TestSetup):
config: Union[A, B] = subparsers(
{"a": A, "b": B}, default_factory=functools.partial(A, foo=0)
)
def test_default_is_used_when_no_args_passed(self):
assert self.Options.setup("").config == A(foo=0)
def test_subparsers_work(self):
assert self.Options.setup("a --foo 456").config == A(foo=456)
assert self.Options.setup("b --bar 1.23").config == B(bar=1.23)
class TestWithDefaultFactory:
@dataclass
class Options(TestSetup):
config: Union[A, B] = subparsers(
{"a": A, "b": B}, default_factory=functools.partial(A, foo=0)
)
def test_default_is_used_when_no_args_passed(self):
assert self.Options.setup("").config == A(foo=0)
def test_subparsers_work(self):
assert self.Options.setup("a --foo 456").config == A(foo=456)
assert self.Options.setup("b --bar 1.23").config == B(bar=1.23)
class TestWithoutSubparsersField:
@dataclass
class Options(TestSetup):
config: Union[A, B] = field(default_factory=functools.partial(A, foo=0))
def test_default_is_used_when_no_args_passed(self):
assert self.Options.setup("").config == A(foo=0)
def test_subparsers_work(self):
assert self.Options.setup("a --foo 456").config == A(foo=456)
assert self.Options.setup("b --bar 1.23").config == B(bar=1.23)
class TestWithoutSubparsersFieldNoPartial:
@dataclass
class Options(TestSetup):
config: Union[A, B] = field(
default_factory=functools.partial(A, foo=0),
)
def test_default_is_used_when_no_args_passed(self):
assert self.Options.setup("").config == A(foo=0)
def test_subparsers_work(self):
assert self.Options.setup("a --foo 456").config == A(foo=456)
assert self.Options.setup("b --bar 1.23").config == B(bar=1.23)
def test_nesting_of_optional_subparsers():
@dataclass
class Bob:
config: Union[A, B] = subparsers(
{"a": A, "b": B}, default_factory=functools.partial(A, foo=0)
)
@dataclass
class Clarice:
config: Union[A, B] = subparsers(
{"a": A, "b": B}, default_factory=functools.partial(A, foo=0)
)
@dataclass
class NestedOptions(TestSetup):
friend: Union[Bob, Clarice] = field(default_factory=Bob)
assert NestedOptions.setup("") == NestedOptions()
assert NestedOptions.setup("bob") == NestedOptions(friend=Bob())
assert NestedOptions.setup("bob a") == NestedOptions(friend=Bob(config=A()))
assert NestedOptions.setup("bob a --foo 1") == NestedOptions(friend=Bob(config=A(foo=1)))
assert NestedOptions.setup("bob b") == NestedOptions(friend=Bob(config=B()))
assert NestedOptions.setup("bob b --bar 0.") == NestedOptions(friend=Bob(config=B(bar=0.0)))
assert NestedOptions.setup("clarice") == NestedOptions(friend=Clarice())
assert NestedOptions.setup("clarice a") == NestedOptions(friend=Clarice(config=A()))
assert NestedOptions.setup("clarice a --foo 1") == NestedOptions(
friend=Clarice(config=A(foo=1))
)
assert NestedOptions.setup("clarice b") == NestedOptions(friend=Clarice(config=B()))
assert NestedOptions.setup("clarice b --bar 0.") == NestedOptions(
friend=Clarice(config=B(bar=0.0))
)
class ModelA(HyperParameters):
foo: float = uniform(0, 1, default=0.5)
class ModelB(HyperParameters):
bar: int = uniform(0, 10, default=5, discrete=True)
@dataclass
class Options(HyperParameters):
model: Union[ModelA, ModelB] = field(default_factory=ModelA)
@pytest.mark.parametrize("seed", [123, 456, 789])
def test_sample_with_subparsers_field(seed: int):
random.seed(seed)
samples = [Options.sample() for _ in range(10)]
assert not all(sample == Options() for sample in samples), samples
model_types = [type(Options.sample().model) for _ in range(100)]
assert len(set(model_types)) == 2
assert 40 <= collections.Counter(model_types)[ModelA] <= 60