diff --git a/c_api/include/taichi/taichi_core.h b/c_api/include/taichi/taichi_core.h index 537b5ac9ca3d8..c674dbdb164d3 100644 --- a/c_api/include/taichi/taichi_core.h +++ b/c_api/include/taichi/taichi_core.h @@ -26,6 +26,9 @@ typedef struct TiRuntime_t *TiRuntime; // handle.aot_module typedef struct TiAotModule_t *TiAotModule; +// handle.event +typedef struct TiEvent_t *TiEvent; + // handle.memory typedef struct TiMemory_t *TiMemory; @@ -160,6 +163,12 @@ TI_DLL_EXPORT void *TI_API_CALL ti_map_memory(TiRuntime runtime, TI_DLL_EXPORT void TI_API_CALL ti_unmap_memory(TiRuntime runtime, TiMemory memory); +// function.create_event +TI_DLL_EXPORT TiEvent TI_API_CALL ti_create_event(TiRuntime runtime); + +// function.destroy_event +TI_DLL_EXPORT void TI_API_CALL ti_destroy_event(TiEvent event); + // function.copy_memory_device_to_device TI_DLL_EXPORT void TI_API_CALL ti_copy_memory_device_to_device(TiRuntime runtime, @@ -179,6 +188,13 @@ ti_launch_compute_graph(TiRuntime runtime, uint32_t arg_count, const TiNamedArgument *args); +// function.signal_event +TI_DLL_EXPORT void TI_API_CALL ti_signal_event(TiRuntime runtime, + TiEvent event); + +// function.reset_event +TI_DLL_EXPORT void TI_API_CALL ti_reset_event(TiRuntime runtime, TiEvent event); + // function.submit TI_DLL_EXPORT void TI_API_CALL ti_submit(TiRuntime runtime); diff --git a/c_api/include/taichi/taichi_vulkan.h b/c_api/include/taichi/taichi_vulkan.h index e019406e8f8f5..9f30955db1d10 100644 --- a/c_api/include/taichi/taichi_vulkan.h +++ b/c_api/include/taichi/taichi_vulkan.h @@ -25,6 +25,11 @@ typedef struct TiVulkanMemoryInteropInfo { VkBufferUsageFlags usage; } TiVulkanMemoryInteropInfo; +// structure.vulkan_event_interop_info +typedef struct TiVulkanEventInteropInfo { + VkEvent event; +} TiVulkanEventInteropInfo; + // function.create_vulkan_runtime TI_DLL_EXPORT TiRuntime TI_API_CALL ti_create_vulkan_runtime_ext(uint32_t api_version, @@ -53,6 +58,17 @@ ti_export_vulkan_memory(TiRuntime runtime, TiMemory memory, TiVulkanMemoryInteropInfo *interop_info); +// function.import_vulkan_event +TI_DLL_EXPORT TiEvent TI_API_CALL +ti_import_vulkan_event(TiRuntime runtime, + const TiVulkanEventInteropInfo *interop_info); + +// function.export_vulkan_event +TI_DLL_EXPORT void TI_API_CALL +ti_export_vulkan_event(TiRuntime runtime, + TiEvent event, + TiVulkanEventInteropInfo *interop_info); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/c_api/src/taichi_core_impl.cpp b/c_api/src/taichi_core_impl.cpp index d8064d3290ed6..094c5e1268e89 100644 --- a/c_api/src/taichi_core_impl.cpp +++ b/c_api/src/taichi_core_impl.cpp @@ -28,7 +28,7 @@ void Runtime::deallocate_memory(TiMemory devmem) { } AotModule::AotModule(Runtime &runtime, - std::unique_ptr &&aot_module) + std::unique_ptr aot_module) : runtime_(&runtime), aot_module_(std::move(aot_module)) { } @@ -50,6 +50,17 @@ Runtime &AotModule::runtime() { return *runtime_; } +Event::Event(Runtime &runtime, std::unique_ptr event) + : runtime_(&runtime), event_(std::move(event)) { +} + +taichi::lang::DeviceEvent &Event::get() { + return *event_; +} +Runtime &Event::runtime() { + return *runtime_; +} + // ----------------------------------------------------------------------------- TiRuntime ti_create_runtime(TiArch arch) { @@ -160,6 +171,22 @@ void ti_unmap_memory(TiRuntime runtime, TiMemory devmem) { runtime2->get().unmap(devmem2devalloc(*runtime2, devmem)); } +TiEvent ti_create_event(TiRuntime runtime) { + Runtime *runtime2 = (Runtime *)runtime; + std::unique_ptr event = + runtime2->get().create_event(); + Event *event2 = new Event(*runtime2, std::move(event)); + return (TiEvent)event2; +} +void ti_destroy_event(TiEvent event) { + if (event == nullptr) { + TI_WARN("ignored attempt to destroy event of null handle"); + return; + } + + delete (Event *)event; +} + void ti_copy_memory_device_to_device(TiRuntime runtime, const TiMemorySlice *dst_memory, const TiMemorySlice *src_memory) { @@ -407,6 +434,19 @@ void ti_launch_compute_graph(TiRuntime runtime, } ((taichi::lang::aot::CompiledGraph *)compute_graph)->run(arg_map); } + +void ti_signal_event(TiRuntime runtime, TiEvent event) { + ((Runtime *)runtime)->signal_event(&((Event *)event)->get()); +} + +void ti_reset_event(TiRuntime runtime, TiEvent event) { + ((Runtime *)runtime)->reset_event(&((Event *)event)->get()); +} + +void ti_wait_event(TiRuntime runtime, TiEvent event) { + ((Runtime *)runtime)->wait_event(&((Event *)event)->get()); +} + void ti_submit(TiRuntime runtime) { if (runtime == nullptr) { TI_WARN("ignored attempt to submit to runtime of null handle"); diff --git a/c_api/src/taichi_core_impl.h b/c_api/src/taichi_core_impl.h index da170062f5667..3f5d6ab908a17 100644 --- a/c_api/src/taichi_core_impl.h +++ b/c_api/src/taichi_core_impl.h @@ -34,6 +34,15 @@ class Runtime { virtual void buffer_copy(const taichi::lang::DevicePtr &dst, const taichi::lang::DevicePtr &src, size_t size) = 0; + virtual void signal_event(taichi::lang::DeviceEvent *event) { + TI_NOT_IMPLEMENTED + } + virtual void reset_event(taichi::lang::DeviceEvent *event) { + TI_NOT_IMPLEMENTED + } + virtual void wait_event(taichi::lang::DeviceEvent *event) { + TI_NOT_IMPLEMENTED + } virtual void submit() = 0; virtual void wait() = 0; @@ -49,13 +58,24 @@ class AotModule { public: AotModule(Runtime &runtime, - std::unique_ptr &&aot_module); + std::unique_ptr aot_module); taichi::lang::aot::CompiledGraph &get_cgraph(const std::string &name); taichi::lang::aot::Module &get(); Runtime &runtime(); }; +class Event { + Runtime *runtime_; + std::unique_ptr event_; + + public: + Event(Runtime &runtime, std::unique_ptr event); + + taichi::lang::DeviceEvent &get(); + Runtime &runtime(); +}; + namespace { [[maybe_unused]] taichi::lang::DeviceAllocation devmem2devalloc( diff --git a/c_api/src/taichi_vulkan_impl.cpp b/c_api/src/taichi_vulkan_impl.cpp index 0e0aac79b7ddc..a290a8d9bc2b0 100644 --- a/c_api/src/taichi_vulkan_impl.cpp +++ b/c_api/src/taichi_vulkan_impl.cpp @@ -126,6 +126,15 @@ void VulkanRuntime::buffer_copy(const taichi::lang::DevicePtr &dst, void VulkanRuntime::submit() { get_gfx_runtime().flush(); } +void VulkanRuntime::signal_event(taichi::lang::DeviceEvent *event) { + get_gfx_runtime().signal_event(event); +} +void VulkanRuntime::reset_event(taichi::lang::DeviceEvent *event) { + get_gfx_runtime().reset_event(event); +} +void VulkanRuntime::wait_event(taichi::lang::DeviceEvent *event) { + get_gfx_runtime().wait_event(event); +} void VulkanRuntime::wait() { // (penguinliong) It's currently waiting for the entire runtime to stop. // Should be simply waiting for its fence to finish. @@ -252,4 +261,43 @@ void ti_export_vulkan_memory(TiRuntime runtime, interop_info->usage = buffer.get()->usage; } +TiEvent ti_import_vulkan_event(TiRuntime runtime, + const TiVulkanEventInteropInfo *interop_info) { + if (runtime == nullptr) { + TI_WARN("ignored attempt to import vulkan event to runtime of null handle"); + return TI_NULL_HANDLE; + } + Runtime *runtime2 = (Runtime *)runtime; + if (runtime2->arch != taichi::Arch::vulkan) { + TI_WARN("ignored attempt to import vulkan memory to non-vulkan runtime"); + return TI_NULL_HANDLE; + } + + vkapi::IVkEvent event = std::make_unique(); + event->device = runtime2->as_vk()->get_vk().vk_device(); + event->event = interop_info->event; + event->external = true; + + std::unique_ptr event2( + new taichi::lang::vulkan::VulkanDeviceEvent(std::move(event))); + + return (TiEvent) new Event(*runtime2, std::move(event2)); +} +void ti_export_vulkan_event(TiRuntime runtime, + TiEvent event, + TiVulkanEventInteropInfo *interop_info) { + if (runtime == nullptr) { + TI_WARN( + "ignored attempt to export vulkan memory from runtime of null handle"); + return; + } + if (event == nullptr) { + TI_WARN("ignored attempt to export vulkan memory of null handle"); + return; + } + auto event2 = + (taichi::lang::vulkan::VulkanDeviceEvent *)(&((Event *)event)->get()); + interop_info->event = event2->vkapi_ref->event; +} + #endif // TI_WITH_VULKAN diff --git a/c_api/src/taichi_vulkan_impl.h b/c_api/src/taichi_vulkan_impl.h index a14e43c75b958..03136b373eb68 100644 --- a/c_api/src/taichi_vulkan_impl.h +++ b/c_api/src/taichi_vulkan_impl.h @@ -25,6 +25,9 @@ class VulkanRuntime : public Runtime { virtual void buffer_copy(const taichi::lang::DevicePtr &dst, const taichi::lang::DevicePtr &src, size_t size) override final; + virtual void signal_event(taichi::lang::DeviceEvent *event) override final; + virtual void reset_event(taichi::lang::DeviceEvent *event) override final; + virtual void wait_event(taichi::lang::DeviceEvent *event) override final; virtual void submit() override final; virtual void wait() override final; }; diff --git a/c_api/taichi.json b/c_api/taichi.json index c4418451534b8..c35ac259c53ad 100644 --- a/c_api/taichi.json +++ b/c_api/taichi.json @@ -50,6 +50,11 @@ "type": "handle", "is_dispatchable": true }, + { + "name": "event", + "type": "handle", + "is_dispatchable": true + }, { "name": "memory", "type": "handle", @@ -299,9 +304,32 @@ } ] }, + { + "name": "create_event", + "type": "function", + "parameters": [ + { + "name": "@return", + "type": "handle.event" + }, + { + "type": "handle.runtime" + } + ] + }, + { + "name": "destroy_event", + "type": "function", + "parameters": [ + { + "type": "handle.event" + } + ] + }, { "name": "copy_memory_device_to_device", "type": "function", + "is_device_command": true, "parameters": [ { "type": "handle.runtime" @@ -321,6 +349,7 @@ { "name": "launch_kernel", "type": "function", + "is_device_command": true, "parameters": [ { "type": "handle.runtime" @@ -342,6 +371,7 @@ { "name": "launch_compute_graph", "type": "function", + "is_device_command": true, "parameters": [ { "type": "handle.runtime" @@ -360,6 +390,45 @@ } ] }, + { + "name": "signal_event", + "type": "function", + "is_device_command": true, + "parameters": [ + { + "type": "handle.runtime" + }, + { + "type": "handle.event" + } + ] + }, + { + "name": "reset_event", + "type": "function", + "is_device_command": true, + "parameters": [ + { + "type": "handle.runtime" + }, + { + "type": "handle.event" + } + ] + }, + { + "name": "wait_event", + "type": "function", + "is_device_command": true, + "parameters": [ + { + "type": "handle.runtime" + }, + { + "type": "handle.event" + } + ] + }, { "name": "submit", "type": "function", @@ -503,6 +572,16 @@ } ] }, + { + "name": "vulkan_event_interop_info", + "type": "structure", + "fields": [ + { + "name": "event", + "type": "VkEvent" + } + ] + }, { "name": "create_vulkan_runtime", "type": "function", @@ -597,6 +676,41 @@ "by_mut": true } ] + }, + { + "name": "import_vulkan_event", + "type": "function", + "parameters": [ + { + "name": "@return", + "type": "handle.event" + }, + { + "type": "handle.runtime" + }, + { + "name": "interop_info", + "type": "structure.vulkan_event_interop_info", + "by_ref": true + } + ] + }, + { + "name": "export_vulkan_event", + "type": "function", + "parameters": [ + { + "type": "handle.runtime" + }, + { + "type": "handle.event" + }, + { + "name": "interop_info", + "type": "structure.vulkan_event_interop_info", + "by_mut": true + } + ] } ] }, diff --git a/misc/generate_c_api.py b/misc/generate_c_api.py index f57dbf1e6deea..5189107b3a459 100644 --- a/misc/generate_c_api.py +++ b/misc/generate_c_api.py @@ -158,6 +158,7 @@ def generate_module_header(module): BuiltInType("VkQueue", "VkQueue"), BuiltInType("VkBuffer", "VkBuffer"), BuiltInType("VkBufferUsageFlags", "VkBufferUsageFlags"), + BuiltInType("VkEvent", "VkEvent"), } for module in Module.load_all(builtin_tys): diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 5f05377004788..3aff6484a6799 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -535,21 +535,14 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { return true; // on CUDA, pass the argument by value } - llvm::Value *create_intrinsic_load(const DataType &dtype, - llvm::Value *data_ptr) override { - // Issue an CUDA "__ldg" instruction so that data are cached in - // the CUDA read-only data cache. - auto llvm_dtype = llvm_type(dtype); - auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0); - llvm::Intrinsic::ID intrin; - if (is_real(dtype)) { - intrin = llvm::Intrinsic::nvvm_ldg_global_f; - } else { - intrin = llvm::Intrinsic::nvvm_ldg_global_i; - } + llvm::Value *create_intrinsic_load(llvm::Value *ptr, + llvm::Type *ty) override { + // Issue an "__ldg" instruction to cache data in the read-only data cache. + auto intrin = ty->isFloatingPointTy() ? llvm::Intrinsic::nvvm_ldg_global_f + : llvm::Intrinsic::nvvm_ldg_global_i; return builder->CreateIntrinsic( - intrin, {llvm_dtype, llvm_dtype_ptr}, - {data_ptr, tlctx->get_constant(data_type_size(dtype))}); + intrin, {ty, llvm::PointerType::get(ty, 0)}, + {ptr, tlctx->get_constant(ty->getScalarSizeInBits())}); } void visit(GlobalLoadStmt *stmt) override { diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index a342274a7c2aa..6a377ea74fadb 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1416,8 +1416,8 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { } } -llvm::Value *CodeGenLLVM::create_intrinsic_load(const DataType &dtype, - llvm::Value *data_ptr) { +llvm::Value *CodeGenLLVM::create_intrinsic_load(llvm::Value *ptr, + llvm::Type *ty) { TI_NOT_IMPLEMENTED; } @@ -1428,24 +1428,28 @@ void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); auto get_ch = stmt->src->as(); - auto physical_type = get_ch->input_snode->physical_type; + auto physical_type = llvm_type(get_ch->input_snode->physical_type); + auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); + auto physical_value = should_cache_as_read_only + ? create_intrinsic_load(byte_ptr, physical_type) + : builder->CreateLoad(physical_type, byte_ptr); if (auto qit = val_type->cast()) { - llvm_val[stmt] = - load_quant_int(ptr, qit, physical_type, should_cache_as_read_only); + llvm_val[stmt] = extract_quant_int(physical_value, bit_offset, qit); } else if (auto qfxt = val_type->cast()) { - llvm_val[stmt] = - load_quant_fixed(ptr, qfxt, physical_type, should_cache_as_read_only); + qit = qfxt->get_digits_type()->as(); + auto digits = extract_quant_int(physical_value, bit_offset, qit); + llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt); } else { TI_ASSERT(val_type->is()); TI_ASSERT(get_ch->input_snode->dt->is()); - llvm_val[stmt] = load_quant_float( - ptr, get_ch->input_snode->dt->as(), - get_ch->output_snode->id_in_bit_struct, should_cache_as_read_only); + llvm_val[stmt] = extract_quant_float( + physical_value, get_ch->input_snode->dt->as(), + get_ch->output_snode->id_in_bit_struct); } } else { // Byte pointer case. if (should_cache_as_read_only) { - llvm_val[stmt] = create_intrinsic_load(stmt->ret_type, ptr); + llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type)); } else { llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr); @@ -1610,14 +1614,6 @@ std::tuple CodeGenLLVM::load_bit_ptr( return std::make_tuple(byte_ptr, bit_offset); } -llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr, - int bit_offset_delta) { - auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr); - auto new_bit_offset = - builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta)); - return create_bit_ptr(byte_ptr, new_bit_offset); -} - void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { llvm::Value *parent = nullptr; parent = llvm_val[stmt->input_snode]; diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 09588e85968d3..c4d052a220197 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -250,46 +250,24 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void store_quant_floats_with_shared_exponents(BitStructStoreStmt *stmt); - llvm::Value *extract_quant_float(llvm::Value *local_bit_struct, - SNode *digits_snode); - - virtual llvm::Value *create_intrinsic_load(const DataType &dtype, - llvm::Value *data_ptr); - - llvm::Value *load_quant_int(llvm::Value *ptr, - QuantIntType *qit, - Type *physical_type, - bool should_cache_as_read_only); + llvm::Value *extract_quant_float(llvm::Value *physical_value, + BitStructType *bit_struct, + int digits_id); llvm::Value *extract_quant_int(llvm::Value *physical_value, llvm::Value *bit_offset, QuantIntType *qit); - llvm::Value *load_quant_fixed(llvm::Value *ptr, - QuantFixedType *qfxt, - Type *physical_type, - bool should_cache_as_read_only); - llvm::Value *reconstruct_quant_fixed(llvm::Value *digits, QuantFixedType *qfxt); - llvm::Value *load_quant_float(llvm::Value *digits_ptr, - BitStructType *bit_struct, - int digits_id, - bool should_cache_as_read_only); - - llvm::Value *load_quant_float(llvm::Value *digits_ptr, - llvm::Value *exponent_ptr, - QuantFloatType *qflt, - Type *physical_type, - bool should_cache_as_read_only, - bool shared_exponent); - llvm::Value *reconstruct_quant_float(llvm::Value *input_digits, llvm::Value *input_exponent_val, QuantFloatType *qflt, bool shared_exponent); + virtual llvm::Value *create_intrinsic_load(llvm::Value *ptr, llvm::Type *ty); + void create_global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only); void visit(GlobalLoadStmt *stmt) override; @@ -308,8 +286,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { std::tuple load_bit_ptr(llvm::Value *bit_ptr); - llvm::Value *offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta); - void visit(SNodeLookupStmt *stmt) override; void visit(GetChStmt *stmt) override; diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 57e39c9361197..6acb4e8753570 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -324,8 +324,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( BitStructStoreStmt *stmt) { // handle each exponent separately auto snode = stmt->get_bit_struct_snode(); - auto bit_struct_physical_type = - snode->dt->as()->get_physical_type(); + auto bit_struct = snode->dt->as(); + auto bit_struct_physical_type = bit_struct->get_physical_type(); auto local_bit_struct = builder->CreateLoad( #ifdef TI_LLVM_15 llvm_type(bit_struct_physical_type), @@ -352,7 +352,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( input != stmt->ch_ids.end()) { floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]); } else { - floats.push_back(extract_quant_float(local_bit_struct, user)); + floats.push_back( + extract_quant_float(local_bit_struct, bit_struct, ch_id)); } } // convert to i32 for bit operations @@ -479,34 +480,21 @@ llvm::Value *CodeGenLLVM::extract_digits_from_f32_with_shared_exponent( return builder->CreateLShr(digits, exp_offset); } -llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *local_bit_struct, - SNode *digits_snode) { - auto qflt = digits_snode->dt->as(); - auto exponent_type = qflt->get_exponent_type()->as(); - auto digits_type = qflt->get_digits_type()->as(); - auto digits = extract_quant_int(local_bit_struct, - tlctx->get_constant(digits_snode->bit_offset), - digits_type); +llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *physical_value, + BitStructType *bit_struct, + int digits_id) { + auto qflt = bit_struct->get_member_type(digits_id)->as(); + auto exponent_id = bit_struct->get_member_exponent(digits_id); + auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id); + auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id); + auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id); + auto digits = + extract_quant_int(physical_value, tlctx->get_constant(digits_bit_offset), + qflt->get_digits_type()->as()); auto exponent = extract_quant_int( - local_bit_struct, - tlctx->get_constant(digits_snode->exp_snode->bit_offset), exponent_type); - return reconstruct_quant_float(digits, exponent, qflt, - digits_snode->owns_shared_exponent); -} - -llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr, - QuantIntType *qit, - Type *physical_type, - bool should_cache_as_read_only) { - auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); - auto physical_value = should_cache_as_read_only - ? create_intrinsic_load(physical_type, byte_ptr) - : builder->CreateLoad( -#ifdef TI_LLVM_15 - llvm_type(physical_type), -#endif - byte_ptr); - return extract_quant_int(physical_value, bit_offset, qit); + physical_value, tlctx->get_constant(exponent_bit_offset), + qflt->get_exponent_type()->as()); + return reconstruct_quant_float(digits, exponent, qflt, shared_exponent); } llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value, @@ -537,15 +525,6 @@ llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value, qit->get_is_signed()); } -llvm::Value *CodeGenLLVM::load_quant_fixed(llvm::Value *ptr, - QuantFixedType *qfxt, - Type *physical_type, - bool should_cache_as_read_only) { - auto digits = load_quant_int(ptr, qfxt->get_digits_type()->as(), - physical_type, should_cache_as_read_only); - return reconstruct_quant_fixed(digits, qfxt); -} - llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, QuantFixedType *qfxt) { // Compute float(digits) * scale @@ -561,37 +540,6 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, return builder->CreateFMul(cast, s); } -llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr, - BitStructType *bit_struct, - int digits_id, - bool should_cache_as_read_only) { - auto exponent_id = bit_struct->get_member_exponent(digits_id); - auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id); - auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id); - auto bit_offset_delta = exponent_bit_offset - digits_bit_offset; - auto exponent_ptr = offset_bit_ptr(digits_ptr, bit_offset_delta); - auto qflt = bit_struct->get_member_type(digits_id)->as(); - auto physical_type = bit_struct->get_physical_type(); - auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id); - return load_quant_float(digits_ptr, exponent_ptr, qflt, physical_type, - should_cache_as_read_only, shared_exponent); -} - -llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr, - llvm::Value *exponent_ptr, - QuantFloatType *qflt, - Type *physical_type, - bool should_cache_as_read_only, - bool shared_exponent) { - auto digits = - load_quant_int(digits_ptr, qflt->get_digits_type()->as(), - physical_type, should_cache_as_read_only); - auto exponent_val = load_quant_int( - exponent_ptr, qflt->get_exponent_type()->as(), - physical_type, should_cache_as_read_only); - return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent); -} - llvm::Value *CodeGenLLVM::reconstruct_quant_float( llvm::Value *input_digits, llvm::Value *input_exponent_val, diff --git a/taichi/codegen/llvm/llvm_codegen_utils.h b/taichi/codegen/llvm/llvm_codegen_utils.h index 5ab5b2cd759ce..c21c0295e24d9 100644 --- a/taichi/codegen/llvm/llvm_codegen_utils.h +++ b/taichi/codegen/llvm/llvm_codegen_utils.h @@ -126,7 +126,7 @@ class LLVMModuleBuilder { std::vector args = arglist; check_func_call_signature(func->getFunctionType(), func->getName(), args, builder); - return builder->CreateCall(func, arglist); + return builder->CreateCall(func, args); } template diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 19da1b753b556..dd6065ed240ac 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -325,6 +325,9 @@ SNode *SNode::get_dual() const { void SNode::set_snode_tree_id(int id) { snode_tree_id_ = id; + for (auto &child : ch) { + child->set_snode_tree_id(id); + } } int SNode::get_snode_tree_id() const { diff --git a/taichi/rhi/device.h b/taichi/rhi/device.h index 3acb2db57b78a..ee378cf0648c5 100644 --- a/taichi/rhi/device.h +++ b/taichi/rhi/device.h @@ -295,6 +295,12 @@ struct ImageCopyParams { uint32_t depth{1}; }; +class DeviceEvent { + public: + virtual ~DeviceEvent() { + } +}; + class CommandList { public: virtual ~CommandList() { @@ -383,6 +389,15 @@ class CommandList { const ImageCopyParams ¶ms) { TI_NOT_IMPLEMENTED } + virtual void signal_event(DeviceEvent *event) { + TI_NOT_IMPLEMENTED + } + virtual void reset_event(DeviceEvent *event) { + TI_NOT_IMPLEMENTED + } + virtual void wait_event(DeviceEvent *event) { + TI_NOT_IMPLEMENTED + } }; struct PipelineSourceDesc { @@ -467,6 +482,8 @@ class Device { const PipelineSourceDesc &src, std::string name = "Pipeline") = 0; + virtual std::unique_ptr create_event(){TI_NOT_IMPLEMENTED} + std::unique_ptr allocate_memory_unique( const AllocParams ¶ms) { return std::make_unique( diff --git a/taichi/rhi/vulkan/vulkan_api.cpp b/taichi/rhi/vulkan/vulkan_api.cpp index 7066e96e4e064..ce8ffb947743f 100644 --- a/taichi/rhi/vulkan/vulkan_api.cpp +++ b/taichi/rhi/vulkan/vulkan_api.cpp @@ -54,6 +54,12 @@ DeviceObjVkFramebuffer::~DeviceObjVkFramebuffer() { vkDestroyFramebuffer(device, framebuffer, nullptr); } +DeviceObjVkEvent::~DeviceObjVkEvent() { + if (!external) { + vkDestroyEvent(device, event, nullptr); + } +} + DeviceObjVkSemaphore::~DeviceObjVkSemaphore() { vkDestroySemaphore(device, semaphore, nullptr); } @@ -91,6 +97,21 @@ IDeviceObj create_device_obj(VkDevice device) { return obj; } +IVkEvent create_event(VkDevice device, + VkSemaphoreCreateFlags flags, + void *pnext) { + IVkEvent obj = std::make_shared(); + obj->device = device; + + VkEventCreateInfo info{}; + info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; + info.pNext = pnext; + info.flags = flags; + + vkCreateEvent(device, &info, nullptr, &obj->event); + return obj; +} + IVkSemaphore create_semaphore(VkDevice device, VkSemaphoreCreateFlags flags, void *pnext) { diff --git a/taichi/rhi/vulkan/vulkan_api.h b/taichi/rhi/vulkan/vulkan_api.h index cc5fec8baf832..f6adc733e8236 100644 --- a/taichi/rhi/vulkan/vulkan_api.h +++ b/taichi/rhi/vulkan/vulkan_api.h @@ -18,6 +18,17 @@ struct DeviceObj { using IDeviceObj = std::shared_ptr; IDeviceObj create_device_obj(VkDevice device); +// VkEvent +struct DeviceObjVkEvent : public DeviceObj { + bool external{false}; + VkEvent event{VK_NULL_HANDLE}; + ~DeviceObjVkEvent() override; +}; +using IVkEvent = std::shared_ptr; +IVkEvent create_event(VkDevice device, + VkEventCreateFlags flags, + void *pnext = nullptr); + // VkSemaphore struct DeviceObjVkSemaphore : public DeviceObj { VkSemaphore semaphore{VK_NULL_HANDLE}; diff --git a/taichi/rhi/vulkan/vulkan_device.cpp b/taichi/rhi/vulkan/vulkan_device.cpp index a7de6a4981f66..c67cb6e624013 100644 --- a/taichi/rhi/vulkan/vulkan_device.cpp +++ b/taichi/rhi/vulkan/vulkan_device.cpp @@ -1244,6 +1244,24 @@ void VulkanCommandList::blit_image(DeviceAllocation dst_img, buffer_->refs.push_back(src_vk_image); } +void VulkanCommandList::signal_event(DeviceEvent *event) { + VulkanDeviceEvent *event2 = static_cast(event); + vkCmdSetEvent(buffer_->buffer, event2->vkapi_ref->event, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT); +} +void VulkanCommandList::reset_event(DeviceEvent *event) { + VulkanDeviceEvent *event2 = static_cast(event); + vkCmdResetEvent(buffer_->buffer, event2->vkapi_ref->event, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT); +} +void VulkanCommandList::wait_event(DeviceEvent *event) { + VulkanDeviceEvent *event2 = static_cast(event); + vkCmdWaitEvents(buffer_->buffer, 1, &event2->vkapi_ref->event, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, 0, nullptr, 0, nullptr, 0, + nullptr); +} + void VulkanCommandList::set_line_width(float width) { if (ti_device_->get_cap(DeviceCapability::wide_lines)) { vkCmdSetLineWidth(buffer_->buffer, width); @@ -1334,6 +1352,11 @@ std::unique_ptr VulkanDevice::create_pipeline( return std::make_unique(params); } +std::unique_ptr VulkanDevice::create_event() { + return std::unique_ptr( + new VulkanDeviceEvent(vkapi::create_event(device_, 0))); +} + // #define TI_VULKAN_DEBUG_ALLOCATIONS DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { diff --git a/taichi/rhi/vulkan/vulkan_device.h b/taichi/rhi/vulkan/vulkan_device.h index bcfcfe61ee488..86ad7599a7b83 100644 --- a/taichi/rhi/vulkan/vulkan_device.h +++ b/taichi/rhi/vulkan/vulkan_device.h @@ -349,6 +349,16 @@ class VulkanPipeline : public Pipeline { vkapi::IVkPipelineLayout pipeline_layout_{VK_NULL_HANDLE}; }; +class VulkanDeviceEvent : public DeviceEvent { + public: + VulkanDeviceEvent(vkapi::IVkEvent event) : vkapi_ref(event) { + } + ~VulkanDeviceEvent() { + } + + vkapi::IVkEvent vkapi_ref{nullptr}; +}; + class VulkanCommandList : public CommandList { public: VulkanCommandList(VulkanDevice *ti_device, @@ -406,6 +416,10 @@ class VulkanCommandList : public CommandList { ImageLayout src_img_layout, const ImageCopyParams ¶ms) override; + void signal_event(DeviceEvent *event) override; + void reset_event(DeviceEvent *event) override; + void wait_event(DeviceEvent *event) override; + vkapi::IVkRenderPass current_renderpass(); // Vulkan specific functions @@ -548,6 +562,7 @@ class TI_DLL_EXPORT VulkanDevice : public GraphicsDevice { std::unique_ptr create_pipeline( const PipelineSourceDesc &src, std::string name = "Pipeline") override; + std::unique_ptr create_event() override; DeviceAllocation allocate_memory(const AllocParams ¶ms) override; void dealloc_memory(DeviceAllocation handle) override; diff --git a/taichi/runtime/gfx/runtime.cpp b/taichi/runtime/gfx/runtime.cpp index e742a36319790..f92a26061cf76 100644 --- a/taichi/runtime/gfx/runtime.cpp +++ b/taichi/runtime/gfx/runtime.cpp @@ -561,6 +561,22 @@ void GfxRuntime::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) { submit_current_cmdlist_if_timeout(); } +void GfxRuntime::signal_event(DeviceEvent *event) { + ensure_current_cmdlist(); + current_cmdlist_->signal_event(event); + submit_current_cmdlist_if_timeout(); +} +void GfxRuntime::reset_event(DeviceEvent *event) { + ensure_current_cmdlist(); + current_cmdlist_->reset_event(event); + submit_current_cmdlist_if_timeout(); +} +void GfxRuntime::wait_event(DeviceEvent *event) { + ensure_current_cmdlist(); + current_cmdlist_->wait_event(event); + submit_current_cmdlist_if_timeout(); +} + void GfxRuntime::synchronize() { flush(); device_->wait_idle(); diff --git a/taichi/runtime/gfx/runtime.h b/taichi/runtime/gfx/runtime.h index 56dad9a0aab79..cabe742c530c3 100644 --- a/taichi/runtime/gfx/runtime.h +++ b/taichi/runtime/gfx/runtime.h @@ -102,6 +102,10 @@ class TI_DLL_EXPORT GfxRuntime { void buffer_copy(DevicePtr dst, DevicePtr src, size_t size); + void signal_event(DeviceEvent *event); + void reset_event(DeviceEvent *event); + void wait_event(DeviceEvent *event); + void synchronize(); StreamSemaphore flush();