-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_conditional_parser.py
217 lines (166 loc) · 7.19 KB
/
test_conditional_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import pytest
from conditional_parser import ConditionalArgumentParser
import sys
def test_basic_functionality():
"""Test that basic ArgumentParser functionality works."""
parser = ConditionalArgumentParser()
parser.add_argument("--flag", action="store_true")
parser.add_argument("--value", type=int, default=42)
args = parser.parse_args(["--flag", "--value", "123"])
assert args.flag is True
assert args.value == 123
args = parser.parse_args([])
assert args.flag is False
assert args.value == 42
def test_simple_conditional():
"""Test conditional argument with direct value matching."""
parser = ConditionalArgumentParser()
parser.add_argument("--format", choices=["json", "csv"], default="json")
parser.add_conditional("format", "csv", "--delimiter", default=",")
# Test with CSV format
args = parser.parse_args(["--format", "csv", "--delimiter", "|"])
assert args.format == "csv"
assert args.delimiter == "|"
# Test with JSON format (delimiter should not be available)
args = parser.parse_args(["--format", "json"])
assert args.format == "json"
assert not hasattr(args, "delimiter")
# Test default CSV case
args = parser.parse_args(["--format", "csv"])
assert args.delimiter == ","
def test_callable_conditional():
"""Test conditional argument with callable condition."""
parser = ConditionalArgumentParser()
parser.add_argument("--add_conditional", type=str, default="False")
parser.add_conditional(
"add_conditional",
lambda x: x.lower() == "true",
"--extra-arg",
action="store_true",
)
def condition(x):
return x.lower() == "true"
parser.add_conditional(
"add_conditional",
condition,
"--another-arg",
action="store_true",
)
# Test threshold above condition
args = parser.parse_args(["--add_conditional", "True", "--extra-arg"])
assert args.extra_arg
args = parser.parse_args(["--add_conditional", "True", "--another-arg"])
assert args.another_arg
# Test threshold below condition (should raise error if trying to use conditional)
with pytest.raises(SystemExit):
parser.parse_args(["--add_conditional", "False", "--extra-arg"])
with pytest.raises(SystemExit):
parser.parse_args(["--extra-arg"])
def test_hierarchical_conditionals():
"""Test nested conditional arguments."""
parser = ConditionalArgumentParser()
parser.add_argument("--use-model", action="store_true")
parser.add_conditional(
"use_model", True, "--model-type", choices=["cnn", "rnn"], required=True
)
parser.add_conditional("model_type", "cnn", "--kernel-size", type=int, default=3)
parser.add_conditional("model_type", "rnn", "--hidden-size", type=int, default=128)
# Test CNN path
args = parser.parse_args(["--use-model", "--model-type", "cnn", "--kernel-size", "5"])
assert args.use_model is True
assert args.model_type == "cnn"
assert args.kernel_size == 5
args = parser.parse_args(["--use-model", "--model-type", "cnn"])
assert args.kernel_size == 3
assert not hasattr(args, "hidden_size")
# Test RNN path (kernel-size should not be available)
args = parser.parse_args(["--use-model", "--model-type", "rnn"])
assert args.model_type == "rnn"
assert args.hidden_size == 128
assert not hasattr(args, "kernel_size")
def test_required_conditionals():
"""Test behavior of required conditional arguments."""
parser = ConditionalArgumentParser()
parser.add_argument("--use-auth", action="store_true")
parser.add_conditional("use_auth", True, "--username", required=True)
# Should fail without required conditional
with pytest.raises(SystemExit):
parser.parse_args(["--use-auth"])
# Should work with required conditional
args = parser.parse_args(["--use-auth", "--username", "user123"])
assert args.username == "user123"
def test_help_text():
"""Test that help text includes conditionals appropriately."""
parser = ConditionalArgumentParser()
parser.add_argument("--mode", choices=["simple", "advanced"])
parser.add_conditional("mode", "advanced", "--extra-param")
# Test with advanced mode
help_output = _capture_help_output(parser, ["--mode", "advanced", "--help"])
assert "--mode" in help_output
assert "--extra-param" in help_output
# Test with simple mode (should still show all conditionals)
help_output = _capture_help_output(parser, ["--mode", "simple", "--help"])
assert "--mode" in help_output
assert "--extra-param" in help_output
# Test with no mode (should still show all conditionals)
help_output = _capture_help_output(parser, ["--help"])
assert "--mode" in help_output
assert "--extra-param" in help_output
def test_help_disabled():
"""Test that help raises appropriate exception when disabled."""
parser = ConditionalArgumentParser(add_help=False)
parser.add_argument("--mode", choices=["simple", "advanced"])
parser.add_conditional("mode", "advanced", "--extra-param")
with pytest.raises(SystemExit):
parser.parse_args(["--mode", "advanced", "--help"])
def _capture_help_output(parser, args):
"""Helper function to capture help text output."""
old_stdout = sys.stdout
try:
import io
sys.stdout = io.StringIO()
with pytest.raises(SystemExit):
parser.parse_args(args)
return sys.stdout.getvalue()
finally:
sys.stdout = old_stdout
def test_invalid_condition():
"""Test error handling for invalid conditions."""
parser = ConditionalArgumentParser()
parser.add_argument("--value", type=int)
# Test callable with wrong number of arguments
with pytest.raises(ValueError):
parser.add_conditional("value", lambda x, y: x > y, "--flag")
def test_sys_argv_default():
"""Test that the parser uses sys.argv when args=None."""
parser = ConditionalArgumentParser()
parser.add_argument("--test-flag", action="store_true")
# Store original sys.argv
original_argv = sys.argv
try:
# Modify sys.argv temporarily
sys.argv = ["program_name", "--test-flag"]
args = parser.parse_args() # Note: not passing any args here
assert args.test_flag is True
# Test without flag
sys.argv = ["program_name"]
args = parser.parse_args() # Note: not passing any args here
assert args.test_flag is False
# TODO:
# Test sys when no args are provided
sys.argv = []
args = parser.parse_args()
assert args.test_flag is False
finally:
# Restore original sys.argv
sys.argv = original_argv
def test_bad_conditional():
"""Test that the parser raises an error when a conditional argument is not valid."""
parser = ConditionalArgumentParser()
parser.add_argument("--value", action="store_true")
with pytest.raises(ValueError):
# Shouldn't be able to add a conditional argument that already exists
parser.add_conditional("value", True, "--value")
with pytest.raises(ValueError):
# Shouldn't be able to add a conditional argument that is not a string
parser.add_conditional(42, True, "--not-a-real-arg")