Skip to content

Commit 0f885e5

Browse files
committed
xe: sdpa: clang-formatting
1 parent ec53bcc commit 0f885e5

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

include/oneapi/dnnl/dnnl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,7 @@ dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
10021002
/// otherwise.
10031003
/// @sa @ref dev_guide_sparsity
10041004
dnnl_status_t DNNL_API dnnl_memory_desc_create_host_side_scalar(
1005-
dnnl_memory_desc_t *memory_desc,
1006-
dnnl_data_type_t data_type);
1005+
dnnl_memory_desc_t *memory_desc, dnnl_data_type_t data_type);
10071006

10081007
/// Creates a memory descriptor for a region inside an area
10091008
/// described by an existing memory descriptor.

src/gpu/intel/ocl/micro_sdpa.cl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
216216
#else
217217
const global SCALE_DATA_T *scale_ptr,
218218
#endif
219-
int d, int k, int q,
220-
const global KEY_ATTR_SCALES_DATA_T *K_scales,
219+
int d, int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales,
221220
const global KEY_ATTR_ZP_DATA_T *K_zp,
222221
const global VAL_ATTR_SCALES_DATA_T *V_scales,
223222
const global VAL_ATTR_ZP_DATA_T *V_zp, const int attn_mask_type

tests/gtests/internals/test_sdpa.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,8 @@ std::pair<dnnl::reorder, memory> dequantize_prim(const engine &eng, mdt dt,
11001100
return std::make_pair(dnnl::reorder(dequantize_pd), dequantized_mem);
11011101
}
11021102

1103-
memory cpu_to_gpu(memory cpu_mem, dnnl::engine &gpu_eng, dnnl::stream gpu_strm) {
1103+
memory cpu_to_gpu(
1104+
memory cpu_mem, dnnl::engine &gpu_eng, dnnl::stream gpu_strm) {
11041105
auto gpu_md = memory::desc(cpu_mem.get_desc().get_dims(),
11051106
cpu_mem.get_desc().get_data_type(),
11061107
cpu_mem.get_desc().get_strides());

0 commit comments

Comments
 (0)