Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (linalg) Add linalg methods to determine loop range #1279

Merged
merged 6 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions tests/dialects/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from xdsl.dialects import arith, func, linalg, memref
from xdsl.dialects.builtin import AffineMapAttr, f32
from xdsl.ir.affine import AffineExpr, AffineMap
from xdsl.printer import Printer


def test_linalg_on_memrefs():
Expand Down Expand Up @@ -34,10 +33,37 @@ def body(args: tuple[Any, ...]):

func.Return()

foo = func.FuncOp("foo", ([], []), funcBody)
func.FuncOp("foo", ([], []), funcBody)

printer = Printer()
printer.print(foo)

def test_loop_range_methods():
A = memref.Alloc.get(f32, shape=[100, 50])
B = memref.Alloc.get(f32, shape=[50, 100])
C = memref.Alloc.get(f32, shape=[100, 100])

test_linalg_on_memrefs()
@Builder.implicit_region((f32, f32, f32))
def body(args: tuple[Any, ...]):
a, b, c = args
linalg.Yield(arith.Addf(arith.Mulf(a, b), c))

i = AffineExpr.dimension(0)
j = AffineExpr.dimension(1)
k = AffineExpr.dimension(2)

indexing_maps = [
AffineMapAttr(AffineMap(3, 0, [i, k])),
AffineMapAttr(AffineMap(3, 0, [k, j])),
AffineMapAttr(AffineMap(3, 0, [i, j])),
]
iterators = [
linalg.IteratorTypeAttr(linalg.IteratorType.PARALLEL),
linalg.IteratorTypeAttr(linalg.IteratorType.PARALLEL),
linalg.IteratorTypeAttr(linalg.IteratorType.PARALLEL),
]

op = linalg.Generic(
[A.results[0], B.results[0]], [C.results[0]], body, indexing_maps, iterators
)

loops = op.get_static_loop_ranges()
assert loops == [100, 50, 50]
65 changes: 65 additions & 0 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
AnyShapedType,
AnyTensorType,
ArrayAttr,
ShapedType,
StringAttr,
)
from xdsl.ir import Attribute, Data, Dialect, Operation, Region, SSAValue
from xdsl.ir.affine import AffineMap
from xdsl.irdl import (
AttrSizedOperandSegments,
IRDLOperation,
Expand Down Expand Up @@ -103,6 +105,69 @@ def __init__(
regions=[body],
)

def get_indexing_maps(self) -> list[AffineMap]:
return [attr.data for attr in self.indexing_maps]

def get_num_loops(self) -> int:
return self.indexing_maps.data[0].data.num_dims

def get_loops_to_shapes_map(self) -> AffineMap:
"""
Returns a map to answer the question: "given an iteration space over
the codomain, what are the subshapes of the operands involved in the
computation".
The default behavior is to just concatenate all the indexing maps.
"""
result_exprs = [res for map in self.get_indexing_maps() for res in map.results]

dims = self.get_num_loops()

# FIXME: Support symbols.
for map in self.get_indexing_maps():
if map.num_symbols != 0:
raise NotImplementedError(
"Indexing maps with symbols not supported for now."
)

syms = 0
return AffineMap(dims, syms, result_exprs)

def get_shapes_to_loops_map(self) -> AffineMap:
"""
Returns a map to answer the question: "Given a list of operand ranges,
what is the subportion of the iteration space involved in the
computation". This is the inverse problem of `get_loops_to_shapes_map`.
Return the empty AffineMap when such an AffineMap cannot be
constructed. The default behavior is based on a very simple inference
procedure that only works with permutation affine maps. A more advanced
Tensor-Comprehension like inference is possible but has proven to be
ambiguous in unfavorable case. A safer and more robust alternative is
to allow each op to define its own AffineMap.
"""
loops_to_shapes = self.get_loops_to_shapes_map()
inverse = loops_to_shapes.inverse_permutation()
if not inverse:
raise NotImplementedError(
"Non-invertible maps need dynamic shapes, which are not implemented."
)
return inverse

def get_static_shapes(self) -> list[int]:
sizes: list[int] = []
for input in self.inputs:
if isinstance(input.type, ShapedType):
for dim in input.type.get_shape():
sizes.append(dim)
for output in self.outputs:
if isinstance(output.type, ShapedType):
for dim in output.type.get_shape():
sizes.append(dim)
return sizes
Groverkss marked this conversation as resolved.
Show resolved Hide resolved

def get_static_loop_ranges(self) -> list[int]:
shapes_to_loops = self.get_shapes_to_loops_map()
return shapes_to_loops.eval(self.get_static_shapes(), [])


@irdl_op_definition
class Yield(IRDLOperation):
Expand Down
41 changes: 40 additions & 1 deletion xdsl/ir/affine/affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass

from xdsl.ir.affine import AffineExpr
from xdsl.ir.affine import AffineDimExpr, AffineExpr


@dataclass
Expand Down Expand Up @@ -47,6 +47,45 @@ def compose(self, map: AffineMap) -> AffineMap:
results=results,
)

def inverse_permutation(self) -> AffineMap | None:
Groverkss marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns a map of codomain to domain dimensions such that the first
codomain dimension for a particular domain dimension is selected.
Returns an empty map if the input map is empty. Returns null map (not
empty map) if the map is not invertible (i.e. the map does not contain
a subset that is a permutation of full domain rank).

Prerequisites: The map should have no symbols.

Example:
(d0, d1, d2) -> (d1, d1, d0, d2, d1, d2, d1, d0)
0 2 3
returns:
(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3)
"""
if self.num_symbols != 0:
raise ValueError(
f"Cannot invert AffineMap with symbols: {self.num_symbols}"
)
found_dims = [-1] * self.num_dims

for i, expr in enumerate(self.results):
match expr:
case AffineDimExpr():
found_dims[expr.position] = i
case _:
continue

if -1 in found_dims:
return None

results = [self.results[i] for i in found_dims]
return AffineMap(
num_dims=len(self.results),
num_symbols=0,
results=results,
)

def eval(self, dims: list[int], symbols: list[int]) -> list[int]:
"""Evaluate the AffineMap given the values of dimensions and symbols."""
assert len(dims) == self.num_dims
Expand Down