diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 0acb731e0e43..3ec04430e89a 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 +Subproject commit 3ec04430e89a6834e5a1b99471f415fa939bf642 diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index bb38ad8a84df..be865635456d 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -87,12 +87,14 @@ typedef enum { } TVMDeviceExtType; /*! - * \brief The type code in used in the TVM FFI. + * \brief The type code in used in the TVM FFI for argument passing. */ typedef enum { // The type code of other types are compatible with DLPack. // The next few fields are extension types // that is used by TVM API calls. + kTVMArgInt = kDLInt, + kTVMArgFloat = kDLFloat, kTVMOpaqueHandle = 3U, kTVMNullptr = 4U, kTVMDataType = 5U, @@ -115,9 +117,7 @@ typedef enum { // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, kTVMExtEnd = 128U, - // The rest of the space is used for custom, user-supplied datatypes - kTVMCustomBegin = 129U, -} TVMTypeCode; +} TVMArgTypeCode; /*! * \brief The Device information, abstract away common device types. diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index a10b83fd321b..1d538105f130 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -45,7 +45,8 @@ class DataType { kInt = kDLInt, kUInt = kDLUInt, kFloat = kDLFloat, - kHandle = TVMTypeCode::kTVMOpaqueHandle, + kHandle = TVMArgTypeCode::kTVMOpaqueHandle, + kCustomBegin = 129 }; /*! \brief default constructor */ DataType() {} @@ -248,7 +249,7 @@ TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); * \param type_code The type code . * \return The name of type code. */ -inline const char* TypeCode2Str(int type_code); +inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code); /*! * \brief convert a string to TVM type. @@ -265,38 +266,16 @@ inline DLDataType String2DLDataType(std::string s); inline std::string DLDataType2String(DLDataType t); // implementation details -inline const char* TypeCode2Str(int type_code) { - switch (type_code) { +inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { + switch (static_cast(type_code)) { case kDLInt: return "int"; case kDLUInt: return "uint"; case kDLFloat: return "float"; - case kTVMStr: - return "str"; - case kTVMBytes: - return "bytes"; - case kTVMOpaqueHandle: + case DataType::kHandle: return "handle"; - case kTVMNullptr: - return "NULL"; - case kTVMDLTensorHandle: - return "ArrayHandle"; - case kTVMDataType: - return "DLDataType"; - case kTVMContext: - return "TVMContext"; - case kTVMPackedFuncHandle: - return "FunctionHandle"; - case kTVMModuleHandle: - return "ModuleHandle"; - case kTVMNDArrayHandle: - return "NDArrayContainer"; - case kTVMObjectHandle: - return "Object"; - case kTVMObjectRValueRefArg: - return "ObjectRValueRefArg"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; @@ -311,8 +290,8 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) if (DataType(t).is_void()) { return os << "void"; } - if (t.code < kTVMCustomBegin) { - os << TypeCode2Str(t.code); + if (t.code < DataType::kCustomBegin) { + os << DLDataTypeCode2Str(static_cast(t.code)); } else { os << "custom[" << GetCustomTypeName(t.code) << "]"; } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 01f8e994347a..e82b97a5a2d4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -327,9 +327,16 @@ class TVMArgs { inline TVMArgValue operator[](int i) const; }; +/*! + * \brief Convert argument type code to string. + * \param type_code The input type code. + * \return The corresponding string repr. + */ +inline const char* ArgTypeCode2Str(int type_code); + // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ - CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) + CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -394,7 +401,7 @@ class TVMPODValue_ { } else { if (type_code_ == kTVMNullptr) return nullptr; LOG(FATAL) << "Expect " - << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_); + << "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_); return nullptr; } } @@ -982,6 +989,44 @@ inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_( inline PackedFunc::FType PackedFunc::body() const { return body_; } // internal namespace +inline const char* ArgTypeCode2Str(int type_code) { + switch (type_code) { + case kDLInt: + return "int"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case kTVMStr: + return "str"; + case kTVMBytes: + return "bytes"; + case kTVMOpaqueHandle: + return "handle"; + case kTVMNullptr: + return "NULL"; + case kTVMDLTensorHandle: + return "ArrayHandle"; + case kTVMDataType: + return "DLDataType"; + case kTVMContext: + return "TVMContext"; + case kTVMPackedFuncHandle: + return "FunctionHandle"; + case kTVMModuleHandle: + return "ModuleHandle"; + case kTVMNDArrayHandle: + return "NDArrayContainer"; + case kTVMObjectHandle: + return "Object"; + case kTVMObjectRValueRefArg: + return "ObjectRValueRefArg"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + return ""; + } +} + namespace detail { template diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5884942ebef1..a4748d525829 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -740,7 +740,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kTVMCustomBegin)) { + if (static_cast(t.code()) >= static_cast(DataType::kCustomBegin)) { return FloatImm(t, static_cast(value)); } LOG(FATAL) << "cannot make const for type " << t; diff --git a/jvm/core/src/main/java/org/apache/tvm/TypeCode.java b/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java similarity index 95% rename from jvm/core/src/main/java/org/apache/tvm/TypeCode.java rename to jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java index 2d21e4afa6b4..b3b3da56e72f 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TypeCode.java +++ b/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java @@ -18,14 +18,14 @@ package org.apache.tvm; // Type code used in API calls -public enum TypeCode { +public enum ArgTypeCode { INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13); public final int id; - private TypeCode(int id) { + private ArgTypeCode(int id) { this.id = id; } diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index a9ac70722410..df535a87aa85 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -80,7 +80,7 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a * @param isResident Whether this is a resident function in jvm */ Function(long handle, boolean isResident) { - super(TypeCode.FUNC_HANDLE); + super(ArgTypeCode.FUNC_HANDLE); this.handle = handle; this.isResident = isResident; } @@ -187,7 +187,7 @@ public Function pushArg(String arg) { * @return this */ public Function pushArg(NDArrayBase arg) { - int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; Base._LIB.tvmFuncPushArgHandle(arg.handle, id); return this; } @@ -198,7 +198,7 @@ public Function pushArg(NDArrayBase arg) { * @return this */ public Function pushArg(Module arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id); return this; } @@ -208,7 +208,7 @@ public Function pushArg(Module arg) { * @return this */ public Function pushArg(Function arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id); return this; } @@ -249,12 +249,12 @@ private static void pushArgToStack(Object arg) { Base._LIB.tvmFuncPushArgBytes((byte[]) arg); } else if (arg instanceof NDArrayBase) { NDArrayBase nd = (NDArrayBase) arg; - int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; Base._LIB.tvmFuncPushArgHandle(nd.handle, id); } else if (arg instanceof Module) { - Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { - Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); } else if (arg instanceof TVMValue) { TVMValue tvmArg = (TVMValue) arg; switch (tvmArg.typeCode) { diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 1656f8dee6fa..874daa4029dc 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -45,7 +45,7 @@ private static Function getApi(String name) { } Module(long handle) { - super(TypeCode.MODULE_HANDLE); + super(ArgTypeCode.MODULE_HANDLE); this.handle = handle; } @@ -138,7 +138,7 @@ public String typeKey() { */ public static Module load(String path, String fmt) { TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); - assert ret.typeCode == TypeCode.MODULE_HANDLE; + assert ret.typeCode == ArgTypeCode.MODULE_HANDLE; return ret.asModule(); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java index 5ac630d3a668..26bb735e1a5b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java @@ -27,7 +27,7 @@ public class NDArrayBase extends TVMValue { private boolean isReleased = false; NDArrayBase(long handle, boolean isView) { - super(TypeCode.ARRAY_HANDLE); + super(ArgTypeCode.ARRAY_HANDLE); this.handle = handle; this.isView = isView; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java index 92c7623b2dc1..d30cfcc4f30a 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java @@ -18,9 +18,9 @@ package org.apache.tvm; public class TVMValue { - public final TypeCode typeCode; + public final ArgTypeCode typeCode; - public TVMValue(TypeCode tc) { + public TVMValue(ArgTypeCode tc) { typeCode = tc; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java index 6c7c1c892747..132d88f7622b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java @@ -21,7 +21,7 @@ public class TVMValueBytes extends TVMValue { public final byte[] value; public TVMValueBytes(byte[] value) { - super(TypeCode.BYTES); + super(ArgTypeCode.BYTES); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java index d94b011d7e10..9db4c3bb0e8c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java @@ -21,7 +21,7 @@ public class TVMValueDouble extends TVMValue { public final double value; public TVMValueDouble(double value) { - super(TypeCode.FLOAT); + super(ArgTypeCode.FLOAT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java index 8ab7572d1cfd..b91f55e2f59b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java @@ -18,13 +18,13 @@ package org.apache.tvm; /** - * Java class related to TVM handles (TypeCode.HANDLE) + * Java class related to TVM handles (ArgTypeCode.HANDLE) */ public class TVMValueHandle extends TVMValue { public final long value; public TVMValueHandle(long value) { - super(TypeCode.HANDLE); + super(ArgTypeCode.HANDLE); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java index 5dba2fd459f6..8a9b157d3961 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java @@ -21,7 +21,7 @@ public class TVMValueLong extends TVMValue { public final long value; public TVMValueLong(long value) { - super(TypeCode.INT); + super(ArgTypeCode.INT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java index 03c0ea0dbcd4..8c49ee5b3df5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java @@ -19,6 +19,6 @@ public class TVMValueNull extends TVMValue { public TVMValueNull() { - super(TypeCode.NULL); + super(ArgTypeCode.NULL); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java index 260803e8e897..46926e7d3fc6 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java @@ -21,7 +21,7 @@ public class TVMValueString extends TVMValue { public final String value; public TVMValueString(String value) { - super(TypeCode.STR); + super(ArgTypeCode.STR); this.value = value; } diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 6db86555278f..6cbc6d2288ac 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -23,7 +23,7 @@ # top-level alias # tvm._ffi from ._ffi.base import TVMError, __version__ -from ._ffi.runtime_ctypes import TypeCode, DataType +from ._ffi.runtime_ctypes import DataTypeCode, DataType from ._ffi import register_object, register_func, register_extension, get_global_func # top-level alias diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 3dbb60715703..359b018f0431 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -18,7 +18,7 @@ """Runtime Object api""" import ctypes from ..base import _LIB, check_call -from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func +from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .ndarray import _register_ndarray, NDArrayBase @@ -60,12 +60,12 @@ def _return_object(x): obj.handle = handle return obj -RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object -C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( - _return_object, TypeCode.OBJECT_HANDLE) +RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object +C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func( + _return_object, ArgTypeCode.OBJECT_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( - _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) +C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( + _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG) class PyNativeObject: diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index b17174a7c6bf..8a2f49a7e6b6 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -26,7 +26,7 @@ from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array -from .types import TVMValue, TypeCode +from .types import TVMValue, ArgTypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .object import ObjectBase, PyNativeObject, _set_class_object @@ -115,32 +115,32 @@ def _make_tvm_args(args, temp_args): for i, arg in enumerate(args): if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None - type_codes[i] = TypeCode.NULL + type_codes[i] = ArgTypeCode.NULL elif isinstance(arg, NDArrayBase): values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) - type_codes[i] = (TypeCode.NDARRAY_HANDLE - if not arg.is_view else TypeCode.DLTENSOR_HANDLE) + type_codes[i] = (ArgTypeCode.NDARRAY_HANDLE + if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE) elif isinstance(arg, PyNativeObject): values[i].v_handle = arg.__tvm_object__.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode elif isinstance(arg, Integral): values[i].v_int64 = arg - type_codes[i] = TypeCode.INT + type_codes[i] = ArgTypeCode.INT elif isinstance(arg, Number): values[i].v_float64 = arg - type_codes[i] = TypeCode.FLOAT + type_codes[i] = ArgTypeCode.FLOAT elif isinstance(arg, DataType): values[i].v_str = c_str(str(arg)) - type_codes[i] = TypeCode.STR + type_codes[i] = ArgTypeCode.STR elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) - type_codes[i] = TypeCode.TVM_CONTEXT + type_codes[i] = ArgTypeCode.TVM_CONTEXT elif isinstance(arg, (bytearray, bytes)): # from_buffer only taeks in bytearray. if isinstance(arg, bytes): @@ -155,31 +155,31 @@ def _make_tvm_args(args, temp_args): arr.size = len(arg) values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) temp_args.append(arr) - type_codes[i] = TypeCode.BYTES + type_codes[i] = ArgTypeCode.BYTES elif isinstance(arg, string_types): values[i].v_str = c_str(arg) - type_codes[i] = TypeCode.STR + type_codes[i] = ArgTypeCode.STR elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): arg = _FUNC_CONVERT_TO_OBJECT(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.MODULE_HANDLE + type_codes[i] = ArgTypeCode.MODULE_HANDLE elif isinstance(arg, PackedFuncBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg - type_codes[i] = TypeCode.HANDLE + type_codes[i] = ArgTypeCode.HANDLE elif isinstance(arg, ObjectRValueRef): values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p) - type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG + type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG elif callable(arg): arg = convert_to_tvm_func(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE temp_args.append(arg) else: raise TypeError("Don't know how to handle type %s" % type(arg)) @@ -240,7 +240,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == TypeCode.OBJECT_HANDLE + assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -275,15 +275,15 @@ def _get_global_func(name, allow_missing=False): # setup return handle for function type _object.__init_by_constructor__ = __init_handle_by_constructor__ -RETURN_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _handle_return_func -RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module -RETURN_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) -C_TO_PY_ARG_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func( - _handle_return_func, TypeCode.PACKED_FUNC_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( - _return_module, TypeCode.MODULE_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False) -C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) +RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func +RETURN_SWITCH[ArgTypeCode.MODULE_HANDLE] = _return_module +RETURN_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) +C_TO_PY_ARG_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func( + _handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE) +C_TO_PY_ARG_SWITCH[ArgTypeCode.MODULE_HANDLE] = _wrap_arg_func( + _return_module, ArgTypeCode.MODULE_HANDLE) +C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False) +C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) _CLASS_MODULE = None _CLASS_PACKED_FUNC = None diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 20be30a59b2f..d4e7b362cbe9 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -19,7 +19,7 @@ import ctypes import struct from ..base import py_str, check_call, _LIB -from ..runtime_ctypes import TVMByteArray, TypeCode, TVMContext +from ..runtime_ctypes import TVMByteArray, ArgTypeCode, TVMContext class TVMValue(ctypes.Union): """TVMValue in C API""" @@ -86,21 +86,21 @@ def _ctx_to_int64(ctx): RETURN_SWITCH = { - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.HANDLE: _return_handle, - TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.BYTES: _return_bytes, - TypeCode.TVM_CONTEXT: _return_context + ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.FLOAT: lambda x: x.v_float64, + ArgTypeCode.HANDLE: _return_handle, + ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.STR: lambda x: py_str(x.v_str), + ArgTypeCode.BYTES: _return_bytes, + ArgTypeCode.TVM_CONTEXT: _return_context } C_TO_PY_ARG_SWITCH = { - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.HANDLE: _return_handle, - TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.BYTES: _return_bytes, - TypeCode.TVM_CONTEXT: _return_context + ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.FLOAT: lambda x: x.v_float64, + ArgTypeCode.HANDLE: _return_handle, + ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.STR: lambda x: py_str(x.v_str), + ArgTypeCode.BYTES: _return_bytes, + ArgTypeCode.TVM_CONTEXT: _return_context } diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0da66ac2e034..8c9e413813b9 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -22,7 +22,7 @@ from cpython cimport pycapsule from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t import ctypes -cdef enum TVMTypeCode: +cdef enum TVMArgTypeCode: kInt = 0 kUInt = 1 kFloat = 2 diff --git a/python/tvm/_ffi/registry.py b/python/tvm/_ffi/registry.py index e4b8b18b4805..0942ccb277a6 100644 --- a/python/tvm/_ffi/registry.py +++ b/python/tvm/_ffi/registry.py @@ -122,7 +122,7 @@ def register_extension(cls, fcreate=None): @tvm.register_extension class MyTensor(object): - _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE + _tvm_tcode = tvm.ArgTypeCode.ARRAY_HANDLE def __init__(self): self.handle = _LIB.NewDLTensor() @@ -132,8 +132,8 @@ def _tvm_handle(self): return self.handle.value """ assert hasattr(cls, "_tvm_tcode") - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") + if fcreate: + raise ValueError("Extension with fcreate is no longer supported") _reg_extension(cls, fcreate) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index db89854a8604..2e498e38cce8 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -23,7 +23,7 @@ tvm_shape_index_t = ctypes.c_int64 -class TypeCode(object): +class ArgTypeCode(object): """Type code used in API calls""" INT = 0 UINT = 1 @@ -42,23 +42,30 @@ class TypeCode(object): OBJECT_RVALUE_REF_ARG = 14 EXT_BEGIN = 15 - class TVMByteArray(ctypes.Structure): """Temp data structure for byte array.""" _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)] +class DataTypeCode(object): + """DataType code in DLTensor.""" + INT = 0 + UINT = 1 + FLOAT = 2 + HANDLE = 3 + + class DataType(ctypes.Structure): """TVM datatype structure""" _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)] CODE2STR = { - 0 : 'int', - 1 : 'uint', - 2 : 'float', - 4 : 'handle' + DataTypeCode.INT : 'int', + DataTypeCode.UINT : 'uint', + DataTypeCode.FLOAT : 'float', + DataTypeCode.HANDLE : 'handle' } def __init__(self, type_str): super(DataType, self).__init__() @@ -67,7 +74,7 @@ def __init__(self, type_str): if type_str == "bool": self.bits = 1 - self.type_code = 1 + self.type_code = DataTypeCode.UINT self.lanes = 1 return @@ -77,16 +84,16 @@ def __init__(self, type_str): bits = 32 if head.startswith("int"): - self.type_code = 0 + self.type_code = DataTypeCode.INT head = head[3:] elif head.startswith("uint"): - self.type_code = 1 + self.type_code = DataTypeCode.UINT head = head[4:] elif head.startswith("float"): - self.type_code = 2 + self.type_code = DataTypeCode.FLOAT head = head[5:] elif head.startswith("handle"): - self.type_code = 4 + self.type_code = DataTypeCode.HANDLE bits = 64 head = "" elif head.startswith("custom"): diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 59574e68f5c6..21c06c517bd7 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -20,7 +20,7 @@ from .packed_func import PackedFunc from .object import Object from .object_generic import ObjectGeneric, ObjectTypes -from .ndarray import NDArray, DataType, TypeCode, TVMContext +from .ndarray import NDArray, DataType, DataTypeCode, TVMContext from .module import Module # function exposures diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 6629cc6e612c..060673dc19c6 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -22,7 +22,7 @@ from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE from tvm._ffi.runtime_ctypes import DataType, TVMContext, TVMArray, TVMArrayHandle -from tvm._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t +from tvm._ffi.runtime_ctypes import DataTypeCode, tvm_shape_index_t try: # pylint: disable=wrong-import-position diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 4cbece363f71..aca5e5a377fb 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -29,7 +29,7 @@ """ import tvm._ffi -from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const +from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const from tvm.ir import PrimExpr import tvm.ir._ffi_api from . import generic as _generic @@ -47,13 +47,13 @@ def _dtype_is_int(value): if isinstance(value, int): return True return (isinstance(value, ExprOp) and - DataType(value.dtype).type_code == TypeCode.INT) + DataType(value.dtype).type_code == DataTypeCode.INT) def _dtype_is_float(value): if isinstance(value, float): return True return (isinstance(value, ExprOp) and - DataType(value.dtype).type_code == TypeCode.FLOAT) + DataType(value.dtype).type_code == DataTypeCode.FLOAT) class ExprOp(object): """Operator overloading for Expr like expressions.""" diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index f3bac39b6a10..65434b928269 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -94,52 +94,52 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMTypeCode_kTVMNullptr => Null, - TVMTypeCode_kTVMDataType => DataType($value.v_type), - TVMTypeCode_kTVMContext => Context($value.v_ctx), - TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + TVMArgTypeCode_kTVMNullptr => Null, + TVMArgTypeCode_kTVMDataType => DataType($value.v_type), + TVMArgTypeCode_kTVMContext => Context($value.v_ctx), + TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), $( $tvm_type => { $from_tvm_type } ),+ _ => unimplemented!("{}", type_code), } } } - pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { use $name::*; match self { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), - Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext), String(val) => { ( TVMValue { v_handle: val.as_ptr() as *mut c_void }, - TVMTypeCode_kTVMStr, + TVMArgTypeCode_kTVMStr, ) } - Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), ArrayHandle(val) => { ( TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMTypeCode_kTVMNDArrayHandle, + TVMArgTypeCode_kTVMNDArrayHandle, ) }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), FuncHandle(val) => ( TVMValue { v_handle: *val }, - TVMTypeCode_kTVMPackedFuncHandle + TVMArgTypeCode_kTVMPackedFuncHandle ), NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), $( $self_type($val) => { $from_self_type } ),+ } } @@ -155,14 +155,14 @@ TVMPODValue! { Str(&'a CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } }, match &self { Bytes(val) => { - (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } } } @@ -188,14 +188,14 @@ TVMPODValue! { Str(&'static CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } }, match &self { Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } } } diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 8411b03592d1..88d6cc80fe1c 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -204,7 +204,7 @@ impl<'a, 'm> Builder<'a, 'm> { ensure!(self.func.is_some(), errors::FunctionNotFoundError); let num_args = self.arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = + let (mut values, mut type_codes): (Vec, Vec) = self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() }; @@ -257,9 +257,9 @@ unsafe extern "C" fn tvm_callback( for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int + if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int { check_call!(ffi::TVMCbArgToReturn( &mut value as *mut _, diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index e4b27397900c..a326aa1b8fdf 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -95,52 +95,52 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMTypeCode_kTVMNullptr => Null, - TVMTypeCode_kTVMDataType => DataType($value.v_type), - TVMTypeCode_kTVMContext => Context($value.v_ctx), - TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + TVMArgTypeCode_kTVMNullptr => Null, + TVMArgTypeCode_kTVMDataType => DataType($value.v_type), + TVMArgTypeCode_kTVMContext => Context($value.v_ctx), + TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), $( $tvm_type => { $from_tvm_type } ),+ _ => unimplemented!("{}", type_code), } } } - pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { use $name::*; match self { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), - Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext), String(val) => { ( TVMValue { v_handle: val.as_ptr() as *mut c_void }, - TVMTypeCode_kTVMStr, + TVMArgTypeCode_kTVMStr, ) } - Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), ArrayHandle(val) => { ( TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMTypeCode_kTVMNDArrayHandle, + TVMArgTypeCode_kTVMNDArrayHandle, ) }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), FuncHandle(val) => ( TVMValue { v_handle: *val }, - TVMTypeCode_kTVMPackedFuncHandle + TVMArgTypeCode_kTVMPackedFuncHandle ), NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), $( $self_type($val) => { $from_self_type } ),+ } } @@ -156,14 +156,14 @@ TVMPODValue! { Str(&'a CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } }, match &self { Bytes(val) => { - (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } } } @@ -189,14 +189,14 @@ TVMPODValue! { Str(&'static CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } }, match &self { Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } } } diff --git a/src/runtime/micro/standalone/utvm_graph_runtime.cc b/src/runtime/micro/standalone/utvm_graph_runtime.cc index db55634de66a..e19ee347a45e 100644 --- a/src/runtime/micro/standalone/utvm_graph_runtime.cc +++ b/src/runtime/micro/standalone/utvm_graph_runtime.cc @@ -327,7 +327,7 @@ std::function CreateTVMOp(const DSOModule& module, const TVMOpParam& par } TVMValue; /*typedef*/ enum { kTVMDLTensorHandle = 7U, - } /*TVMTypeCode*/; + } /*TVMArgTypeCode*/; struct OpArgs { DynArray args; DynArray arg_values; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 8c462698f648..b9cdc2cf82ad 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -251,7 +251,7 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { - LOG(FATAL) << "ValueError: Cannot pass type " << runtime::TypeCode2Str(arg.type_code()) + LOG(FATAL) << "ValueError: Cannot pass type " << runtime::ArgTypeCode2Str(arg.type_code()) << " as an argument to the remote"; return nullptr; } diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 99d6bee60975..5ed3ce4f7c03 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -49,8 +49,8 @@ Registry* Registry::Global() { } void Registry::Register(const std::string& type_name, uint8_t type_code) { - CHECK(type_code >= kTVMCustomBegin) - << "Please choose a type code >= kTVMCustomBegin for custom types"; + CHECK(type_code >= DataType::kCustomBegin) + << "Please choose a type code >= DataType::kCustomBegin for custom types"; code_to_name_[type_code] = type_name; name_to_code_[type_name] = type_code; } @@ -78,7 +78,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { ss << datatype::Registry::Global()->GetTypeName(type_code); } else { - ss << runtime::TypeCode2Str(type_code); + ss << runtime::DLDataTypeCode2Str(static_cast(type_code)); } ss << "."; @@ -86,7 +86,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { ss << datatype::Registry::Global()->GetTypeName(src_type_code); } else { - ss << runtime::TypeCode2Str(src_type_code); + ss << runtime::DLDataTypeCode2Str(static_cast(src_type_code)); } return runtime::Registry::Get(ss.str()); } diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index c04359208e64..5df8ef8164db 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -61,7 +61,7 @@ class Registry { * same code. Generally, this should be straightforward, as the user will be manually registering * all of their custom types. * \param type_name The name of the type, e.g. "bfloat" - * \param type_code The type code, which should be greater than TVMTypeCode::kTVMExtEnd + * \param type_code The type code, which should be greater than TVMArgTypeCode::kTVMExtEnd */ void Register(const std::string& type_name, uint8_t type_code); diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 48eaf7dd306b..2207eb3a73fa 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -18,9 +18,10 @@ from tvm import te import numpy as np + @tvm.register_extension class MyTensorView(object): - _tvm_tcode = tvm.TypeCode.DLTENSOR_HANDLE + _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE def __init__(self, arr): self.arr = arr diff --git a/tests/python/unittest/test_runtime_ndarray.py b/tests/python/unittest/test_runtime_ndarray.py index e3143794cc34..36312959da3d 100644 --- a/tests/python/unittest/test_runtime_ndarray.py +++ b/tests/python/unittest/test_runtime_ndarray.py @@ -72,6 +72,13 @@ def test_fp16_conversion(): tvm.testing.assert_allclose(expected, real) + +def test_dtype(): + dtype = tvm.DataType("handle") + assert dtype.type_code == tvm.DataTypeCode.HANDLE + + if __name__ == "__main__": test_nd_create() test_fp16_conversion() + test_dtype() diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index f533b4e491a6..66c46fe7ed91 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -208,9 +208,9 @@ export const enum SizeOf { } /** - * Type code in TVM FFI. + * Argument Type code in TVM FFI. */ -export const enum TypeCode { +export const enum ArgTypeCode { Int = 0, UInt = 1, Float = 2, @@ -226,4 +226,4 @@ export const enum TypeCode { TVMBytes = 12, TVMNDArrayHandle = 13, TVMObjectRValueRefArg = 14 -} \ No newline at end of file +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index 50227dc79281..542558aa157f 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -17,7 +17,7 @@ * under the License. */ -import { SizeOf, TypeCode } from "./ctypes"; +import { SizeOf, ArgTypeCode } from "./ctypes"; import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; import { detectGPUDevice } from "./webgpu"; import * as compact from "./compact"; @@ -216,10 +216,10 @@ export class RPCServer { for (let i = 0; i < nargs; ++i) { const tcode = tcodes[i]; - if (tcode == TypeCode.TVMStr) { + if (tcode == ArgTypeCode.TVMStr) { const str = Uint8ArrayToString(reader.readByteArray()); args.push(str); - } else if (tcode == TypeCode.TVMBytes) { + } else if (tcode == ArgTypeCode.TVMBytes) { args.push(reader.readByteArray()); } else { throw new Error("cannot support type code " + tcode); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index bcf7be7d5544..5c9b9d8181d7 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -20,7 +20,7 @@ /** * TVM JS Wasm Runtime library. */ -import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes"; +import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; @@ -234,12 +234,21 @@ export class DLContext { ); } } +/** + * The data type code in DLDataType + */ +export const enum DLDataTypeCode { + Int = 0, + UInt = 1, + Float = 2, + OpaqueHandle = 3 +} const DLDataTypeCodeToStr: Record = { 0: "int", 1: "uint", 2: "float", - 4: "handle", + 3: "handle", }; /** @@ -866,16 +875,16 @@ export class Instance implements Disposable { lanes = 1; if (pattern.substring(0, 5) == "float") { pattern = pattern.substring(5, pattern.length); - code = TypeCode.Float; + code = DLDataTypeCode.Float; } else if (pattern.substring(0, 3) == "int") { pattern = pattern.substring(3, pattern.length); - code = TypeCode.Int; + code = DLDataTypeCode.Int; } else if (pattern.substring(0, 4) == "uint") { pattern = pattern.substring(4, pattern.length); - code = TypeCode.UInt; + code = DLDataTypeCode.UInt; } else if (pattern.substring(0, 6) == "handle") { pattern = pattern.substring(5, pattern.length); - code = TypeCode.TVMOpaqueHandle; + code = DLDataTypeCode.OpaqueHandle; bits = 64; } else { throw new Error("Unknown dtype " + dtype); @@ -1140,47 +1149,47 @@ export class Instance implements Disposable { const codeOffset = argsCode + i * SizeOf.I32; if (val instanceof NDArray) { stack.storePtr(valueOffset, val.handle); - stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); } else if (val instanceof Scalar) { if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { stack.storeI64(valueOffset, val.value); - stack.storeI32(codeOffset, TypeCode.Int); + stack.storeI32(codeOffset, ArgTypeCode.Int); } else if (val.dtype.startsWith("float")) { stack.storeF64(valueOffset, val.value); - stack.storeI32(codeOffset, TypeCode.Float); + stack.storeI32(codeOffset, ArgTypeCode.Float); } else { assert(val.dtype == "handle", "Expect handle"); stack.storePtr(valueOffset, val.value); - stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); } } else if (val instanceof DLContext) { stack.storeI32(valueOffset, val.deviceType); stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); - stack.storeI32(codeOffset, TypeCode.TVMContext); + stack.storeI32(codeOffset, ArgTypeCode.TVMContext); } else if (tp == "number") { stack.storeF64(valueOffset, val); - stack.storeI32(codeOffset, TypeCode.Float); + stack.storeI32(codeOffset, ArgTypeCode.Float); // eslint-disable-next-line no-prototype-builtins } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { stack.storePtr(valueOffset, val._tvmPackedCell.handle); - stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); } else if (val === null || val == undefined) { stack.storePtr(valueOffset, 0); - stack.storeI32(codeOffset, TypeCode.Null); + stack.storeI32(codeOffset, ArgTypeCode.Null); } else if (tp == "string") { stack.allocThenSetArgString(valueOffset, val); - stack.storeI32(codeOffset, TypeCode.TVMStr); + stack.storeI32(codeOffset, ArgTypeCode.TVMStr); } else if (val instanceof Uint8Array) { stack.allocThenSetArgBytes(valueOffset, val); - stack.storeI32(codeOffset, TypeCode.TVMBytes); + stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); } else if (val instanceof Function) { val = this.toPackedFunc(val); stack.tempArgs.push(val); stack.storePtr(valueOffset, val._tvmPackedCell.handle); - stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); } else if (val instanceof Module) { stack.storePtr(valueOffset, val.handle); - stack.storeI32(codeOffset, TypeCode.TVMModuleHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); } else { throw new Error("Unsupported argument type " + tp); } @@ -1204,10 +1213,10 @@ export class Instance implements Disposable { let tcode = lib.memory.loadI32(codePtr); if ( - tcode == TypeCode.TVMObjectHandle || - tcode == TypeCode.TVMObjectRValueRefArg || - tcode == TypeCode.TVMPackedFuncHandle || - tcode == TypeCode.TVMModuleHandle + tcode == ArgTypeCode.TVMObjectHandle || + tcode == ArgTypeCode.TVMObjectRValueRefArg || + tcode == ArgTypeCode.TVMPackedFuncHandle || + tcode == ArgTypeCode.TVMModuleHandle ) { lib.checkCall( (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( @@ -1290,25 +1299,25 @@ export class Instance implements Disposable { private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { switch (tcode) { - case TypeCode.Int: - case TypeCode.UInt: + case ArgTypeCode.Int: + case ArgTypeCode.UInt: return this.memory.loadI64(rvaluePtr); - case TypeCode.Float: + case ArgTypeCode.Float: return this.memory.loadF64(rvaluePtr); - case TypeCode.TVMOpaqueHandle: { + case ArgTypeCode.TVMOpaqueHandle: { return this.memory.loadPointer(rvaluePtr); } - case TypeCode.TVMNDArrayHandle: { + case ArgTypeCode.TVMNDArrayHandle: { return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); } - case TypeCode.TVMDLTensorHandle: { + case ArgTypeCode.TVMDLTensorHandle: { assert(callbackArg); return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); } - case TypeCode.TVMPackedFuncHandle: { + case ArgTypeCode.TVMPackedFuncHandle: { return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); } - case TypeCode.TVMModuleHandle: { + case ArgTypeCode.TVMModuleHandle: { return new Module( this.memory.loadPointer(rvaluePtr), this.lib, @@ -1317,17 +1326,17 @@ export class Instance implements Disposable { } ); } - case TypeCode.Null: return undefined; - case TypeCode.TVMContext: { + case ArgTypeCode.Null: return undefined; + case ArgTypeCode.TVMContext: { const deviceType = this.memory.loadI32(rvaluePtr); const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); return this.context(deviceType, deviceId); } - case TypeCode.TVMStr: { + case ArgTypeCode.TVMStr: { const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); return ret; } - case TypeCode.TVMBytes: { + case ArgTypeCode.TVMBytes: { return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); } default: