From bff9e518d03eb9b6d73d1fc6ecbb53adc12788a4 Mon Sep 17 00:00:00 2001 From: Dominic Hofer Date: Fri, 1 Nov 2019 11:28:00 +0100 Subject: [PATCH 1/2] Passes all arguments through 'compile(..)'. --- dace/sdfg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/sdfg.py b/dace/sdfg.py index a69ff94afb..edbb524a22 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -4012,16 +4012,16 @@ def local_transients(sdfg, dfg, entry_node): return transients -def compile(function_or_sdfg, *args, specialize=None): +def compile(function_or_sdfg, *args, **kwargs): """ Obtain a runnable binary from a Python (@dace.program) function. """ if isinstance(function_or_sdfg, dace.frontend.python.parser.DaceProgram): sdfg = dace.frontend.python.parser.parse_from_function( - function_or_sdfg, *args) + function_or_sdfg, *args, **kwargs) elif isinstance(function_or_sdfg, SDFG): sdfg = function_or_sdfg else: raise TypeError("Unsupported function type") - return sdfg.compile(specialize=specialize) + return sdfg.compile(**kwargs) def is_devicelevel(sdfg: SDFG, state: SDFGState, node: dace.graph.nodes.Node): From ebc7feb199badc1a66002d1305a202eb1319fc10 Mon Sep 17 00:00:00 2001 From: Dominic Hofer Date: Fri, 1 Nov 2019 11:29:31 +0100 Subject: [PATCH 2/2] Extends 'compile(..)' with an output filename. --- dace/codegen/compiler.py | 4 ++++ dace/sdfg.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/dace/codegen/compiler.py b/dace/codegen/compiler.py index 786adf431a..947977967f 100644 --- a/dace/codegen/compiler.py +++ b/dace/codegen/compiler.py @@ -146,6 +146,10 @@ def __init__(self, sdfg, lib: ReloadableDLL): self._exit = lib.get_symbol('__dace_exit') self._cfunc = lib.get_symbol('__program_{}'.format(sdfg.name)) + @property + def filename(self): + return self._lib._library_filename + @property def sdfg(self): return self._sdfg diff --git a/dace/sdfg.py b/dace/sdfg.py index edbb524a22..794aa86e5c 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -8,6 +8,7 @@ import pickle, json from pydoc import locate import random +import shutil import sys from typing import Any, Dict, Set, Tuple, List, Union import warnings @@ -1528,7 +1529,7 @@ def specialize(self, additional_symbols=None, specialize_all_symbols=True): # Update constants self.constants_prop.update(syms) - def compile(self, specialize=None, optimizer=None): + def compile(self, specialize=None, optimizer=None, output_file=None): """ Compiles a runnable binary from this SDFG. @param specialize: If True, specializes all symbols to their @@ -1537,6 +1538,8 @@ def compile(self, specialize=None, optimizer=None): @param optimizer: If defines a valid class name, it will be called during compilation to transform the SDFG as necessary. If None, uses configuration setting. + @param output_file: If not None, copies the output library file to + the specified path. @return: A callable CompiledSDFG object. """ @@ -1587,6 +1590,12 @@ def compile(self, specialize=None, optimizer=None): # Compile the code and get the shared library path shared_library = compiler.configure_and_compile(program_folder) + # If provided, save output to path or filename + if output_file is not None: + if os.path.isdir(output_file): + output_file = os.path.join(output_file, os.path.basename(shared_library)) + shutil.copyfile(shared_library, output_file) + # Get the function handle return compiler.get_program_handle(shared_library, sdfg)