Skip to content

Commit

Permalink
[feat] Add bfloat16 value type support to the HKV for being enhanced …
Browse files Browse the repository at this point in the history
…by Ampere GPU BF16 training feature.
  • Loading branch information
MoFHeka committed Jan 31, 2024
1 parent 35d445c commit a4b7a80
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ REGISTER_KERNEL(int64, int8);
REGISTER_KERNEL(int64, int32);
REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Eigen::half);
REGISTER_KERNEL(int64, Eigen::bfloat16);

#undef REGISTER_KERNEL

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ DEFINE_PURE_GPU_HASHTABLE(int64, int8);
DEFINE_PURE_GPU_HASHTABLE(int64, int32);
DEFINE_PURE_GPU_HASHTABLE(int64, int64);
DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::half);
DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::bfloat16);

#undef DEFINE_PURE_GPU_HASHTABLE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_variable(self):
dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 200]
kv_list = [[dtypes.int64, dtypes.float32], [dtypes.int64, dtypes.int32],
[dtypes.int64, dtypes.half], [dtypes.int64, dtypes.int8],
[dtypes.int64, dtypes.int64]]
[dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.bfloat16]]
else:
dim_list = [1, 8, 16, 128]
kv_list = [[dtypes.int32, dtypes.double], [dtypes.int32, dtypes.float32],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def _get_default_devices():
[dtypes.int64, dtypes.int32],
[dtypes.int64, dtypes.int64],
[dtypes.int64, dtypes.half],
[dtypes.int64, dtypes.bfloat16],
]
if is_macos() and is_arm64():
if value_dtype == dtypes.half or value_dtype == dtypes.bfloat16:
Expand Down

0 comments on commit a4b7a80

Please sign in to comment.