Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused global data descriptor shapes from arguments #1338

Merged
merged 12 commits into from
Aug 3, 2023
5 changes: 4 additions & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def free_symbols(self, obj: Any):
k = id(obj)
if k in self.fsyms:
return self.fsyms[k]
result = obj.free_symbols
if hasattr(obj, 'used_symbols'):
result = obj.used_symbols(all_symbols=False)
else:
result = obj.free_symbols
self.fsyms[k] = result
return result

Expand Down
48 changes: 33 additions & 15 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,26 @@ def as_arg(self, with_types=True, for_call=False, name=None):
"""Returns a string for a C++ function signature (e.g., `int *A`). """
raise NotImplementedError

def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]:
"""
Returns a set of symbols that are used by this data descriptor.

:param all_symbols: Include not-strictly-free symbols that are used by this data descriptor,
e.g., shape and size of a global array.
:return: A set of symbols that are used by this data descriptor. NOTE: The results are symbolic
rather than a set of strings.
"""
result = set()
if self.transient or all_symbols:
for s in self.shape:
if isinstance(s, sp.Basic):
result |= set(s.free_symbols)
return result

@property
def free_symbols(self) -> Set[symbolic.SymbolicType]:
""" Returns a set of undefined symbols in this data descriptor. """
result = set()
for s in self.shape:
if isinstance(s, sp.Basic):
result |= set(s.free_symbols)
return result
return self.used_symbols(all_symbols=True)

def __repr__(self):
return 'Abstract Data Container, DO NOT USE'
Expand Down Expand Up @@ -689,20 +701,23 @@ def as_arg(self, with_types=True, for_call=False, name=None):
def sizes(self):
return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape]

@property
def free_symbols(self):
result = super().free_symbols
def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]:
result = super().used_symbols(all_symbols)
for s in self.strides:
if isinstance(s, sp.Expr):
result |= set(s.free_symbols)
if isinstance(self.total_size, sp.Expr):
result |= set(self.total_size.free_symbols)
for o in self.offset:
if isinstance(o, sp.Expr):
result |= set(o.free_symbols)

if self.transient or all_symbols:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the transient check only here? What about strides and offset just above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strides and offset are always necessary if you have any memlet (a[i*stride] will be generated in the code).

if isinstance(self.total_size, sp.Expr):
result |= set(self.total_size.free_symbols)
return result

@property
def free_symbols(self):
return self.used_symbols(all_symbols=True)

def _set_shape_dependent_properties(self, shape, strides, total_size, offset):
"""
Used to set properties which depend on the shape of the array
Expand Down Expand Up @@ -890,17 +905,20 @@ def covers_range(self, rng):

return True

@property
def free_symbols(self):
result = super().free_symbols
if isinstance(self.buffer_size, sp.Expr):
def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]:
result = super().used_symbols(all_symbols)
if (self.transient or all_symbols) and isinstance(self.buffer_size, sp.Expr):
result |= set(self.buffer_size.free_symbols)
for o in self.offset:
if isinstance(o, sp.Expr):
result |= set(o.free_symbols)

return result

@property
def free_symbols(self):
return self.used_symbols(all_symbols=True)


@make_properties
class View(Array):
Expand Down
35 changes: 25 additions & 10 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
if TYPE_CHECKING:
import dace.sdfg.graph


@make_properties
class Memlet(object):
""" Data movement object. Represents the data, the subset moved, and the
Expand Down Expand Up @@ -176,15 +177,16 @@ def to_json(self):
@staticmethod
def from_json(json_obj, context=None):
ret = Memlet()
dace.serialize.set_properties_from_json(ret,
json_obj,
context=context,
ignore_properties={'src_subset', 'dst_subset', 'num_accesses', 'is_data_src'})

dace.serialize.set_properties_from_json(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the PR, but why is YAPF constantly switching the formatting of such lines?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the version of yapf being used

ret,
json_obj,
context=context,
ignore_properties={'src_subset', 'dst_subset', 'num_accesses', 'is_data_src'})

# Allow serialized memlet to override src/dst_subset to disambiguate self-copies
if 'is_data_src' in json_obj['attributes']:
ret._is_data_src = json_obj['attributes']['is_data_src']

if context:
ret._sdfg = context['sdfg']
ret._state = context['sdfg_state']
Expand Down Expand Up @@ -510,18 +512,30 @@ def validate(self, sdfg, state):
if self.data is not None and self.data not in sdfg.arrays:
raise KeyError('Array "%s" not found in SDFG' % self.data)

@property
def free_symbols(self) -> Set[str]:
""" Returns a set of symbols used in this edge's properties. """
def used_symbols(self, all_symbols: bool) -> Set[str]:
"""
Returns a set of symbols used in this edge's properties.

:param all_symbols: If False, only returns the set of symbols that will be used
in the generated code and are needed as arguments.
"""
# Symbolic properties are in volume, and the two subsets
result = set()
result |= set(map(str, self.volume.free_symbols))
if all_symbols:
result |= set(map(str, self.volume.free_symbols))
if self.src_subset:
result |= self.src_subset.free_symbols

