Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions src/inline/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(self):
self.check_stmts = []
self.given_stmts = []
self.previous_stmts = []
self.import_stmts = []
self.prev_stmt_type = PrevStmtType.StmtExpr
# the line number of test statement
self.lineno = 0
Expand All @@ -174,10 +175,18 @@ def __init__(self):
self.devices = None
self.globs = {}

def write_imports(self):
import_str = ""
for n in self.import_stmts:
import_str += ExtractInlineTest.node_to_source_code(n) + "\n"
return import_str

def to_test(self):
prefix = "\n"

if self.prev_stmt_type == PrevStmtType.CondExpr:
if self.assume_stmts == []:
return "\n".join(
return prefix.join(
[ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
+ [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts]
)
Expand All @@ -187,11 +196,11 @@ def to_test(self):
)
assume_statement = self.assume_stmts[0]
assume_node = self.build_assume_node(assume_statement, body_nodes)
return "\n".join(ExtractInlineTest.node_to_source_code(assume_node))
return prefix.join(ExtractInlineTest.node_to_source_code(assume_node))

else:
if self.assume_stmts is None or self.assume_stmts == []:
return "\n".join(
return prefix.join(
[ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
+ [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts]
+ [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts]
Expand All @@ -202,7 +211,7 @@ def to_test(self):
)
assume_statement = self.assume_stmts[0]
assume_node = self.build_assume_node(assume_statement, body_nodes)
return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)])
return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)])

def build_assume_node(self, assumption_node, body_nodes):
return ast.If(assumption_node, body_nodes, [])
Expand Down Expand Up @@ -296,6 +305,11 @@ class ExtractInlineTest(ast.NodeTransformer):
arg_timeout_str = "timeout"

assume = "assume"

import_str = "import"
from_str = "from"
as_str = "as"

inline_module_imported = False

def __init__(self):
Expand Down Expand Up @@ -360,6 +374,23 @@ def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]):
inline_test_calls.append(node)
self.collect_inline_test_calls(node.func, inline_test_calls)

def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]):
"""
collect all import calls in the node (should be done first)
"""

while not isinstance(node, ast.Module) and node.parent != None:
node = node.parent

if not isinstance(node, ast.Module):
return

for child in node.children:
if isinstance(child, ast.Import):
import_calls.append(child)
elif isinstance(child, ast.ImportFrom):
import_from_calls.append(child)

def parse_constructor(self, node):
"""
Parse a constructor call.
Expand Down Expand Up @@ -931,8 +962,13 @@ def parse_parameterized_test(self):
parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index)

def parse_inline_test(self, node):
inline_test_calls = []
import_calls = []
import_from_calls = []
inline_test_calls = []

self.collect_inline_test_calls(node, inline_test_calls)
self.collect_import_calls(node, import_calls, import_from_calls)

inline_test_calls.reverse()

if len(inline_test_calls) <= 1:
Expand All @@ -953,14 +989,20 @@ def parse_inline_test(self, node):
self.parse_assume(call)
inline_test_call_index += 1

# "given(a, 1)"
for call in inline_test_calls[inline_test_call_index:]:
if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str:
self.parse_given(call)
inline_test_call_index += 1
if isinstance(call.func, ast.Attribute):
if call.func.attr == self.given_str:
self.parse_given(call)
inline_test_call_index += 1
else:
break

for import_stmt in import_calls:
self.cur_inline_test.import_stmts.append(import_stmt)
for import_stmt in import_from_calls:
self.cur_inline_test.import_stmts.append(import_stmt)


# "check_eq" or "check_true" or "check_false" or "check_neq"
for call in inline_test_calls[inline_test_call_index:]:
# "check_eq(a, 1)"
Expand Down
70 changes: 70 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,76 @@ def m(a):
items, reprec = pytester.inline_genitems(x)
assert len(items) == 0

def test_inline_detects_imports(self, pytester: Pytester):
checkfile = pytester.makepyfile(
"""
from inline import itest
import datetime

def m(a):
b = a + datetime.timedelta(days=365)
itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366))
"""
)
for x in (pytester.path, checkfile):
items, reprec = pytester.inline_genitems(x)
assert len(items) == 1
res = pytester.runpytest()
assert res.ret != 1

def test_inline_detects_import_alias(self, pytester: Pytester):
checkfile = pytester.makepyfile(
"""
from inline import itest
import datetime as dt

def m(a):
b = a + dt.timedelta(days=365)
itest().given(a, dt.timedelta(days=1)).check_eq(b, dt.timedelta(days=366))
"""
)
for x in (pytester.path, checkfile):
items, reprec = pytester.inline_genitems(x)
assert len(items) == 1
res = pytester.runpytest()
assert res.ret != 1

def test_inline_detects_from_imports(self, pytester: Pytester):
checkfile = pytester.makepyfile(
"""
from inline import itest
from enum import Enum

class Choice(Enum):
YES = 0
NO = 1

def m(a):
b = a
itest().given(a, Choice.YES).check_eq(b, Choice.YES)
"""
)
for x in (pytester.path, checkfile):
items, reprec = pytester.inline_genitems(x)
assert len(items) == 1
res = pytester.runpytest()
assert res.ret == 0

def test_fail_on_importing_missing_module(self, pytester: Pytester):
checkfile = pytester.makepyfile(
"""
from inline import itest
from scipy import owijef as st

def m(n, p):
b = st.binom(n, p)
itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p)
"""
)
for x in (pytester.path, checkfile):
items, reprec = pytester.inline_genitems(x)
assert len(items) == 0

def test_inline_malformed_given(self, pytester: Pytester):
checkfile = pytester.makepyfile(
"""
Expand Down
Loading