From 5edfe1a2260f1d354c79a4572b56de385fcc5591 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Wed, 27 Apr 2022 14:20:07 -0700 Subject: [PATCH] use shfl_sync instead of __shfl_sync Summary: This patch replaces CUDA-specific `__shfl_sync` used in D35758762 (https://github.com/pytorch/FBGEMM/commit/dfb36cdd5e80fa692baea688e2264b398963c88b) with `shfl_sync`, which is a wrapper that supports both NVIDIA and AMD GPUs (like D33231489 (https://github.com/pytorch/FBGEMM/commit/c6df576eb5c4c9ae6abefd0c564ed8e697969147)). Differential Revision: D35980472 fbshipit-source-id: ae76d3c6303ddcfd345fdbb16cc9c69a5860a1f2 --- .../codegen/embedding_backward_code_generator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index e7f050274c..d43de70d91 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -643,8 +643,8 @@ def rowwise_adagrad_with_counter() -> None: l2_wd = 1.0; } } - freq = __shfl_sync(0xFFFFFFFF, freq, 0); - l2_wd = __shfl_sync(0xFFFFFFFF, l2_wd, 0); + freq = shfl_sync(freq, 0); + l2_wd = shfl_sync(l2_wd, 0); at::acc_type g_local_sum_square = 0.0; @@ -703,9 +703,9 @@ def rowwise_adagrad_with_counter() -> None: } } } - multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); - adjusted_multiplier = __shfl_sync(0xFFFFFFFF, adjusted_multiplier, 0); - exp_reg_correction = __shfl_sync(0xFFFFFFFF, exp_reg_correction, 0); + multiplier = shfl_sync(multiplier, 0); + adjusted_multiplier = shfl_sync(adjusted_multiplier, 0); + exp_reg_correction = shfl_sync(exp_reg_correction, 0); """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0;