Skip to content

Commit

Permalink
Merge pull request #1181 from spcl/usrs/lukas/bmm
Browse files Browse the repository at this point in the history
BatchedMatMul: MKL gemm_batch support
  • Loading branch information
tbennun committed Mar 16, 2023
2 parents a46a5f9 + bd55c53 commit dd69131
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 12 deletions.
86 changes: 74 additions & 12 deletions dace/libraries/blas/nodes/batched_matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
from copy import deepcopy as dc
from dace import dtypes, memlet as mm, properties, data as dt
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -84,6 +84,75 @@ class ExpandBatchedMatMulMKL(ExpandTransformation):

environments = [environments.intel_mkl.IntelMKL]

@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg)
cdesc: dt.Array = sdfg.arrays[state.out_edges(node)[0].data.data]
check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)
dtype = cdesc.dtype.base_type
func = to_blastype(dtype.type).lower() + 'gemm'
if dtype == dace.float32:
alpha = "1.0f"
beta = "0.0f"
prefix = "s"
elif dtype == dace.float64:
alpha = "1.0"
beta = "0.0"
prefix = "d"
elif dtype == dace.complex64:
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
prefix = "c"
elif dtype == dace.complex128:
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
prefix = "z"
else:
raise ValueError("Unsupported type for BLAS dot product: " + str(dtype))
opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func)

opt['prefix'] = prefix
opt['dtype'] = cdesc.dtype.ctype

code = '''
const MKL_INT group_count = 1;
MKL_INT group_sizes[group_count] = {{ {BATCH} }};
MKL_INT m_array[group_count] = {{ {M} }};
MKL_INT n_array[group_count] = {{ {N} }};
MKL_INT k_array[group_count] = {{ {K} }};
char transa[group_count] = {{ '{ta}' }};
char transb[group_count] = {{ '{tb}' }};
{dtype} alpha_array[group_count] = {{ {alpha} }};
{dtype} beta_array[group_count] = {{ {beta} }};
MKL_INT lda_array[group_count] = {{ {lda} }};
MKL_INT ldb_array[group_count] = {{ {ldb} }};
MKL_INT ldc_array[group_count] = {{ {ldc} }};
const {dtype}** A = new const {dtype}*[{BATCH}];
const {dtype}** B = new const {dtype}*[{BATCH}];
{dtype}** C = new {dtype}*[{BATCH}];
for (int __ib = 0; __ib < {BATCH}; __ib++) {{
A[__ib] = (({dtype}*){x}) + __ib*{stride_a};
B[__ib] = (({dtype}*){y}) + __ib*{stride_b};
C[__ib] = (({dtype}*)_c) + __ib*{stride_c};
}}
{prefix}gemm_batch(transa, transb, m_array, n_array, k_array, alpha_array, A, lda_array, B, ldb_array, beta_array, C, ldc_array, &group_count, group_sizes);'''.format_map(
opt)

tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
code,
language=dace.dtypes.Language.CPP)
return tasklet


@dace.library.expansion
class ExpandBatchedMatMulOpenBLAS(ExpandTransformation):
environments = [environments.openblas.OpenBLAS]

@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
Expand Down Expand Up @@ -129,15 +198,6 @@ def expansion(node, state, sdfg):
return tasklet


@dace.library.expansion
class ExpandBatchedMatMulOpenBLAS(ExpandTransformation):
environments = [environments.openblas.OpenBLAS]

@staticmethod
def expansion(*args, **kwargs):
return ExpandBatchedMatMulMKL.expansion(*args, **kwargs)


@dace.library.expansion
class ExpandBatchedMatMulCuBLAS(ExpandTransformation):

Expand Down Expand Up @@ -368,15 +428,17 @@ def validate(self, sdfg, state):
size1 = subset.size()
out_edges = state.out_edges(self)
if len(out_edges) != 1:
raise ValueError("Expected exactly one output from " "batched matrix-matrix product")
raise ValueError("Expected exactly one output from "
"batched matrix-matrix product")
out_memlet = out_edges[0].data
# Function is symmetric, edge order does not matter
if len(size0) not in [2, 3]:
raise ValueError("Batched matrix-matrix product only supported on matrices")
if len(size1) != 3:
raise ValueError("Batched matrix-matrix product only supported on matrices")
if size0[-1] != size1[-2]:
raise ValueError("Inputs to matrix-matrix product " "must agree in the k-dimension")
raise ValueError("Inputs to matrix-matrix product "
"must agree in the k-dimension")
out_subset = dc(out_memlet.subset)
out_subset.squeeze()
size2 = out_subset.size()
Expand Down
51 changes: 51 additions & 0 deletions tests/library/batched_matmul_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
import pytest
import numpy as np

import dace
import dace.libraries.blas as blas

from dace.library import change_default


@pytest.mark.parametrize("implementation, dtype", [
pytest.param("pure", dace.float32),
pytest.param("pure", dace.float64),
pytest.param("MKL", dace.float32, marks=pytest.mark.mkl),
pytest.param("MKL", dace.float64, marks=pytest.mark.mkl),
pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu),
pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu)
])
def test_batchmm(implementation: str, dtype):
b, m, n, k = tuple(dace.symbol(k) for k in 'bmnk')

@dace.program
def bmm(A: dtype[b, m, k], B: dtype[b, k, n], C: dtype[b, m, n]):
C[:] = A @ B

with change_default(blas, implementation):
sdfg = bmm.to_sdfg()
sdfg.simplify()
sdfg.expand_library_nodes()

b, m, n, k = 3, 32, 31, 30

x = np.random.rand(b, m, k).astype(dtype.as_numpy_dtype())
y = np.random.rand(b, k, n).astype(dtype.as_numpy_dtype())
z = np.zeros([b, m, n]).astype(dtype.as_numpy_dtype())

csdfg = sdfg.compile()
csdfg(A=x, B=y, C=z, b=b, m=m, n=n, k=k)

ref = x @ y

assert np.allclose(ref, z)


if __name__ == "__main__":
test_batchmm("pure", dace.float32)
test_batchmm("pure", dace.float64)
test_batchmm("MKL", dace.float32)
test_batchmm("MKL", dace.float64)
test_batchmm("cuBLAS", dace.float32)
test_batchmm("cuBLAS", dace.float64)

0 comments on commit dd69131

Please sign in to comment.