diff --git a/src/inline/plugin.py b/src/inline/plugin.py index f8ddfc1..f2e3189 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -159,6 +159,7 @@ def __init__(self): self.check_stmts = [] self.given_stmts = [] self.previous_stmts = [] + self.import_stmts = [] self.prev_stmt_type = PrevStmtType.StmtExpr # the line number of test statement self.lineno = 0 @@ -174,10 +175,18 @@ def __init__(self): self.devices = None self.globs = {} + def write_imports(self): + import_str = "" + for n in self.import_stmts: + import_str += ExtractInlineTest.node_to_source_code(n) + "\n" + return import_str + def to_test(self): + prefix = "\n" + if self.prev_stmt_type == PrevStmtType.CondExpr: if self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] ) @@ -187,11 +196,11 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join(ExtractInlineTest.node_to_source_code(assume_node)) + return prefix.join(ExtractInlineTest.node_to_source_code(assume_node)) else: if self.assume_stmts is None or self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] @@ -202,7 +211,7 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)]) + return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)]) def build_assume_node(self, assumption_node, body_nodes): return ast.If(assumption_node, body_nodes, []) @@ -296,6 +305,11 @@ class ExtractInlineTest(ast.NodeTransformer): arg_timeout_str = "timeout" assume = "assume" + + import_str = "import" + from_str = "from" + as_str = "as" + inline_module_imported = False def __init__(self): @@ -360,6 +374,23 @@ def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]): inline_test_calls.append(node) self.collect_inline_test_calls(node.func, inline_test_calls) + def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): + """ + collect all import calls in the node (should be done first) + """ + + while not isinstance(node, ast.Module) and node.parent != None: + node = node.parent + + if not isinstance(node, ast.Module): + return + + for child in node.children: + if isinstance(child, ast.Import): + import_calls.append(child) + elif isinstance(child, ast.ImportFrom): + import_from_calls.append(child) + def parse_constructor(self, node): """ Parse a constructor call. @@ -931,8 +962,13 @@ def parse_parameterized_test(self): parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index) def parse_inline_test(self, node): - inline_test_calls = [] + import_calls = [] + import_from_calls = [] + inline_test_calls = [] + self.collect_inline_test_calls(node, inline_test_calls) + self.collect_import_calls(node, import_calls, import_from_calls) + inline_test_calls.reverse() if len(inline_test_calls) <= 1: @@ -953,14 +989,20 @@ def parse_inline_test(self, node): self.parse_assume(call) inline_test_call_index += 1 - # "given(a, 1)" for call in inline_test_calls[inline_test_call_index:]: - if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str: - self.parse_given(call) - inline_test_call_index += 1 + if isinstance(call.func, ast.Attribute): + if call.func.attr == self.given_str: + self.parse_given(call) + inline_test_call_index += 1 else: break + for import_stmt in import_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + for import_stmt in import_from_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + + # "check_eq" or "check_true" or "check_false" or "check_neq" for call in inline_test_calls[inline_test_call_index:]: # "check_eq(a, 1)" diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 40c3096..953b61f 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -31,6 +31,76 @@ def m(a): items, reprec = pytester.inline_genitems(x) assert len(items) == 0 + def test_inline_detects_imports(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + import datetime + + def m(a): + b = a + datetime.timedelta(days=365) + itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366)) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + def test_inline_detects_import_alias(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + import datetime as dt + + def m(a): + b = a + dt.timedelta(days=365) + itest().given(a, dt.timedelta(days=1)).check_eq(b, dt.timedelta(days=366)) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + def test_inline_detects_from_imports(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + from enum import Enum + + class Choice(Enum): + YES = 0 + NO = 1 + + def m(a): + b = a + itest().given(a, Choice.YES).check_eq(b, Choice.YES) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret == 0 + + def test_fail_on_importing_missing_module(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + from scipy import owijef as st + + def m(n, p): + b = st.binom(n, p) + itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 0 + def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """