From 1d3ed6e66cf9ca697a0b3f53d106595f68d68ac6 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 6 Nov 2025 12:19:12 -0800 Subject: [PATCH] Fix workspace allocation for f8f8bf16_rowwise_batched (#5098) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2105 X-link: https://github.com/meta-pytorch/MSLK/pull/6 This diff updates the workspace allocation for f8f8bf16_rowwise_batched to make sure its on the proper device. Previously, it could default to using device 0 despite other inputs being on a different gpu. Reviewed By: q10 Differential Revision: D86439655 --- .../f8f8bf16_rowwise_batched_common.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh index 339d26ccba..6987f1dc9b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh @@ -274,7 +274,8 @@ at::Tensor f8f8bf16_rowwise_batched_impl( size_t workspace_size = Gemm::get_workspace_size(arguments); // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); + at::Tensor workspace = + at::empty(workspace_size, XQ.options().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); @@ -283,7 +284,7 @@ at::Tensor f8f8bf16_rowwise_batched_impl( } // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); + status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); }