From 592c727ca8c35a99ed01dca7219dee1274cc7ee8 Mon Sep 17 00:00:00 2001 From: PENGUINLIONG Date: Thu, 14 Jul 2022 01:31:23 +0800 Subject: [PATCH 1/2] Adjusted C-API for nd-array type conformance --- c_api/include/taichi/taichi_core.h | 20 +++++++++++ c_api/src/taichi_core_impl.cpp | 44 ++++++++++++++++++++++++- c_api/taichi.json | 9 +++++ misc/generate_unity_language_binding.py | 4 +-- 4 files changed, 74 insertions(+), 3 deletions(-) diff --git a/c_api/include/taichi/taichi_core.h b/c_api/include/taichi/taichi_core.h index facbe63728947..537b5ac9ca3d8 100644 --- a/c_api/include/taichi/taichi_core.h +++ b/c_api/include/taichi/taichi_core.h @@ -52,6 +52,25 @@ typedef enum TiArch { TI_ARCH_MAX_ENUM = 0xffffffff, } TiArch; +// enumeration.data_type +typedef enum TiDataType { + TI_DATA_TYPE_F16 = 0, + TI_DATA_TYPE_F32 = 1, + TI_DATA_TYPE_F64 = 2, + TI_DATA_TYPE_I8 = 3, + TI_DATA_TYPE_I16 = 4, + TI_DATA_TYPE_I32 = 5, + TI_DATA_TYPE_I64 = 6, + TI_DATA_TYPE_U1 = 7, + TI_DATA_TYPE_U8 = 8, + TI_DATA_TYPE_U16 = 9, + TI_DATA_TYPE_U32 = 10, + TI_DATA_TYPE_U64 = 11, + TI_DATA_TYPE_GEN = 12, + TI_DATA_TYPE_UNKNOWN = 13, + TI_DATA_TYPE_MAX_ENUM = 0xffffffff, +} TiDataType; + // enumeration.argument_type typedef enum TiArgumentType { TI_ARGUMENT_TYPE_I32 = 0, @@ -96,6 +115,7 @@ typedef struct TiNdArray { TiMemory memory; TiNdShape shape; TiNdShape elem_shape; + TiDataType elem_type; } TiNdArray; // union.argument_value diff --git a/c_api/src/taichi_core_impl.cpp b/c_api/src/taichi_core_impl.cpp index eebfea2422ba3..9dad62a146b90 100644 --- a/c_api/src/taichi_core_impl.cpp +++ b/c_api/src/taichi_core_impl.cpp @@ -344,8 +344,50 @@ void ti_launch_compute_graph(TiRuntime runtime, ndarray.elem_shape.dims, ndarray.elem_shape.dims + ndarray.elem_shape.dim_count); + const taichi::lang::DataType* prim_ty; + switch (ndarray.elem_type) { + case TI_DATA_TYPE_F16: + prim_ty = &taichi::lang::PrimitiveType::f16; + break; + case TI_DATA_TYPE_F32: + prim_ty = &taichi::lang::PrimitiveType::f32; + break; + case TI_DATA_TYPE_F64: + prim_ty = &taichi::lang::PrimitiveType::f64; + break; + case TI_DATA_TYPE_I8: + prim_ty = &taichi::lang::PrimitiveType::i8; + break; + case TI_DATA_TYPE_I16: + prim_ty = &taichi::lang::PrimitiveType::i16; + break; + case TI_DATA_TYPE_I32: + prim_ty = &taichi::lang::PrimitiveType::i32; + break; + case TI_DATA_TYPE_I64: + prim_ty = &taichi::lang::PrimitiveType::i64; + break; + case TI_DATA_TYPE_U8: + prim_ty = &taichi::lang::PrimitiveType::u8; + break; + case TI_DATA_TYPE_U16: + prim_ty = &taichi::lang::PrimitiveType::u16; + break; + case TI_DATA_TYPE_U32: + prim_ty = &taichi::lang::PrimitiveType::u32; + break; + case TI_DATA_TYPE_U64: + prim_ty = &taichi::lang::PrimitiveType::u64; + break; + case TI_DATA_TYPE_GEN: + prim_ty = &taichi::lang::PrimitiveType::gen; + break; + default: + TI_ERROR("unexpected data type"); + } + ndarrays.emplace_back(taichi::lang::Ndarray( - devalloc, taichi::lang::PrimitiveType::f32, shape, elem_shape)); + devalloc, *prim_ty, shape, elem_shape)); arg_map.emplace(std::make_pair( arg.name, taichi::lang::aot::IValue::create(ndarrays.back()))); break; diff --git a/c_api/taichi.json b/c_api/taichi.json index c00423c4acdb1..c4418451534b8 100644 --- a/c_api/taichi.json +++ b/c_api/taichi.json @@ -70,6 +70,11 @@ "type": "enumeration", "inc_cases": "archs" }, + { + "name": "data_type", + "type": "enumeration", + "inc_cases": "data_type" + }, { "name": "argument_type", "type": "enumeration", @@ -162,6 +167,10 @@ { "name": "elem_shape", "type": "structure.nd_shape" + }, + { + "name": "elem_type", + "type": "enumeration.data_type" } ] }, diff --git a/misc/generate_unity_language_binding.py b/misc/generate_unity_language_binding.py index 0c6af3a17dd32..539bcfd203d79 100644 --- a/misc/generate_unity_language_binding.py +++ b/misc/generate_unity_language_binding.py @@ -239,7 +239,7 @@ def print_module_header(module): "using System.Runtime.InteropServices;", "using System.Collections.Generic;", "", - "namespace Taichi {", + "namespace Taichi.Generated {", ] for x in module.declr_reg: @@ -251,7 +251,7 @@ def print_module_header(module): out += [ "", - "} // namespace Taichi", + "} // namespace Taichi.Generated", "", ] From 0cbeb95eeb963b973779a5286fbe7d0a14e65ffe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Jul 2022 17:36:14 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- c_api/src/taichi_core_impl.cpp | 82 +++++++++++++++++----------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/c_api/src/taichi_core_impl.cpp b/c_api/src/taichi_core_impl.cpp index 9dad62a146b90..029510a02f7ed 100644 --- a/c_api/src/taichi_core_impl.cpp +++ b/c_api/src/taichi_core_impl.cpp @@ -344,50 +344,50 @@ void ti_launch_compute_graph(TiRuntime runtime, ndarray.elem_shape.dims, ndarray.elem_shape.dims + ndarray.elem_shape.dim_count); - const taichi::lang::DataType* prim_ty; + const taichi::lang::DataType *prim_ty; switch (ndarray.elem_type) { - case TI_DATA_TYPE_F16: - prim_ty = &taichi::lang::PrimitiveType::f16; - break; - case TI_DATA_TYPE_F32: - prim_ty = &taichi::lang::PrimitiveType::f32; - break; - case TI_DATA_TYPE_F64: - prim_ty = &taichi::lang::PrimitiveType::f64; - break; - case TI_DATA_TYPE_I8: - prim_ty = &taichi::lang::PrimitiveType::i8; - break; - case TI_DATA_TYPE_I16: - prim_ty = &taichi::lang::PrimitiveType::i16; - break; - case TI_DATA_TYPE_I32: - prim_ty = &taichi::lang::PrimitiveType::i32; - break; - case TI_DATA_TYPE_I64: - prim_ty = &taichi::lang::PrimitiveType::i64; - break; - case TI_DATA_TYPE_U8: - prim_ty = &taichi::lang::PrimitiveType::u8; - break; - case TI_DATA_TYPE_U16: - prim_ty = &taichi::lang::PrimitiveType::u16; - break; - case TI_DATA_TYPE_U32: - prim_ty = &taichi::lang::PrimitiveType::u32; - break; - case TI_DATA_TYPE_U64: - prim_ty = &taichi::lang::PrimitiveType::u64; - break; - case TI_DATA_TYPE_GEN: - prim_ty = &taichi::lang::PrimitiveType::gen; - break; - default: - TI_ERROR("unexpected data type"); + case TI_DATA_TYPE_F16: + prim_ty = &taichi::lang::PrimitiveType::f16; + break; + case TI_DATA_TYPE_F32: + prim_ty = &taichi::lang::PrimitiveType::f32; + break; + case TI_DATA_TYPE_F64: + prim_ty = &taichi::lang::PrimitiveType::f64; + break; + case TI_DATA_TYPE_I8: + prim_ty = &taichi::lang::PrimitiveType::i8; + break; + case TI_DATA_TYPE_I16: + prim_ty = &taichi::lang::PrimitiveType::i16; + break; + case TI_DATA_TYPE_I32: + prim_ty = &taichi::lang::PrimitiveType::i32; + break; + case TI_DATA_TYPE_I64: + prim_ty = &taichi::lang::PrimitiveType::i64; + break; + case TI_DATA_TYPE_U8: + prim_ty = &taichi::lang::PrimitiveType::u8; + break; + case TI_DATA_TYPE_U16: + prim_ty = &taichi::lang::PrimitiveType::u16; + break; + case TI_DATA_TYPE_U32: + prim_ty = &taichi::lang::PrimitiveType::u32; + break; + case TI_DATA_TYPE_U64: + prim_ty = &taichi::lang::PrimitiveType::u64; + break; + case TI_DATA_TYPE_GEN: + prim_ty = &taichi::lang::PrimitiveType::gen; + break; + default: + TI_ERROR("unexpected data type"); } - ndarrays.emplace_back(taichi::lang::Ndarray( - devalloc, *prim_ty, shape, elem_shape)); + ndarrays.emplace_back( + taichi::lang::Ndarray(devalloc, *prim_ty, shape, elem_shape)); arg_map.emplace(std::make_pair( arg.name, taichi::lang::aot::IValue::create(ndarrays.back()))); break;