-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathsnn_matmul_provider.h
137 lines (129 loc) · 5.62 KB
/
snn_matmul_provider.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/*
* Copyright Codeplay Software Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PORTDNN_INCLUDE_BACKEND_SNN_MATMUL_PROVIDER_H_
#define PORTDNN_INCLUDE_BACKEND_SNN_MATMUL_PROVIDER_H_
/**
* \file
* Contains the implementation of \ref sycldnn::backend::SNNMatmulProvider,
* which provides matmul and batch_matmul implementations using the internal
* portDNN matmul kernels.
*/
#include "portdnn/backend/backend_traits.h"
#include "portdnn/backend/internal_backend.h"
#include "portdnn/matmul/launch.h"
#include "portdnn/matmul/params.h"
namespace sycldnn {
namespace backend {
/**
* CRTP module to provide matmul and batch_matmul implementations using the
* internal portDNN kernels.
*/
template <typename Backend>
struct SNNMatmulProvider {
private:
/** The pointer representation required by the internal handler. */
template <typename T>
using internal_pointer_type =
typename BackendTraits<Backend>::template internal_pointer_type<T>;
public:
/**
* A wrapper around a call to GEMM.
*
* Perform the matrix multiply operation:
* \code
* output = lhs * rhs + beta * output
* \endcode
* where lhs is a [m x k] matrix, rhs is a [k x n] matrix. The `bool`
* template parameters determine whether or not to transpose the matrices.
* The matrices provided here are assumed to be in row-major ordering.
*
* \param [in] lhs Pointer to a buffer containing the LHS matrix.
* \param [in] rhs Pointer to a buffer containing the RHS matrix.
* \param [in,out] output Pointer to a buffer containing the output matrix.
* \param [in] beta Scale multiplier for the output matrix.
* \param [in] m Number of rows in the LHS matrix.
* \param [in] k Number of columns in the LHS matrix and rows in the
* RHS matrix.
* \param [in] n Number of columns in the RHS matrix.
*
* \return A SYCL event corresponding to the matmul kernel launch.
*/
template <bool TransposeLHS, bool TransposeRHS, typename T, typename Index>
cl::sycl::event matmul(internal_pointer_type<const T> const lhs,
internal_pointer_type<const T> const rhs,
internal_pointer_type<T> const output, T const beta,
Index const m, Index const k, Index const n,
const std::vector<cl::sycl::event>& = {}) {
auto& underlying_backend = static_cast<Backend&>(*this);
internal::InternalBackend<Backend> internal_backend{underlying_backend};
auto status = matmul::launch<T, TransposeLHS, TransposeRHS>(
lhs, rhs, output, sycldnn::matmul::MatmulParams{1, m, k, n, beta},
internal_backend);
SNN_ASSERT(status.status == StatusCode::OK,
"Error launching matmul kernel.");
return status.event;
}
/**
* Compute a batch of matrix multiplies.
*
* Perform the batched matrix multiply operation:
* \code
* output[i] = lhs[i] * rhs[i]
* \endcode
* for 0 <= i < batch, where lhs is a [batch x m x k] tensor and rhs is a
* [batch x k x n] tensor. Each matrix is assumed to be contiguous in memory
* and in row-major format. The `bool` template parameters determine whether
* or not to transpose the matrices.
*
* \param [in] lhs Pointer to a buffer containing the LHS matrix.
* \param [in] rhs Pointer to a buffer containing the RHS matrix.
* \param [in,out] output Pointer to a buffer containing the output
* matrix.
* \param [in] n_batches Scale multiplier for the output matrix.
* \param [in] m Number of rows in the LHS matrix.
* \param [in] k Number of columns in the LHS matrix and rows in
* the RHS matrix.
* \param [in] n Number of columns in the RHS matrix.
* \param [in] batch_type Format indicating how the batches are layed out.
*
* \return A SYCL event corresponding to the matmul kernel launch.
*/
template <bool TransposeLHS, bool TransposeRHS, typename T, typename Index>
cl::sycl::event batch_matmul(
internal_pointer_type<const T> const lhs,
internal_pointer_type<const T> const rhs,
internal_pointer_type<T> const output, Index const n_batches,
Index const m, Index const k, Index const n,
sycldnn::BatchFormat const batch_type = sycldnn::BatchFormat::STRIDED,
const std::vector<cl::sycl::event>& = {}) {
if (batch_type != sycldnn::BatchFormat::STRIDED) {
throw std::runtime_error(
"SNN batch matmul only supports strided batch format.");
}
auto& underlying_backend = static_cast<Backend&>(*this);
internal::InternalBackend<Backend> internal_backend{underlying_backend};
auto status = matmul::launch<T, TransposeLHS, TransposeRHS>(
lhs, rhs, output,
sycldnn::matmul::MatmulParams{n_batches, m, k, n, T{0}},
internal_backend);
SNN_ASSERT(status.status == StatusCode::OK,
"Error launching matmul kernel.");
return status.event;
}
};
} // namespace backend
} // namespace sycldnn
#endif // PORTDNN_INCLUDE_BACKEND_SNN_MATMUL_PROVIDER_H_