Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Adjusted C-API for nd-array type conformance #5417

Merged
merged 2 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ typedef enum TiArch {
TI_ARCH_MAX_ENUM = 0xffffffff,
} TiArch;

// enumeration.data_type
typedef enum TiDataType {
TI_DATA_TYPE_F16 = 0,
TI_DATA_TYPE_F32 = 1,
TI_DATA_TYPE_F64 = 2,
TI_DATA_TYPE_I8 = 3,
TI_DATA_TYPE_I16 = 4,
TI_DATA_TYPE_I32 = 5,
TI_DATA_TYPE_I64 = 6,
TI_DATA_TYPE_U1 = 7,
TI_DATA_TYPE_U8 = 8,
TI_DATA_TYPE_U16 = 9,
TI_DATA_TYPE_U32 = 10,
TI_DATA_TYPE_U64 = 11,
TI_DATA_TYPE_GEN = 12,
TI_DATA_TYPE_UNKNOWN = 13,
TI_DATA_TYPE_MAX_ENUM = 0xffffffff,
} TiDataType;

// enumeration.argument_type
typedef enum TiArgumentType {
TI_ARGUMENT_TYPE_I32 = 0,
Expand Down Expand Up @@ -96,6 +115,7 @@ typedef struct TiNdArray {
TiMemory memory;
TiNdShape shape;
TiNdShape elem_shape;
TiDataType elem_type;
} TiNdArray;

// union.argument_value
Expand Down
46 changes: 44 additions & 2 deletions c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,50 @@ void ti_launch_compute_graph(TiRuntime runtime,
ndarray.elem_shape.dims,
ndarray.elem_shape.dims + ndarray.elem_shape.dim_count);

ndarrays.emplace_back(taichi::lang::Ndarray(
devalloc, taichi::lang::PrimitiveType::f32, shape, elem_shape));
const taichi::lang::DataType *prim_ty;
switch (ndarray.elem_type) {
case TI_DATA_TYPE_F16:
prim_ty = &taichi::lang::PrimitiveType::f16;
break;
case TI_DATA_TYPE_F32:
prim_ty = &taichi::lang::PrimitiveType::f32;
break;
case TI_DATA_TYPE_F64:
prim_ty = &taichi::lang::PrimitiveType::f64;
break;
case TI_DATA_TYPE_I8:
prim_ty = &taichi::lang::PrimitiveType::i8;
break;
case TI_DATA_TYPE_I16:
prim_ty = &taichi::lang::PrimitiveType::i16;
break;
case TI_DATA_TYPE_I32:
prim_ty = &taichi::lang::PrimitiveType::i32;
break;
case TI_DATA_TYPE_I64:
prim_ty = &taichi::lang::PrimitiveType::i64;
break;
case TI_DATA_TYPE_U8:
prim_ty = &taichi::lang::PrimitiveType::u8;
break;
case TI_DATA_TYPE_U16:
prim_ty = &taichi::lang::PrimitiveType::u16;
break;
case TI_DATA_TYPE_U32:
prim_ty = &taichi::lang::PrimitiveType::u32;
break;
case TI_DATA_TYPE_U64:
prim_ty = &taichi::lang::PrimitiveType::u64;
break;
case TI_DATA_TYPE_GEN:
prim_ty = &taichi::lang::PrimitiveType::gen;
break;
default:
TI_ERROR("unexpected data type");
}

ndarrays.emplace_back(
taichi::lang::Ndarray(devalloc, *prim_ty, shape, elem_shape));
arg_map.emplace(std::make_pair(
arg.name, taichi::lang::aot::IValue::create(ndarrays.back())));
break;
Expand Down
9 changes: 9 additions & 0 deletions c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@
"type": "enumeration",
"inc_cases": "archs"
},
{
"name": "data_type",
"type": "enumeration",
"inc_cases": "data_type"
},
{
"name": "argument_type",
"type": "enumeration",
Expand Down Expand Up @@ -162,6 +167,10 @@
{
"name": "elem_shape",
"type": "structure.nd_shape"
},
{
"name": "elem_type",
"type": "enumeration.data_type"
}
]
},
Expand Down
4 changes: 2 additions & 2 deletions misc/generate_unity_language_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def print_module_header(module):
"using System.Runtime.InteropServices;",
"using System.Collections.Generic;",
"",
"namespace Taichi {",
"namespace Taichi.Generated {",
]

for x in module.declr_reg:
Expand All @@ -251,7 +251,7 @@ def print_module_header(module):

out += [
"",
"} // namespace Taichi",
"} // namespace Taichi.Generated",
"",
]

Expand Down