diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh index 339d26ccba..6987f1dc9b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh @@ -274,7 +274,8 @@ at::Tensor f8f8bf16_rowwise_batched_impl( size_t workspace_size = Gemm::get_workspace_size(arguments); // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); + at::Tensor workspace = + at::empty(workspace_size, XQ.options().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); @@ -283,7 +284,7 @@ at::Tensor f8f8bf16_rowwise_batched_impl( } // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); + status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); }