Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relaxed Shape/Size Equality Checks #1321

Merged
merged 6 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 21 additions & 6 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,10 +2143,13 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op
if len(arr1.shape) > 1 and len(arr2.shape) > 1: # matrix * matrix

if len(arr1.shape) > 3 or len(arr2.shape) > 3:
raise SyntaxError('Matrix multiplication of tensors of dimensions > 3 '
'not supported')
raise SyntaxError('Matrix multiplication of tensors of dimensions > 3 not supported')

if arr1.shape[-1] != arr2.shape[-2]:
res = symbolic.equal(arr1.shape[-1], arr2.shape[-2])
if res is None:
warnings.warn(f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of '
f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning)
elif not res:
raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2]))

from dace.libraries.blas.nodes.matmul import _get_batchmm_opts
Expand All @@ -2160,23 +2163,35 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

elif len(arr1.shape) == 2 and len(arr2.shape) == 1: # matrix * vector

if arr1.shape[1] != arr2.shape[0]:
res = symbolic.equal(arr1.shape[-1], arr2.shape[0])
if res is None:
warnings.warn(f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError("Number of matrix columns {} must match"
"size of vector {}.".format(arr1.shape[1], arr2.shape[0]))

output_shape = (arr1.shape[0], )

elif len(arr1.shape) == 1 and len(arr2.shape) == 2: # vector * matrix

if arr1.shape[0] != arr2.shape[0]:
res = symbolic.equal(arr1.shape[0], arr2.shape[0])
if res is None:
warnings.warn(f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError("Size of vector {} must match number of matrix "
"rows {} must match".format(arr1.shape[0], arr2.shape[0]))

output_shape = (arr2.shape[1], )

elif len(arr1.shape) == 1 and len(arr2.shape) == 1: # vector * vector

if arr1.shape[0] != arr2.shape[0]:
res = symbolic.equal(arr1.shape[0], arr2.shape[0])
if res is None:
warnings.warn(f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError("Vectors in vector product must have same size: "
"{} vs. {}".format(arr1.shape[0], arr2.shape[0]))

Expand Down
23 changes: 15 additions & 8 deletions dace/libraries/blas/nodes/batched_matmul.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from copy import deepcopy as dc
from dace import dtypes, memlet as mm, properties, data as dt
from typing import Any, Dict, Optional
from dace.symbolic import symstr
from dace.symbolic import symstr, equal
import dace.library
import dace.properties
from dace.frontend.common import op_repository as oprepo
Expand All @@ -12,6 +11,7 @@
to_cublas_computetype)
from dace.libraries.blas.nodes.matmul import (_get_matmul_operands, _get_batchmm_opts, _get_codegen_gemm_opts)
from .. import environments
import warnings


@dace.library.expansion
Expand All @@ -28,8 +28,12 @@ def make_sdfg(node, parent_state, parent_sdfg):
cdesc = parent_sdfg.arrays[outedge.data.data]
bopt = _get_batchmm_opts(shape_a, strides_a, shape_b, strides_b, cdesc.shape, cdesc.strides)

if shape_a[-1] != shape_b[-2]:
raise SyntaxError('Matrix sizes must match')
res = equal(shape_a[-1], shape_b[-2])
if res is None:
warnings.warn(f"First matrix columns {shape_a[-1]} may not match second matrix rows {shape_b[-2]}",
UserWarning)
elif not res:
raise SyntaxError("Matrix sizes must match")
if bopt:
shape_c = (bopt['b'], shape_a[-2], shape_b[-1])
else:
Expand Down Expand Up @@ -436,9 +440,12 @@ def validate(self, sdfg, state):
raise ValueError("Batched matrix-matrix product only supported on matrices")
if len(size1) != 3:
raise ValueError("Batched matrix-matrix product only supported on matrices")
if size0[-1] != size1[-2]:
raise ValueError("Inputs to matrix-matrix product "
"must agree in the k-dimension")
res = equal(size0[-1], size1[-2])
if res is None:
warnings.warn(f'First tensor\'s last mode {size0[-1]} and second tensor\'s second-last mode {size1[-2]} '
f'may not match', UserWarning)
elif not res:
raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension")
out_subset = dc(out_memlet.subset)
out_subset.squeeze()
size2 = out_subset.size()
Expand Down
45 changes: 32 additions & 13 deletions dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from copy import deepcopy as dc
from typing import Any, Dict, Optional
from dace import dtypes, memlet as mm, properties, data as dt
from dace.symbolic import symstr
from dace.symbolic import symstr, equal
import dace.library
from dace import SDFG, SDFGState
from dace.frontend.common import op_repository as oprepo
Expand All @@ -13,7 +12,7 @@
from dace.libraries.blas.nodes.matmul import (_get_matmul_operands, _get_codegen_gemm_opts)
from .. import environments
import numpy as np
from numbers import Number
import warnings


def _is_complex(dtype):
Expand Down Expand Up @@ -65,7 +64,13 @@ def make_sdfg(node, parent_state, parent_sdfg):
else:
trans_shape_b = shape_b

if (len(trans_shape_a) != 2 or len(trans_shape_b) != 2 or trans_shape_a[1] != trans_shape_b[0]):
if len(trans_shape_a) != 2 or len(trans_shape_b) != 2:
raise SyntaxError("Matrix sizes must match")
res = equal(trans_shape_a[1], trans_shape_b[0])
if res is None:
warnings.warn(f"First matrix columns {trans_shape_a[1]} may not match "
f"second matrix rows {trans_shape_b[0]}", UserWarning)
elif not res:
raise SyntaxError("Matrix sizes must match")
M, K, N = trans_shape_a[0], trans_shape_a[1], trans_shape_b[1]
shape_c = (M, N)
Expand Down Expand Up @@ -1032,19 +1037,33 @@ def validate(self, sdfg, state):
# Function is symmetric, edge order does not matter
if len(size0) != 2 or len(size1) != 2:
raise ValueError("matrix-matrix product only supported on matrices")
if size0[1] != size1[0]:
raise ValueError("Inputs to matrix-matrix product "
"must agree in the k-dimension")
res = equal(size0[1], size1[0])
if res is None:
warnings.warn(f'First matrix columns {size0[1]} and second matrix rows {size1[0]} may not match',
UserWarning)
elif not res:
raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension")
out_subset = dc(out_memlet.subset)
out_subset.squeeze()
size3 = out_subset.size()
if size2 is not None and size2 != size3:
raise ValueError("Input C matrix must match output matrix.")
if size2 is not None:
res = [equal(s0, s1) for s0, s1 in zip(size2, size3)]
fail = any([r is False for r in res])
success = all([r is True for r in res])
if fail:
raise ValueError("Input C matrix must match output matrix.")
elif not success:
warnings.warn(f"Size of input C matrix {size2} may not match output matrix size {size3}", UserWarning)
if len(size3) != 2:
raise ValueError("matrix-matrix product only supported on matrices")
if len(size3) == 2 and list(size3) != [size0[-2], size1[-1]]:
raise ValueError("Output to matrix-matrix product must agree in the m and n "
"dimensions")
if len(size3) == 2:
res = [equal(s0, s1) for s0, s1 in zip(size3, [size0[-2], size1[-1]])]
fail = any([r is False for r in res])
success = all([r is True for r in res])
if fail:
raise ValueError("Output to matrix-matrix product must agree in the m and n dimensions")
elif not success:
warnings.warn(f'Size of output {size3} may not match input {size0} @ {size1}', UserWarning)


# Numpy replacement
Expand Down
15 changes: 10 additions & 5 deletions dace/libraries/blas/nodes/matmul.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace import properties
from dace import properties, symbolic
from copy import deepcopy as dc
from typing import Any, Dict, Optional
from typing import Any, Dict
import warnings


Expand Down Expand Up @@ -58,8 +58,13 @@ def _get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, c_strides
batch = a_shape[0]
stride_a = a_strides[0]
if len(b_shape) == 3:
if batch and batch != b_shape[0]:
raise ValueError('Batch size mismatch for matrix multiplication')
if batch is not None:
res = symbolic.equal(batch, b_shape[0])
if res is None:
warnings.warn(f'Batch size of first tensor ({batch}) may not match second tensor ({b_shape[0]})',
UserWarning)
elif not res:
raise ValueError('Batch size mismatch for matrix multiplication')
batch = b_shape[0]
stride_b = b_strides[0]
if c_shape and len(c_shape) == 3:
Expand Down
24 changes: 24 additions & 0 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,3 +1337,27 @@ def inequal_symbols(a: Union[sympy.Expr, Any], b: Union[sympy.Expr, Any]) -> boo
# We subtract and compare to zero according to the SymPy documentation
# (https://docs.sympy.org/latest/tutorial/gotchas.html).
return (a - b).simplify() != 0


def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[bool, None]:
"""
Compares 2 symbolic expressions and returns True if they are equal, False if they are inequal,
and None if the comparison is inconclusive.

:param a: First symbolic expression.
:param b: Second symbolic expression.
:param is_length: If True, the assumptions that a, b are integers and positive are made.
"""

args = [arg.expr if isinstance(arg, SymExpr) else arg for arg in (a, b)]

if any([args is None for args in args]):
return False

facts = []
if is_length:
for arg in args:
facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)]

