Skip to content

Commit

Permalink
Merge pull request #1262 from spcl/schedule-storage-lite
Browse files Browse the repository at this point in the history
Determine schedule type based on surrounding storage types
  • Loading branch information
tbennun committed Jun 5, 2023
2 parents 68b6449 + 016dafa commit fe68f39
Show file tree
Hide file tree
Showing 15 changed files with 491 additions and 166 deletions.
2 changes: 2 additions & 0 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def get_state_struct(self) -> ctypes.Structure:
:return: the ctypes.Structure representation of the state struct.
"""
if not self._libhandle:
raise ValueError('Library was not initialized')

return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents

Expand Down
8 changes: 2 additions & 6 deletions dace/codegen/targets/fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,12 +1110,8 @@ def generate_nested_state(self, sdfg, state, nest_name, subgraphs, function_stre
def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, callsite_stream):

if not self._in_device_code:
# If we're not already generating kernel code we need to set up the
# kernel launch
subgraphs = [dfg_scope]
return self.generate_kernel(sdfg, sdfg.node(state_id),
dfg_scope.source_nodes()[0].map.label.replace(" ", "_"), subgraphs,
function_stream, callsite_stream)
# If we're not already generating kernel code, fail
raise cgx.CodegenError('FPGA kernel needs to be generated inside a device state.')

self.generate_node(sdfg, dfg_scope, state_id, dfg_scope.source_nodes()[0], function_stream, callsite_stream)

Expand Down
9 changes: 7 additions & 2 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from dace.sdfg import SDFG, ScopeSubgraphView, SDFGState, nodes
from dace.sdfg import scope as sdscope
from dace.sdfg import utils
from dace.sdfg.infer_types import set_default_schedule_and_storage_types
from dace.transformation.passes.analysis import StateReachability


Expand Down Expand Up @@ -425,7 +424,13 @@ def _get_schedule(self, scope: Union[nodes.EntryNode, SDFGState, SDFG]) -> dtype
sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.parent)
if sdfg.parent_nsdfg_node is None:
return TOP_SCHEDULE
return (sdfg.parent_nsdfg_node.schedule or TOP_SCHEDULE)

# Go one SDFG up
pstate = sdfg.parent
pscope = pstate.entry_node(sdfg.parent_nsdfg_node)
if pscope is not None:
return self._get_schedule(pscope)
return self._get_schedule(pstate)
else:
raise TypeError

Expand Down
12 changes: 12 additions & 0 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@ class TilingType(aenum.AutoNumberEnum):
ScheduleType.Snitch_Multicore: ScheduleType.Snitch_Multicore
}

# Maps from StorageType to a preferred ScheduleType for helping determine schedules.
# If mapped to None or does not exist in this dictionary, does not affect decision.
# Scalar data containers also do not affect this decision.
STORAGEDEFAULT_SCHEDULE = {
StorageType.CPU_Heap: ScheduleType.CPU_Multicore,
StorageType.CPU_ThreadLocal: ScheduleType.CPU_Multicore,
StorageType.GPU_Global: ScheduleType.GPU_Device,
StorageType.GPU_Shared: ScheduleType.GPU_ThreadBlock,
StorageType.FPGA_Global: ScheduleType.FPGA_Device,
StorageType.SVE_Register: ScheduleType.SVE_Map,
}

# Translation of types to C types
_CTYPES = {
None: "void",
Expand Down
4 changes: 2 additions & 2 deletions dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def make_sdfg(node, parent_state, parent_sdfg):
init_state = sdfg.add_state(node.label + "_initstate")
state = sdfg.add_state_after(init_state, node.label + "_state")

if node.beta != 0:
if '_cin' in node.in_connectors:
sdfg.add_array("_cin", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage)

mul_out, mul_out_array = "_c", array_c
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def validate(self, sdfg, state):
# Numpy replacement
@oprepo.replaces('dace.libraries.blas.gemm')
@oprepo.replaces('dace.libraries.blas.Gemm')
def gemv_libnode(pv: 'ProgramVisitor',
def gemm_libnode(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
A,
Expand Down
2 changes: 2 additions & 0 deletions dace/libraries/blas/nodes/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def expansion(node, state, sdfg):
from dace.libraries.blas.nodes.gemm import Gemm
beta = node.beta
cin = True
if '_cin' not in node.in_connectors:
cin = False
if c[0].data.wcr:
from dace.frontend import operations
redtype = operations.detect_reduction_type(c[0].data.wcr)
Expand Down

0 comments on commit fe68f39

Please sign in to comment.