diff --git a/tests/test_collapse_literals.py b/tests/test_collapse_literals.py index aae7068..31e0e27 100644 --- a/tests/test_collapse_literals.py +++ b/tests/test_collapse_literals.py @@ -30,43 +30,49 @@ def f(y): self.assertEqual(f(-1), deco_f(-1)) def test_basic(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals def f(): return 1 + 1 - result = dedent(''' + result = ''' def f(): return 2 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(), 2) def test_vars(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals def f(): x = 3 y = 2 return x + y - result = dedent(''' + result = ''' def f(): x = 3 y = 2 return 5 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(), 5) def test_partial(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals def f(y): x = 3 return x + 2 + y - result = dedent(''' + result = ''' def f(y): x = 3 return 5 + y - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(5), 10) def test_constant_index(self): @pragma.collapse_literals @@ -79,23 +85,25 @@ def f(): x = [1, 2, 3] return 1 ''' + self.assertSourceEqual(f, result) self.assertEqual(f(), 1) def test_with_unroll(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals @pragma.unroll def f(): for i in range(3): print(i + 2) - result = dedent(''' + result = ''' def f(): print(2) print(3) print(4) - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) # # TODO: Figure out variable levels of specificity... # def test_with_objects(self): @@ -170,20 +178,22 @@ def f(y): # self.assertEqual(f.strip(), result.strip()) def test_constant_conditional_erasure(self): - @pragma.collapse_literals(return_source=True) - def f(y): + @pragma.collapse_literals + def f(): x = 0 if x <= 0: x = 1 return x - result = dedent(''' - def f(y): + result = ''' + def f(): x = 0 x = 1 return 1 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(), 1) def fn(): if x == 0: @@ -209,31 +219,36 @@ def fn(): x = 'c' return 'c' ''') - self.assertEqual(pragma.collapse_literals(return_source=True, x=0)(fn).strip(), result0.strip()) - self.assertEqual(pragma.collapse_literals(return_source=True, x=1)(fn).strip(), result1.strip()) - self.assertEqual(pragma.collapse_literals(return_source=True, x=2)(fn).strip(), result2.strip()) + + self.assertSourceEqual(pragma.collapse_literals(x=0)(fn), result0) + self.assertSourceEqual(pragma.collapse_literals(x=1)(fn), result1) + self.assertSourceEqual(pragma.collapse_literals(x=2)(fn), result2) def test_unary(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals def f(): return 1 + -5 - result = dedent(''' + result = ''' def f(): return -4 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(), -4) def test_funcs(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals def f(): return sum(range(5)) - result = dedent(''' + result = ''' def f(): return 10 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(), 10) def test_funcs2(self): my_list = [1, 2, 3] @@ -272,12 +287,13 @@ def f(): ((y, x), z) = ((1, 2), 3) return x - result = dedent(''' + result = ''' def f(): x = 3 (y, x), z = (1, 2), 3 return 2 - ''') + ''' + self.assertSourceEqual(f, result) self.assertEqual(f(), 2) diff --git a/tests/test_deindex.py b/tests/test_deindex.py index 00d5206..323ec42 100644 --- a/tests/test_deindex.py +++ b/tests/test_deindex.py @@ -8,6 +8,7 @@ class TestDeindex(PragmaTest): def test_with_literals(self): v = [1, 2, 3] + @pragma.collapse_literals @pragma.deindex(v, 'v') def f(): @@ -23,6 +24,7 @@ def f(): def test_with_objects(self): v = [object(), object(), object()] + @pragma.deindex(v, 'v') def f(): return v[0] + v[1] + v[2] @@ -34,13 +36,28 @@ def f(): self.assertSourceEqual(f, result) + def test_with_objects_same_instance(self): + v = [object(), object(), object()] + + @pragma.deindex(v, 'v') + def f(): + return v[0] + + result = ''' + def f(): + return v_0 + ''' + + self.assertSourceEqual(f, result) + self.assertIs(f(), v[0]) + def test_with_unroll(self): v = [None, None, None] @pragma.deindex(v, 'v', return_source=True) - @pragma.unroll(lv=len(v)) + @pragma.unroll def f(): - for i in range(lv): + for i in range(len(v)): yield v[i] result = dedent(''' @@ -49,29 +66,24 @@ def f(): yield v_1 yield v_2 ''') - self.assertEqual(f.strip(), result.strip()) - def test_with_objects_run(self): - v = [object(), object(), object()] - @pragma.deindex(v, 'v') - def f(): - return v[0] - - self.assertEqual(f(), v[0]) + self.assertSourceEqual(f, result) def test_with_variable_indices(self): v = [object(), object(), object()] - @pragma.deindex(v, 'v', return_source=True) + + @pragma.deindex(v, 'v') def f(x): yield v[0] yield v[x] - result = dedent(''' + result = ''' def f(x): yield v_0 yield v[x] - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_dict(self): d = {'a': 1, 'b': 2} @@ -108,7 +120,7 @@ def run_func(i, x): self.assertEqual(run_func(1, 5), 25) self.assertEqual(run_func(2, 5), 125) - result = dedent(''' + result = ''' def run_func(i, x): if i == 0: return funcs_0(x) @@ -116,8 +128,9 @@ def run_func(i, x): return funcs_1(x) if i == 2: return funcs_2(x) - ''') - self.assertEqual(inspect.getsource(run_func).strip(), result.strip()) + ''' + + self.assertSourceEqual(run_func, result) def test_len(self): a = ['a', 'b', 'c'] diff --git a/tests/test_inline.py b/tests/test_inline.py index 09b5099..6c4ed7c 100644 --- a/tests/test_inline.py +++ b/tests/test_inline.py @@ -60,7 +60,9 @@ def g(x, *args, y, **kwargs): def f(): g(1, 2, 3, 4, y=5, z=6, w=7) - result1 = dedent(''' + inline_f = pragma.inline(g)(f) + + result1 = ''' def f(): _g_0 = dict(x=1, args=(2, 3, 4), y=5, kwargs={'z': 6, 'w': 7}) for ____ in [None]: @@ -72,8 +74,8 @@ def f(): print('{} = {}'.format(k, v)) del _g_0 None - ''') - result2 = dedent(''' + ''' + result2 = ''' def f(): _g_0 = dict(x=1, args=(2, 3, 4), y=5, kwargs={'w': 7, 'z': 6}) for ____ in [None]: @@ -85,11 +87,10 @@ def f(): print('{} = {}'.format(k, v)) del _g_0 None - ''') - self.assertIn(pragma.inline(g, return_source=True)(f).strip(), - [result1.strip(), result2.strip()]) + ''' - self.assertEqual(f(), pragma.inline(g)(f)()) + self.assertSourceIn(inline_f, result1, result2) + self.assertEqual(f(), inline_f()) def test_recursive(self): def fib(n): @@ -101,28 +102,21 @@ def fib(n): return fib(n-1) + fib(n-2) from miniutils import tic + known_fibs = { + 0: 1, + 1: 1, + 2: 2, + 3: 3, + 4: 5, + 5: 8, + } toc = tic() - fib_code = pragma.inline(fib, max_depth=1, return_source=True)(fib) - toc("Inlined recursive function to depth 1") - print(fib_code) - # fib_code = pragma.inline(fib, max_depth=3, return_source=True)(fib) - # toc("Inlined recursive function to depth 3") - # print(fib_code) - - fib = pragma.inline(fib, max_depth=2)(fib) - toc("Inlined executable function") - self.assertEqual(fib(0), 1) - toc("Ran fib(0)") - self.assertEqual(fib(1), 1) - toc("Ran fib(1)") - self.assertEqual(fib(2), 2) - toc("Ran fib(2)") - self.assertEqual(fib(3), 3) - toc("Ran fib(3)") - self.assertEqual(fib(4), 5) - toc("Ran fib(4)") - self.assertEqual(fib(5), 8) - toc("Ran fib(5)") + for depth in range(1, 4): + inline_fib = pragma.inline(fib, max_depth=depth)(fib) + toc('Inlined fibonacci function to depth of {}'.format(inline_fib)) + for k, v in known_fibs.items(): + self.assertEqual(fib(k), v) + toc("Ran fib_{}({})=={}".format(depth, k, v)) # def test_failure_cases(self): # def g_for(x): @@ -143,9 +137,9 @@ def f(y): return 0 return g(y - 1) - f_code = pragma.inline(g, return_source=True)(f) + f_code = pragma.inline(g)(f) - result = dedent(''' + result = ''' def f(y): if y <= 0: return 0 @@ -156,12 +150,13 @@ def f(y): _g_return_0 = _g_0.get('return', None) del _g_0 return _g_return_0 - ''') - self.assertEqual(f_code.strip(), result.strip()) + ''' + + self.assertSourceEqual(f_code, result) - f_unroll_code = pragma.unroll(return_source=True)(pragma.inline(g)(f)) + f_unroll_code = pragma.unroll(pragma.inline(g)(f)) - result_unroll = dedent(''' + result_unroll = ''' def f(y): if y <= 0: return 0 @@ -170,10 +165,11 @@ def f(y): _g_return_0 = _g_0.get('return', None) del _g_0 return _g_return_0 - ''') - self.assertEqual(f_unroll_code.strip(), result_unroll.strip()) + ''' + + self.assertSourceEqual(f_unroll_code, result_unroll) - f2_code = pragma.inline(f, g, return_source=True, f=f)(f) + f2_code = pragma.inline(f, g, f=f)(f) result2 = dedent(''' def f(y): @@ -196,8 +192,8 @@ def f(y): del _g_0 return _g_return_0 ''') - print(f2_code) - self.assertEqual(f2_code.strip(), result2.strip()) + + self.assertSourceEqual(f2_code, result2) def test_generator(self): def g(y): @@ -205,11 +201,11 @@ def g(y): yield i yield from range(y) - @pragma.inline(g, return_source=True) + @pragma.inline(g) def f(x): return sum(g(x)) - result = dedent(''' + result = ''' def f(x): _g_0 = dict([('yield', [])], y=x) for ____ in [None]: @@ -219,22 +215,24 @@ def f(x): _g_return_0 = _g_0['yield'] del _g_0 return sum(_g_return_0) - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_variable_starargs(self): def g(a, b, c): return a + b + c - @pragma.inline(g, return_source=True) + @pragma.inline(g) def f(x): return g(*x) - result = dedent(''' + result = ''' def f(x): return g(*x) - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_multiple_inline(self): def a(x): @@ -243,12 +241,12 @@ def a(x): def b(x): return x + 2 - @pragma.unroll(return_source=True) + @pragma.unroll @pragma.inline(a, b) def f(x): return a(x) + b(x) - result = dedent(''' + result = ''' def f(x): _a_0 = dict(x=x) _a_0['return'] = _a_0['x'] ** 2 @@ -259,8 +257,9 @@ def f(x): _b_return_0 = _b_0.get('return', None) del _b_0 return _a_return_0 + _b_return_0 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_coverage(self): def g(y): @@ -273,7 +272,7 @@ def f(): except: raise - print(pragma.inline(g, return_source=True)(f)) + print(pragma.inline(g)(f)) self.assertEqual(f(), pragma.inline(g)(f)()) def test_bug_my_range(self): @@ -288,7 +287,7 @@ def my_range(x): def test_my_range(): return list(my_range(5)) - result = dedent(''' + result = ''' def test_my_range(): _my_range_0 = dict([('yield', [])], x=5) i = 0 @@ -298,7 +297,7 @@ def test_my_range(): _my_range_return_0 = _my_range_0['yield'] del _my_range_0 return list(_my_range_return_0) - ''') + ''' self.assertSourceEqual(test_my_range, result) self.assertEqual(test_my_range(), [0, 1, 2, 3, 4]) diff --git a/tests/test_lambda_lift.py b/tests/test_lambda_lift.py index 92e87a4..864ce95 100644 --- a/tests/test_lambda_lift.py +++ b/tests/test_lambda_lift.py @@ -65,16 +65,16 @@ def test_not_all_locals(self): x = 1 y = 2 - @pragma.lift(return_source=True, imports=False) + @pragma.lift(imports=False) def f(z): return z + x - result = dedent(''' + result = ''' def f(z, *, x): return z + x - ''') + ''' - self.assertEqual(f.strip(), result.strip()) + self.assertSourceEqual(f, result) def test_defaults_thoroughly(self): x = 1 @@ -105,14 +105,16 @@ def f(z, *, o, x: 'number', y=5): ''') def test_no_closure(self): - @pragma.lift(return_source=True, imports=False) + @pragma.lift(imports=False) def f(x): return x - self.assertEqual(f.strip(), dedent(''' + result = ''' def f(x): return x - ''').strip()) + ''' + + self.assertSourceEqual(f, result) def test_method(self): class A: @@ -131,12 +133,14 @@ def f(self, x): self.assertEqual(A.f(something_else, 1), 3) def test_global(self): - global_g = pragma.lift(return_source=True, lift_globals=['global_x'], defaults=True, imports=False)(global_f) + global_g = pragma.lift(lift_globals=['global_x'], defaults=True, imports=False)(global_f) - self.assertEqual(global_g.strip(), dedent(''' + result = ''' def global_f(y, *, global_x=10): return global_x + y - ''').strip()) + ''' + + self.assertSourceEqual(global_g, result) def test_imports(self): import sys @@ -146,13 +150,13 @@ def f(): self.assertEqual(f(), sys.version_info) self.assertEqual(pragma.lift(f)(), sys.version_info) - self.assertSourceEqual(pragma.lift(return_source=True, imports=True)(f), ''' + self.assertSourceEqual(pragma.lift(imports=True)(f), ''' def f(): import pragma import sys return sys.version_info ''') - self.assertSourceEqual(pragma.lift(return_source=True, imports=['sys'])(f), ''' + self.assertSourceEqual(pragma.lift(imports=['sys'])(f), ''' def f(): import sys return sys.version_info @@ -163,7 +167,7 @@ def f(): def g(): return pseudo_sys.version_info - self.assertSourceEqual(pragma.lift(return_source=True, imports=True)(g), ''' + self.assertSourceEqual(pragma.lift(imports=True)(g), ''' def g(): import pragma import sys as pseudo_sys @@ -171,16 +175,18 @@ def g(): ''') def test_docstring(self): - @pragma.lift(return_source=True, imports=True) + @pragma.lift(imports=True) def f(x): 'some docstring' return x + 1 - self.assertSourceEqual(f, ''' + result = ''' def f(x): """some docstring""" import pragma return x + 1 - ''') + ''' + + self.assertSourceEqual(f, result) diff --git a/tests/test_unroll.py b/tests/test_unroll.py index b2c8c79..4e9a2a5 100644 --- a/tests/test_unroll.py +++ b/tests/test_unroll.py @@ -26,7 +26,7 @@ def test_unroll_various(self): g.a = [1, 2, 3] g.b = 6 - @pragma.unroll(return_source=True) + @pragma.unroll def f(x): y = 5 a = range(3) @@ -52,7 +52,7 @@ def f(x): for i in [g.b + 0, g.b + 1, g.b + 2]: yield i - result = dedent(''' + result = ''' def f(x): y = 5 a = range(3) @@ -85,8 +85,9 @@ def f(x): yield 6 yield 7 yield 8 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_unroll_const_list(self): @pragma.unroll @@ -113,7 +114,7 @@ def f(): self.assertEqual(list(f()), [1, 2, 4]) def test_unroll_dyn_list_source(self): - @pragma.unroll(return_source=True) + @pragma.unroll def f(): x = 3 a = [x, x, x] @@ -124,7 +125,7 @@ def f(): for i in a: yield i - result = dedent(''' + result = ''' def f(): x = 3 a = [x, x, x] @@ -136,10 +137,12 @@ def f(): yield 4 yield 4 yield 4 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_unroll_dyn_list(self): + @pragma.unroll def summation(x=0): a = [x, x, x] v = 0 @@ -147,10 +150,8 @@ def summation(x=0): v += _a return v - summation_source = pragma.unroll(return_source=True)(summation) - summation = pragma.unroll(summation) - code = dedent(''' + result = ''' def summation(x=0): a = [x, x, x] v = 0 @@ -158,14 +159,15 @@ def summation(x=0): v += x v += x return v - ''') - self.assertEqual(summation_source.strip(), code.strip()) + ''' + + self.assertSourceEqual(summation, result) self.assertEqual(summation(), 0) self.assertEqual(summation(1), 3) self.assertEqual(summation(5), 15) def test_unroll_dyn_list_const(self): - @pragma.collapse_literals(return_source=True) + @pragma.collapse_literals @pragma.unroll(x=3) def summation(): a = [x, x, x] @@ -174,7 +176,7 @@ def summation(): v += _a return v - code = dedent(''' + result = ''' def summation(): a = [x, x, x] v = 0 @@ -182,17 +184,18 @@ def summation(): v += 3 v += 3 return 9 - ''') - self.assertEqual(summation.strip(), code.strip()) + ''' + + self.assertSourceEqual(summation, result) def test_unroll_2range_source(self): - @pragma.unroll(return_source=True) + @pragma.unroll def f(): for i in range(3): for j in range(3): yield i + j - result = dedent(''' + result = ''' def f(): yield 0 + 0 yield 0 + 1 @@ -203,17 +206,18 @@ def f(): yield 2 + 0 yield 2 + 1 yield 2 + 2 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_unroll_2list_source(self): - @pragma.unroll(return_source=True) + @pragma.unroll def f(): for i in [[1, 2, 3], [4, 5], [6]]: for j in i: yield j - result = dedent(''' + result = ''' def f(): yield 1 yield 2 @@ -221,44 +225,45 @@ def f(): yield 4 yield 5 yield 6 - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_external_definition(self): # Known bug: this works when defined as a kwarg, but not as an external variable, but ONLY in unittests... # External variables work in practice - @pragma.unroll(return_source=True, a=range) + @pragma.unroll(a=range) def f(): for i in a(3): print(i) - result = dedent(''' + result = ''' def f(): print(0) print(1) print(2) - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_tuple_assign(self): - # This is still early code, so just make sure that it recognizes when a name is assigned to... we don't get values yet - # TODO: Implement tuple assignment - @pragma.unroll(return_source=True) + @pragma.unroll def f(): x = 3 ((y, x), z) = ((1, 2), 3) for i in [x, x, x]: print(i) - result = dedent(''' + result = ''' def f(): x = 3 (y, x), z = (1, 2), 3 print(2) print(2) print(2) - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_tuple_loop(self): @pragma.unroll @@ -277,7 +282,7 @@ def f(): self.assertListEqual(list(f()), [6, 8, 10]) def test_top_break(self): - @pragma.unroll(return_source=True) + @pragma.unroll def f(): for i in range(10): print(i) @@ -287,24 +292,26 @@ def f(): def f(): print(0) ''') - self.assertEqual(f.strip(), result.strip()) + + self.assertSourceEqual(f, result) def test_inner_break(self): - @pragma.unroll(return_source=True) + @pragma.unroll def f(y): for i in range(10): print(i) if i == y: break - result = dedent(''' + result = ''' def f(y): for i in range(10): print(i) if i == y: break - ''') - self.assertEqual(f.strip(), result.strip()) + ''' + + self.assertSourceEqual(f, result) def test_nonliteral_iterable(self): def g(x):