with sympy.assuming(*facts):
return sympy.ask(sympy.Q.is_true(sympy.Eq(*args)))
53 changes: 51 additions & 2 deletions tests/library/gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
M = dace.symbol('M')
K = dace.symbol('K')
N = dace.symbol('N')
L = dace.symbol('L')
O = dace.symbol('O')


@pytest.mark.parametrize(
Expand Down Expand Up @@ -171,8 +173,55 @@ def params_generator(grid):
"misconfigured, skipping test for {}.".format(implementation))


def test_gemm_symbolic():
sdfg = dace.SDFG("gemm")
state = sdfg.add_state()
A, A_arr = sdfg.add_array("A", [M, K], dace.float64)
B, B_arr = sdfg.add_array("B", [L, N], dace.float64)
C, C_arr = sdfg.add_array("C", [O, N], dace.float64)

rA = state.add_read("A")
rB = state.add_read("B")
wC = state.add_write("C")

libnode = Gemm('_Gemm_', transA=False, transB=False, alpha=1.0, beta=0.0)
state.add_node(libnode)

state.add_edge(rA, None, libnode, '_a', dace.Memlet.from_array(A, A_arr))
state.add_edge(rB, None, libnode, '_b', dace.Memlet.from_array(B, B_arr))
state.add_edge(libnode, '_c', wC, None, dace.Memlet.from_array(C, C_arr))

