Skip to content

Commit

Permalink
Hacking on bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Mar 25, 2019
1 parent 697b092 commit 39ffc35
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 20 deletions.
1 change: 1 addition & 0 deletions Pipfile
Expand Up @@ -4,6 +4,7 @@ verify_ssl = true
name = "pypi"

[packages]
numpy = "*"

[dev-packages]

Expand Down
20 changes: 20 additions & 0 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 19 additions & 6 deletions aot/__init__.py
Expand Up @@ -17,6 +17,7 @@
TVM_PATH = os.environ['TVM_HOME']

def compile_cpp(source, lib_name, flags=None, lib_path=None):
print(f"lib_name={lib_name}, flags={flags}, lib_path={lib_path}")
if flags is None:
flags = []

Expand Down Expand Up @@ -300,18 +301,30 @@ def visit_ref_write(self, r):
_LIB_COUNTER = 1
_LIB = []

def compile(mod, func, *, ctx, tgt, use_gpu, name='default'):
global _LIB, _LIB_COUNTER
def lib_and_func_name(name):
global _LIB_COUNTER
packed_name = f'relay.aot.{name}.{_LIB_COUNTER}'
lib_name = f"librelay_aot_{_LIB_COUNTER}.so"
_LIB_COUNTER += 1
return lib_name, packed_name

def _mk_wrapper(fn, ctx, params):
def _wrapper(*args):
return fn(*[convert(a, ctx) for a in params], *[convert(a, ctx) for a in args])
return _wrapper

def compile(mod, func, ctx, tgt, use_gpu, name='default'):
global _LIB
compiler = AoTCompiler(mod, tgt)
func = compiler.optimize(func)
func = compiler.visit(func)
lib_name, packed_name = lib_and_func_name(name)
params, source_code = to_source.to_source(mod, compiler.gv_map, use_gpu, packed_name, func)
print(source_code)
lib_name = f"librelay_aot_{_LIB_COUNTER}.so"
library_path = compile_cpp(source_code, lib_name, flags=["-O3"])
_LIB_COUNTER += 1
_LIB.append(load_lib(library_path))
print(f"Getting packed_name={packed_name}")
fn = get_global_func(packed_name)
def wrap(*args):
return fn(*[convert(a, ctx) for a in params], *[convert(a, ctx) for a in args])
return wrap
print(fn)
return _mk_wrapper(fn, ctx, params)
28 changes: 14 additions & 14 deletions test/test_aot.py
Expand Up @@ -194,18 +194,18 @@ def test_compose():
# np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32'))

if __name__ == "__main__":
#test_identity()
#test_add()
#test_mult_op()
#test_double()
#test_42()
#test_add_42()
#test_int_mult_3()
#test_abs()
#test_recur_sum_global()
#test_nat_3()
#test_nat_add()
#test_add_convert()
#test_ref()
#test_tuple()
test_identity()
test_add()
test_mult_op()
test_double()
test_42()
test_add_42()
test_int_mult_3()
test_abs()
test_recur_sum_global()
test_nat_3()
test_nat_add()
test_add_convert()
test_ref()
test_tuple()
test_compose()

0 comments on commit 39ffc35

Please sign in to comment.