diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index b8e3dc8e21b..6fe315ba8ee 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -225,7 +225,7 @@ AOTITorchError aoti_torch_empty_strided( if (device_type == static_cast(SupportedDevices::CUDA)) { ET_CUDA_CHECK_OR_RETURN_ERROR( - cudaMallocManaged(&ptr, static_cast(nbytes))); + cudaMallocAsync(&ptr, static_cast(nbytes), cudaStreamDefault)); } else if (device_type == static_cast(SupportedDevices::CPU)) { // Ensure 16-byte alignment for CPU memory to match CUDA requirements int result = posix_memalign(&ptr, 16, nbytes); @@ -328,11 +328,14 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { ET_CUDA_CHECK_OR_RETURN_ERROR( cudaPointerGetAttributes(&attributes, data_ptr)); - if (attributes.type == cudaMemoryTypeManaged) { - // This is CUDA managed memory - free with proper synchronization - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize()); - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + if (attributes.type == cudaMemoryTypeDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaFreeAsync(data_ptr, cudaStreamDefault)); } else { + ET_CHECK_OR_RETURN_ERROR( + attributes.type != cudaMemoryTypeManaged, + Internal, + "Expected host memory but got managed!") // This is CPU memory - free immediately free(data_ptr); data_ptr = nullptr;