sdfg.validate()


def test_gemm_symbolic_1():
sdfg = dace.SDFG("gemm")
state = sdfg.add_state()
A, A_arr = sdfg.add_array("A", [M, K], dace.float64)
B, B_arr = sdfg.add_array("B", [K + 2, N], dace.float64)
C, C_arr = sdfg.add_array("C", [M, N], dace.float64)

rA = state.add_read("A")
rB = state.add_read("B")
wC = state.add_write("C")

libnode = Gemm('_Gemm_', transA=False, transB=False, alpha=1.0, beta=0.0)
state.add_node(libnode)

state.add_edge(rA, None, libnode, '_a', dace.Memlet.from_array(A, A_arr))
state.add_edge(rB, None, libnode, '_b', dace.Memlet.from_array(B, B_arr))
state.add_edge(libnode, '_c', wC, None, dace.Memlet.from_array(C, C_arr))

try:
sdfg.validate()
except dace.sdfg.InvalidSDFGError:
pass


if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == 'gpu':
test_library_gemm('cuBLAS')
test_library_gemm('pure')
test_library_gemm('MKL')
# test_library_gemm('pure')
# test_library_gemm('MKL')
test_gemm_symbolic()
test_gemm_symbolic_1()
26 changes: 24 additions & 2 deletions tests/numpy/matrix_multiplication_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import unittest
import dace
import numpy as np

B, M, N, K = tuple(dace.symbol(k) for k in 'BMNK')
B, M, N, K, L, O = tuple(dace.symbol(k) for k in 'BMNKLO')


class MatrixMultiplication(unittest.TestCase):
Expand Down Expand Up @@ -39,6 +39,28 @@ def mmmtest(a: dace.float64[M, K], b: dace.float64[B, K, N]):
c = mmmtest(a, b)
self.assertEqual(list(c.shape), [3, 34, 31])
self.assertTrue(np.allclose(c, a @ b))

def test_mm_symbolic(self):
@dace.program
def mmtest_symbolic(a: dace.float64[M, K], b: dace.float64[O, N]):
return a @ b

a = np.random.rand(32, 33)
b = np.random.rand(33, 34)
c = mmtest_symbolic(a, b)
self.assertEqual(list(c.shape), [32, 34])
self.assertTrue(np.allclose(c, a @ b))

def test_mmm_batch_symbolic(self):
@dace.program
def mmmtest_symbolic(a: dace.float64[B, M, K], b: dace.float64[L, O, N]):
return a @ b

a = np.random.rand(3, 34, 32)
b = np.random.rand(3, 32, 31)
c = mmmtest_symbolic(a, b)
self.assertEqual(list(c.shape), [3, 34, 31])
self.assertTrue(np.allclose(c, a @ b))


if __name__ == '__main__':
Expand Down