if self.dst_subset:
result |= self.dst_subset.free_symbols

return result

@property
def free_symbols(self) -> Set[str]:
""" Returns a set of symbols used in this edge's properties. """
return self.used_symbols(all_symbols=True)

def get_free_symbols_by_indices(self, indices_src: List[int], indices_dst: List[int]) -> Set[str]:
"""
Returns set of free symbols used in this edges properties but only taking certain indices of the src and dst
Expand Down Expand Up @@ -640,6 +654,7 @@ class MemletTree(object):
all siblings of the same edge and their children, for instance if
multiple inputs from the same access node are used.
"""

def __init__(self,
edge: 'dace.sdfg.graph.MultiConnectorEdge[Memlet]',
downwards: bool = True,
Expand Down
35 changes: 28 additions & 7 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,22 @@ def from_json(json_obj, context=None):

return ret

def used_symbols(self, all_symbols: bool) -> Set[str]:
free_syms = set().union(*(map(str,
pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()),
*(map(str,
pystr_to_symbolic(v).free_symbols) for v in self.location.values()))

# Filter out unused internal symbols from symbol mapping
if not all_symbols:
internally_used_symbols = self.sdfg.used_symbols(all_symbols=False)
free_syms &= internally_used_symbols

return free_syms

@property
def free_symbols(self) -> Set[str]:
return set().union(*(map(str,
pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()),
*(map(str,
pystr_to_symbolic(v).free_symbols) for v in self.location.values()))
return self.used_symbols(all_symbols=True)

def infer_connector_types(self, sdfg, state):
# Avoid import loop
Expand Down Expand Up @@ -673,6 +683,7 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
# Scope entry class
class EntryNode(Node):
""" A type of node that opens a scope (e.g., Map or Consume). """

def validate(self, sdfg, state):
self.map.validate(sdfg, state, self)

Expand All @@ -683,6 +694,7 @@ def validate(self, sdfg, state):
# Scope exit class
class ExitNode(Node):
""" A type of node that closes a scope (e.g., Map or Consume). """

def validate(self, sdfg, state):
self.map.validate(sdfg, state, self)

Expand All @@ -696,6 +708,7 @@ class MapEntry(EntryNode):

:see: Map
"""

def __init__(self, map: 'Map', dynamic_inputs=None):
super(MapEntry, self).__init__(dynamic_inputs or set())
if map is None:
Expand Down Expand Up @@ -772,6 +785,7 @@ class MapExit(ExitNode):

:see: Map
"""

def __init__(self, map: 'Map'):
super(MapExit, self).__init__()
if map is None:
Expand Down Expand Up @@ -851,17 +865,20 @@ class Map(object):
default=0,
desc="Number of OpenMP threads executing the Map",
optional=True,
optional_condition=lambda m: m.schedule in (dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.CPU_Persistent))
optional_condition=lambda m: m.schedule in
(dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.CPU_Persistent))
omp_schedule = EnumProperty(dtype=dtypes.OMPScheduleType,
default=dtypes.OMPScheduleType.Default,
desc="OpenMP schedule {static, dynamic, guided}",
optional=True,
optional_condition=lambda m: m.schedule in (dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.CPU_Persistent))
optional_condition=lambda m: m.schedule in
(dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.CPU_Persistent))
omp_chunk_size = Property(dtype=int,
default=0,
desc="OpenMP schedule chunk size",
optional=True,
optional_condition=lambda m: m.schedule in (dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.CPU_Persistent))
optional_condition=lambda m: m.schedule in
(dtypes.ScheduleType.CPU_Multicore, dtypes.ScheduleType.CPU_Persistent))

gpu_block_size = ListProperty(element_type=int,
default=None,
Expand Down Expand Up @@ -928,6 +945,7 @@ class ConsumeEntry(EntryNode):

:see: Consume
"""

def __init__(self, consume: 'Consume', dynamic_inputs=None):
super(ConsumeEntry, self).__init__(dynamic_inputs or set())
if consume is None:
Expand Down Expand Up @@ -1006,6 +1024,7 @@ class ConsumeExit(ExitNode):

:see: Consume
"""

def __init__(self, consume: 'Consume'):
super(ConsumeExit, self).__init__()
if consume is None:
Expand Down Expand Up @@ -1117,6 +1136,7 @@ def get_param_num(self):

@dace.serialize.serializable
class PipelineEntry(MapEntry):

@staticmethod
def map_type():
return PipelineScope
Expand Down Expand Up @@ -1149,6 +1169,7 @@ def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]:

@dace.serialize.serializable
class PipelineExit(MapExit):

@staticmethod
def map_type():
return PipelineScope
Expand Down