diff --git a/experimental/cuda2/CMakeLists.txt b/experimental/cuda2/CMakeLists.txt index ca1959516917..d06868660ac4 100644 --- a/experimental/cuda2/CMakeLists.txt +++ b/experimental/cuda2/CMakeLists.txt @@ -21,15 +21,20 @@ iree_cc_library( "cuda_allocator.h" "cuda_buffer.c" "cuda_buffer.h" + "cuda_device.c" + "cuda_device.h" "cuda_driver.c" "memory_pools.c" "memory_pools.h" DEPS ::dynamic_symbols iree::base + iree::base::internal + iree::base::internal::arena iree::base::core_headers iree::base::tracing iree::hal + iree::hal::utils::buffer_transfer iree::schemas::cuda_executable_def_c_fbs PUBLIC ) diff --git a/experimental/cuda2/api.h b/experimental/cuda2/api.h index a56b656b2b7c..777d6ef238d1 100644 --- a/experimental/cuda2/api.h +++ b/experimental/cuda2/api.h @@ -17,7 +17,7 @@ extern "C" { #endif // __cplusplus //===----------------------------------------------------------------------===// -// iree_hal_cuda_device_t +// iree_hal_cuda2_device_t //===----------------------------------------------------------------------===// // Parameters defining a CUmemoryPool. @@ -40,6 +40,32 @@ typedef struct iree_hal_cuda2_memory_pooling_params_t { iree_hal_cuda2_memory_pool_params_t other; } iree_hal_cuda2_memory_pooling_params_t; +// Parameters configuring an iree_hal_cuda2_device_t. +// Must be initialized with iree_hal_cuda2_device_params_initialize prior to +// use. +typedef struct iree_hal_cuda2_device_params_t { + // Number of queues exposed on the device. + // Each queue acts as a separate synchronization scope where all work executes + // concurrently unless prohibited by semaphores. + iree_host_size_t queue_count; + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // Whether to use async allocations even if reported as available by the + // device. Defaults to true when the device supports it. + bool async_allocations; + + // Parameters for each CUmemoryPool used for queue-ordered allocations. + iree_hal_cuda2_memory_pooling_params_t memory_pools; +} iree_hal_cuda2_device_params_t; + +// Initializes |out_params| to default values. +IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize( + iree_hal_cuda2_device_params_t* out_params); + //===----------------------------------------------------------------------===// // iree_hal_cuda2_driver_t //===----------------------------------------------------------------------===// @@ -62,6 +88,7 @@ IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize( IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create( iree_string_view_t identifier, const iree_hal_cuda2_driver_options_t* options, + const iree_hal_cuda2_device_params_t* default_params, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); #ifdef __cplusplus diff --git a/experimental/cuda2/cts/CMakeLists.txt b/experimental/cuda2/cts/CMakeLists.txt new file mode 100644 index 000000000000..7c04876827ee --- /dev/null +++ b/experimental/cuda2/cts/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_hal_cts_test_suite( + DRIVER_NAME + cuda2 + DRIVER_REGISTRATION_HDR + "experimental/cuda2/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_cuda2_driver_module_register" + COMPILER_TARGET_BACKEND + "cuda" + EXECUTABLE_FORMAT + "\"PTXE\"" + DEPS + iree::experimental::cuda2::registration + INCLUDED_TESTS + "allocator" + "buffer_mapping" + "driver" +) diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c new file mode 100644 index 000000000000..9ea9ddce0e3f --- /dev/null +++ b/experimental/cuda2/cuda_device.c @@ -0,0 +1,481 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/cuda2/cuda_device.h" + +#include +#include +#include +#include + +#include "experimental/cuda2/cuda_allocator.h" +#include "experimental/cuda2/cuda_buffer.h" +#include "experimental/cuda2/cuda_dynamic_symbols.h" +#include "experimental/cuda2/cuda_status_util.h" +#include "experimental/cuda2/memory_pools.h" +#include "iree/base/internal/arena.h" +#include "iree/base/internal/math.h" +#include "iree/base/tracing.h" +#include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/deferred_command_buffer.h" + +//===----------------------------------------------------------------------===// +// iree_hal_cuda2_device_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_cuda2_device_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + iree_string_view_t identifier; + + // Block pool used for command buffers with a larger block size (as command + // buffers can contain inlined data uploads). + iree_arena_block_pool_t block_pool; + + // Optional driver that owns the CUDA symbols. We retain it for our lifetime + // to ensure the symbols remains valid. + iree_hal_driver_t* driver; + + const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols; + + // Parameters used to control device behavior. + iree_hal_cuda2_device_params_t params; + + CUcontext cu_context; + CUdevice cu_device; + // TODO: support multiple streams. + CUstream cu_stream; + + iree_allocator_t host_allocator; + + // Device memory pools and allocators. + bool supports_memory_pools; + iree_hal_cuda2_memory_pools_t memory_pools; + iree_hal_allocator_t* device_allocator; +} iree_hal_cuda2_device_t; + +static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable; + +static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast( + iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_device_vtable); + return (iree_hal_cuda2_device_t*)base_value; +} + +static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast_unsafe( + iree_hal_device_t* base_value) { + return (iree_hal_cuda2_device_t*)base_value; +} + +IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize( + iree_hal_cuda2_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32 * 1024; + out_params->queue_count = 1; + out_params->async_allocations = true; +} + +static iree_status_t iree_hal_cuda2_device_check_params( + const iree_hal_cuda2_device_params_t* params) { + if (params->arena_block_size < 4096) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "arena block size too small (< 4096 bytes)"); + } + if (params->queue_count == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one queue is required"); + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_device_create_internal( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_cuda2_device_params_t* params, CUdevice cu_device, + CUstream stream, CUcontext context, + const iree_hal_cuda2_dynamic_symbols_t* symbols, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_cuda2_device_t* device = NULL; + iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&device)); + memset(device, 0, total_size); + + iree_hal_resource_initialize(&iree_hal_cuda2_device_vtable, + &device->resource); + iree_string_view_append_to_buffer( + identifier, &device->identifier, + (char*)device + iree_sizeof_struct(*device)); + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); + device->driver = driver; + iree_hal_driver_retain(device->driver); + device->cuda_symbols = symbols; + device->params = *params; + device->cu_context = context; + device->cu_device = cu_device; + device->cu_stream = stream; + device->host_allocator = host_allocator; + + iree_status_t status = iree_ok_status(); + + // Memory pool support is conditional. + if (iree_status_is_ok(status) && params->async_allocations) { + int supports_memory_pools = 0; + status = IREE_CURESULT_TO_STATUS( + symbols, + cuDeviceGetAttribute(&supports_memory_pools, + CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, + cu_device), + "cuDeviceGetAttribute"); + device->supports_memory_pools = supports_memory_pools != 0; + } + + // Create memory pools first so that we can share them with the allocator. + if (iree_status_is_ok(status) && device->supports_memory_pools) { + status = iree_hal_cuda2_memory_pools_initialize( + symbols, cu_device, ¶ms->memory_pools, host_allocator, + &device->memory_pools); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_cuda2_allocator_create( + (iree_hal_device_t*)device, symbols, cu_device, stream, + device->supports_memory_pools ? &device->memory_pools : NULL, + host_allocator, &device->device_allocator); + } + + if (iree_status_is_ok(status)) { + *out_device = (iree_hal_device_t*)device; + } else { + iree_hal_device_release((iree_hal_device_t*)device); + } + return status; +} + +iree_status_t iree_hal_cuda2_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_cuda2_device_params_t* params, + const iree_hal_cuda2_dynamic_symbols_t* symbols, CUdevice device, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(driver); + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(out_device); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_cuda2_device_check_params(params); + + // Get the main context for the device. + CUcontext context = NULL; + if (iree_status_is_ok(status)) { + status = IREE_CURESULT_TO_STATUS( + symbols, cuDevicePrimaryCtxRetain(&context, device)); + } + if (iree_status_is_ok(status)) { + status = IREE_CURESULT_TO_STATUS(symbols, cuCtxSetCurrent(context)); + } + + // Create the default stream for the device. + CUstream stream = NULL; + if (iree_status_is_ok(status)) { + status = IREE_CURESULT_TO_STATUS( + symbols, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_cuda2_device_create_internal( + driver, identifier, params, device, stream, context, symbols, + host_allocator, out_device); + } + if (!iree_status_is_ok(status)) { + if (stream) symbols->cuStreamDestroy(stream); + if (context) symbols->cuDevicePrimaryCtxRelease(device); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = + iree_hal_cuda2_device_cast_unsafe(base_device); + return device->cu_context; +} + +const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols( + iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = + iree_hal_cuda2_device_cast_unsafe(base_device); + return device->cuda_symbols; +} + +static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + // There should be no more buffers live that use the allocator. + iree_hal_allocator_release(device->device_allocator); + + // Destroy memory pools that hold on to reserved memory. + iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools); + + // TODO: support multiple streams. + IREE_CUDA_IGNORE_ERROR(device->cuda_symbols, + cuStreamDestroy(device->cu_stream)); + + IREE_CUDA_IGNORE_ERROR(device->cuda_symbols, + cuDevicePrimaryCtxRelease(device->cu_device)); + + iree_arena_block_pool_deinitialize(&device->block_pool); + + // Finally, destroy the device. + iree_hal_driver_release(device->driver); + + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_string_view_t iree_hal_cuda2_device_id( + iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + return device->identifier; +} + +static iree_allocator_t iree_hal_cuda2_device_host_allocator( + iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + return device->host_allocator; +} + +static iree_hal_allocator_t* iree_hal_cuda2_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + return device->device_allocator; +} + +static void iree_hal_cuda2_replace_device_allocator( + iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + iree_hal_allocator_retain(new_allocator); + iree_hal_allocator_release(device->device_allocator); + device->device_allocator = new_allocator; +} + +static void iree_hal_cuda2_replace_channel_provider( + iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) { + // TODO: implement this together with channel support. +} + +static iree_status_t iree_hal_cuda2_device_trim( + iree_hal_device_t* base_device) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + iree_arena_block_pool_trim(&device->block_pool); + IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); + if (device->supports_memory_pools) { + IREE_RETURN_IF_ERROR(iree_hal_cuda2_memory_pools_trim( + &device->memory_pools, &device->params.memory_pools)); + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_device_query_attribute( + iree_hal_cuda2_device_t* device, CUdevice_attribute attribute, + int64_t* out_value) { + int value = 0; + IREE_CUDA_RETURN_IF_ERROR( + device->cuda_symbols, + cuDeviceGetAttribute(&value, attribute, device->cu_device), + "cuDeviceGetAttribute"); + *out_value = value; + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_device_query_i64( + iree_hal_device_t* base_device, iree_string_view_t category, + iree_string_view_t key, int64_t* out_value) { + iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); + *out_value = 0; + + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { + *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0; + return iree_ok_status(); + } + + if (iree_string_view_equal(category, IREE_SV("cuda.device"))) { + if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) { + return iree_hal_cuda2_device_query_attribute( + device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value); + } else if (iree_string_view_equal(key, + IREE_SV("compute_capability_minor"))) { + return iree_hal_cuda2_device_query_attribute( + device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value); + } + } + + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "unknown device configuration key value '%.*s :: %.*s'", + (int)category.size, category.data, (int)key.size, key.data); +} + +static iree_status_t iree_hal_cuda2_device_create_channel( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "channel not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_create_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_hal_command_buffer_t** out_command_buffer) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "command buffer not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_flags_t flags, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "descriptor set layout not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_create_event( + iree_hal_device_t* base_device, iree_hal_event_t** out_event) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "event not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_create_executable_cache( + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "executable cache not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_create_pipeline_layout( + iree_hal_device_t* base_device, iree_host_size_t push_constants, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_hal_pipeline_layout_t** out_pipeline_layout) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "pipeline layout not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_create_semaphore( + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "semaphore not yet implmeneted"); +} + +static iree_hal_semaphore_compatibility_t +iree_hal_cuda2_device_query_semaphore_compatibility( + iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) { + // TODO: implement CUDA semaphores. + return IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE; +} + +// TODO: implement multiple streams; today we only have one and queue_affinity +// is ignored. +// TODO: implement proper semaphores in CUDA to ensure ordering and avoid +// the barrier here. +static iree_status_t iree_hal_cuda2_device_queue_alloca( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, + iree_device_size_t allocation_size, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "queue alloca not yet implmeneted"); +} + +// TODO: implement multiple streams; today we only have one and queue_affinity +// is ignored. +// TODO: implement proper semaphores in CUDA to ensure ordering and avoid +// the barrier here. +static iree_status_t iree_hal_cuda2_device_queue_dealloca( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* buffer) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "queue dealloca not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_queue_execute( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "queue execution not yet implmeneted"); +} + +static iree_status_t iree_hal_cuda2_device_queue_flush( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { + // Currently unused; we flush as submissions are made. + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_device_wait_semaphores( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "semaphore not yet implemented"); +} + +static iree_status_t iree_hal_cuda2_device_profiling_begin( + iree_hal_device_t* base_device, + const iree_hal_device_profiling_options_t* options) { + // Unimplemented (and that's ok). + // We could hook in to CUPTI here or use the much simpler cuProfilerStart API. + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_device_profiling_end( + iree_hal_device_t* base_device) { + // Unimplemented (and that's ok). + return iree_ok_status(); +} + +static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = { + .destroy = iree_hal_cuda2_device_destroy, + .id = iree_hal_cuda2_device_id, + .host_allocator = iree_hal_cuda2_device_host_allocator, + .device_allocator = iree_hal_cuda2_device_allocator, + .replace_device_allocator = iree_hal_cuda2_replace_device_allocator, + .replace_channel_provider = iree_hal_cuda2_replace_channel_provider, + .trim = iree_hal_cuda2_device_trim, + .query_i64 = iree_hal_cuda2_device_query_i64, + .create_channel = iree_hal_cuda2_device_create_channel, + .create_command_buffer = iree_hal_cuda2_device_create_command_buffer, + .create_descriptor_set_layout = + iree_hal_cuda2_device_create_descriptor_set_layout, + .create_event = iree_hal_cuda2_device_create_event, + .create_executable_cache = iree_hal_cuda2_device_create_executable_cache, + .create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout, + .create_semaphore = iree_hal_cuda2_device_create_semaphore, + .query_semaphore_compatibility = + iree_hal_cuda2_device_query_semaphore_compatibility, + .transfer_range = iree_hal_device_submit_transfer_range_and_wait, + .queue_alloca = iree_hal_cuda2_device_queue_alloca, + .queue_dealloca = iree_hal_cuda2_device_queue_dealloca, + .queue_execute = iree_hal_cuda2_device_queue_execute, + .queue_flush = iree_hal_cuda2_device_queue_flush, + .wait_semaphores = iree_hal_cuda2_device_wait_semaphores, + .profiling_begin = iree_hal_cuda2_device_profiling_begin, + .profiling_end = iree_hal_cuda2_device_profiling_end, +}; diff --git a/experimental/cuda2/cuda_device.h b/experimental/cuda2/cuda_device.h new file mode 100644 index 000000000000..70dbd3694cd4 --- /dev/null +++ b/experimental/cuda2/cuda_device.h @@ -0,0 +1,48 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef EXPERIMENTAL_CUDA2_CUDA_DEVICE_H_ +#define EXPERIMENTAL_CUDA2_CUDA_DEVICE_H_ + +#include "experimental/cuda2/api.h" +#include "experimental/cuda2/cuda_dynamic_symbols.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a device that owns and manages its own CUcontext. +iree_status_t iree_hal_cuda2_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_cuda2_device_params_t* params, + const iree_hal_cuda2_dynamic_symbols_t* symbols, CUdevice device, + iree_allocator_t host_allocator, iree_hal_device_t** out_device); + +// Returns the CUDA context bound to the given |device| if it is a CUDA device +// and otherwise returns NULL. +// +// WARNING: this API is unsafe and unstable. HAL devices may have any number of +// contexts and the context may be in use on other threads. +CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* device); + +// Returns the dynamic symbol table from the |device| if it is a CUDA device +// and otherwise returns NULL. +// +// WARNING: the symbols are only valid for as long as the device is. Hosting +// libraries and applications should prefer to either link against CUDA +// themselves or maintain their own dynamic linking support: the IREE runtime +// only provides the symbols required by the HAL driver and not the entirety of +// the API. +const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols( + iree_hal_device_t* device); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // EXPERIMENTAL_CUDA2_CUDA_DEVICE_H_ diff --git a/experimental/cuda2/cuda_driver.c b/experimental/cuda2/cuda_driver.c index 8d1590f1f403..8fee8b9af881 100644 --- a/experimental/cuda2/cuda_driver.c +++ b/experimental/cuda2/cuda_driver.c @@ -8,6 +8,7 @@ #include #include "experimental/cuda2/api.h" +#include "experimental/cuda2/cuda_device.h" #include "experimental/cuda2/cuda_dynamic_symbols.h" #include "experimental/cuda2/cuda_status_util.h" #include "experimental/cuda2/nccl_dynamic_symbols.h" @@ -38,6 +39,9 @@ typedef struct iree_hal_cuda2_driver_t { // NCCL API dynamic symbols to interact with the CUDA system. iree_hal_cuda2_nccl_dynamic_symbols_t nccl_symbols; + // The default parameters for creating devices using this driver. + iree_hal_cuda2_device_params_t device_params; + // The index of the default CUDA device to use if multiple ones are available. int default_device_index; } iree_hal_cuda2_driver_t; @@ -60,6 +64,7 @@ IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize( static iree_status_t iree_hal_cuda2_driver_create_internal( iree_string_view_t identifier, const iree_hal_cuda2_driver_options_t* options, + const iree_hal_cuda2_device_params_t* device_params, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_cuda2_driver_t* driver = NULL; iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; @@ -86,6 +91,8 @@ static iree_status_t iree_hal_cuda2_driver_create_internal( if (iree_status_is_unavailable(status)) status = iree_status_ignore(status); } + memcpy(&driver->device_params, device_params, sizeof(driver->device_params)); + if (iree_status_is_ok(status)) { *out_driver = (iree_hal_driver_t*)driver; } else { @@ -97,13 +104,15 @@ static iree_status_t iree_hal_cuda2_driver_create_internal( IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create( iree_string_view_t identifier, const iree_hal_cuda2_driver_options_t* options, + const iree_hal_cuda2_device_params_t* device_params, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(device_params); IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_cuda2_driver_create_internal( - identifier, options, host_allocator, out_driver); + identifier, options, device_params, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -438,7 +447,6 @@ static iree_status_t iree_hal_cuda2_driver_create_device_by_id( iree_host_size_t param_count, const iree_string_pair_t* params, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(base_driver); - IREE_ASSERT_ARGUMENT(params); IREE_ASSERT_ARGUMENT(out_device); iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); @@ -458,10 +466,16 @@ static iree_status_t iree_hal_cuda2_driver_create_device_by_id( } else { device = IREE_DEVICE_ID_TO_CUDEVICE(device_id); } - (void)device; + + iree_string_view_t device_name = iree_make_cstring_view("cuda2"); + + // Attempt to create the device now. + iree_status_t status = iree_hal_cuda2_device_create( + base_driver, device_name, &driver->device_params, &driver->cuda_symbols, + device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); - return iree_status_from_code(IREE_STATUS_UNIMPLEMENTED); + return status; } static iree_status_t iree_hal_cuda2_driver_create_device_by_uuid( @@ -560,7 +574,6 @@ static iree_status_t iree_hal_cuda2_driver_create_device_by_path( const iree_string_pair_t* params, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(base_driver); - IREE_ASSERT_ARGUMENT(params); IREE_ASSERT_ARGUMENT(out_device); if (iree_string_view_is_empty(device_path)) { diff --git a/experimental/cuda2/memory_pools.c b/experimental/cuda2/memory_pools.c index b7b7baad9e59..c20c9ebe69e4 100644 --- a/experimental/cuda2/memory_pools.c +++ b/experimental/cuda2/memory_pools.c @@ -59,9 +59,9 @@ static iree_status_t iree_hal_cuda2_create_memory_pool( } iree_status_t iree_hal_cuda2_memory_pools_initialize( - iree_allocator_t host_allocator, const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, CUdevice cu_device, const iree_hal_cuda2_memory_pooling_params_t* pooling_params, + iree_allocator_t host_allocator, iree_hal_cuda2_memory_pools_t* IREE_RESTRICT out_pools) { IREE_ASSERT_ARGUMENT(cuda_symbols); IREE_ASSERT_ARGUMENT(pooling_params); diff --git a/experimental/cuda2/memory_pools.h b/experimental/cuda2/memory_pools.h index e88e05fd63d8..9c1e59b0c0d0 100644 --- a/experimental/cuda2/memory_pools.h +++ b/experimental/cuda2/memory_pools.h @@ -38,9 +38,9 @@ typedef struct iree_hal_cuda2_memory_pools_t { // Initializes |out_pools| by configuring new CUDA memory pools. iree_status_t iree_hal_cuda2_memory_pools_initialize( - iree_allocator_t host_allocator, const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, CUdevice cu_device, const iree_hal_cuda2_memory_pooling_params_t* pooling_params, + iree_allocator_t host_allocator, iree_hal_cuda2_memory_pools_t* IREE_RESTRICT out_pools); // Deinitializes the |pools| and releases the underlying CUDA resources. diff --git a/experimental/cuda2/registration/driver_module.c b/experimental/cuda2/registration/driver_module.c index 4089ed0d1165..b805bd5ae323 100644 --- a/experimental/cuda2/registration/driver_module.c +++ b/experimental/cuda2/registration/driver_module.c @@ -15,6 +15,10 @@ #include "iree/base/status.h" #include "iree/base/tracing.h" +IREE_FLAG( + bool, cuda_async_allocations, true, + "Enables CUDA asynchronous stream-ordered allocations when supported."); + IREE_FLAG(int32_t, cuda2_default_index, 0, "Specifies the index of the default CUDA device to use"); @@ -80,6 +84,10 @@ static iree_status_t iree_hal_cuda2_driver_factory_try_create( iree_hal_cuda2_driver_options_t driver_options; iree_hal_cuda2_driver_options_initialize(&driver_options); + iree_hal_cuda2_device_params_t device_params; + iree_hal_cuda2_device_params_initialize(&device_params); + device_params.async_allocations = FLAG_cuda_async_allocations; + driver_options.default_device_index = FLAG_cuda2_default_index; if (FLAG_cuda2_default_index_from_mpi) { driver_options.default_device_index = @@ -88,7 +96,7 @@ static iree_status_t iree_hal_cuda2_driver_factory_try_create( } iree_status_t status = iree_hal_cuda2_driver_create( - driver_name, &driver_options, host_allocator, out_driver); + driver_name, &driver_options, &device_params, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0);