This repository was archived by the owner on Aug 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathtest_docstrings.py
322 lines (273 loc) · 9.93 KB
/
test_docstrings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
import io
import sys
import doctest
import inspect
import pathlib
import asyncio
import logging
import unittest
import tempfile
import platform
import importlib
import contextlib
from typing import Optional, Callable
from dffml import Features, Feature
from dffml.df.types import DataFlow, Input
from dffml.df.memory import MemoryOrchestrator
from dffml.noasync import train
from dffml.model.slr import SLRModel
from dffml.util.asynctestcase import AsyncTestCase
from dffml.util.testing.docs import run_consoletest
from dffml.util.testing.consoletest.parser import parse_nodes
from dffml.db.sqlite import SqliteDatabase, SqliteDatabaseConfig
from dffml.operation.db import db_query_create_table, DatabaseQueryConfig
def modules(
root: pathlib.Path,
package_name: str,
*,
skip: Optional[Callable[[str, pathlib.Path], bool]] = None,
):
for path in root.rglob("*.py"):
# Figure out name
import_name = pathlib.Path(str(path)[len(str(root)) :]).parts[1:]
import_name = (
package_name
+ "."
+ ".".join(
list(import_name[:-1]) + [import_name[-1].replace(".py", "")]
)
)
# Check if we should skip importing this file
if skip and skip(import_name, path):
continue
# Import module
yield import_name, importlib.import_module(import_name)
root = pathlib.Path(__file__).parent.parent / "dffml"
skel = root / "skel"
package_name = "dffml"
# Skip any files in skel and __main__.py and __init__.py
skip = lambda _import_name, path: skel in path.parents or path.name.startswith(
"__"
)
# All classes to test
to_test = {}
@contextlib.contextmanager
def tempdir(state):
with tempfile.TemporaryDirectory() as new_cwd:
try:
orig_cwd = os.getcwd()
os.chdir(new_cwd)
yield
finally:
os.chdir(orig_cwd)
def wrap_operation_io_AcceptUserInput(state):
with unittest.mock.patch(
"builtins.input", return_value="Data flow is awesome"
):
yield
def wrap_high_level_accuracy(state):
model = SLRModel(
features=Features(Feature("Years", int, 1),),
predict=Feature("Salary", int, 1),
location="tempdir",
)
train(
model,
{"Years": 0, "Salary": 10},
{"Years": 1, "Salary": 20},
{"Years": 2, "Salary": 30},
{"Years": 3, "Salary": 40},
)
yield
wrap_high_level_predict = wrap_high_level_accuracy
def wrap_noasync_accuracy(state):
model = SLRModel(
features=Features(Feature("Years", int, 1),),
predict=Feature("Salary", int, 1),
location="tempdir",
)
train(
model,
{"Years": 0, "Salary": 10},
{"Years": 1, "Salary": 20},
{"Years": 2, "Salary": 30},
{"Years": 3, "Salary": 40},
)
yield
wrap_noasync_predict = wrap_noasync_accuracy
async def operation_db():
"""
Create the database and table (myTable) for the db operations
"""
sdb = SqliteDatabase(SqliteDatabaseConfig(filename="examples.db"))
dataflow = DataFlow(
operations={"db_query_create": db_query_create_table.op},
configs={"db_query_create": DatabaseQueryConfig(database=sdb)},
seed=[],
)
inputs = [
Input(
value="myTable",
definition=db_query_create_table.op.inputs["table_name"],
),
Input(
value={
"key": "INTEGER NOT NULL PRIMARY KEY",
"firstName": "text",
"lastName": "text",
"age": "int",
},
definition=db_query_create_table.op.inputs["cols"],
),
]
async for ctx, result in MemoryOrchestrator.run(dataflow, inputs):
pass
def wrap_operation_db(state):
asyncio.run(operation_db())
yield
def wrap_operation_db_db_query_lookup(state):
run_doctest(operation_db_db_query_insert.obj, state)
run_doctest(operation_db_db_query_insert_or_update.obj, state, check=False)
yield
def wrap_operation_db_db_query_update(state):
run_doctest(operation_db_db_query_insert.obj, state)
yield
def run_doctest(obj, state, check=True):
finder = doctest.DocTestFinder(verbose=state["verbose"], recurse=False)
runner = doctest.DocTestRunner(verbose=state["verbose"])
for test in finder.find(obj, obj.__qualname__, globs=state["globs"]):
output = io.StringIO()
results = runner.run(test, out=output.write)
if results.failed and check:
raise Exception(output.getvalue())
def mktestcase(name, import_name, module, obj):
# Global variables for the doctest
state = {
"globs": {},
"name": name,
"obj": obj,
"import_name": import_name,
"module": module,
"verbose": os.environ.get("LOGGING", "").lower() == "debug",
}
# Check if there is a function within this file which will be used to do
# extra setup and tear down for the test. Its the same name as the test but
# prefixed with wrap_. Also look all the way up the path for wrap_ functions
name = name.split(".")
extra_context = []
for i in range(0, len(name)):
wrapper_name = "wrap_" + "_".join(name[: i + 1])
wrapper = sys.modules[__name__].__dict__.get(wrapper_name, False)
if wrapper:
extra_context.append(contextlib.contextmanager(wrapper))
# The test case itself, assigned to test_doctest of each class
def testcase(self):
if state["verbose"]:
logging.basicConfig(level=logging.DEBUG)
with contextlib.ExitStack() as stack:
# Create tempdir for the test
stack.enter_context(tempdir(state))
# Do all test specific setup
for wrapper in extra_context:
stack.enter_context(wrapper(state))
# Run the doctest
run_doctest(obj, state)
return testcase
def mkconsoletest(_name, _import_name, _module, obj):
async def test_consoletest(self):
await run_consoletest(
obj, docs_root_dir=pathlib.Path(__file__).parents[1] / "docs",
)
return test_consoletest
def recurse_properties(discovered, import_name, module, prefix, parent):
if inspect.ismodule(parent) or inspect.isclass(parent):
for name, obj in inspect.getmembers(parent):
if inspect.ismodule(parent):
# Skip if not a class or function
if (
not hasattr(obj, "__module__")
or not obj.__module__.startswith(import_name)
or (
not inspect.isclass(obj)
and not inspect.isfunction(obj)
)
):
continue
# Add to dict to ensure no duplicates
discovered[".".join([prefix, obj.__qualname__])] = (
prefix,
module,
obj,
)
recurse_properties(
discovered,
import_name,
module,
".".join([prefix, obj.__qualname__]),
obj,
)
if inspect.isclass(parent):
# Skip if not a class or function
if (
not hasattr(obj, "__module__")
or obj.__module__ is None
or not obj.__module__.startswith(import_name)
or (
not inspect.isclass(obj)
and not inspect.isfunction(obj)
and not inspect.ismethod(obj)
)
):
continue
# Add to dict to ensure no duplicates
discovered[".".join([prefix, obj.__qualname__])] = (
prefix,
module,
obj,
)
recurse_properties(
discovered,
import_name,
module,
".".join([prefix, obj.__qualname__]),
obj,
)
# Iterate over all of the objects in the module
for import_name, module in modules(root, package_name, skip=skip):
recurse_properties(to_test, import_name, module, import_name, module)
for name, (import_name, module, obj) in to_test.items():
# Check that class or function has an example that could be doctested
docstring = inspect.getdoc(obj)
# Remove the package name from the Python style path to the object
name = name[len(package_name) + 1 :]
# Create a dictionary to hold the test case functions of the AsyncTestCase
# class we're going to create
test_cases = {}
# Add a doctest testcase if there are any lines to doctest
if docstring is not None and ">>>" in docstring:
test_cases["test_docstring"] = mktestcase(
name, import_name, module, obj
)
# Add a consoletest testcase if there are any testable rst nodes
if docstring is not None and [
node for node in parse_nodes(docstring) if "test" in node.options
]:
test_cases["test_consoletest"] = mkconsoletest(
name, import_name, module, obj
)
# Only create the instance of AsyncTestCase if the object's docstring holds
# anything that could have a testcase made out of it
if not test_cases:
continue
# Create the test case class with the object as a property and test cases
testcase = type(
name.replace(".", "_"), (AsyncTestCase,), {"obj": obj, **test_cases}
)
# Create the name of the class using the path to it and the object name
# Add the class to this file's globals
setattr(sys.modules[__name__], testcase.__qualname__, testcase)
cli_cli_Version_Version_git_hash.test_docstring = unittest.skipIf(
platform.system() == "Windows",
"Test cleanup doesn't seem to work on Windows",
)(cli_cli_Version_Version_git_hash.test_docstring)