Skip to content

Commit

Permalink
[aot] Test for AOT device capability (#6618)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PENGUINLIONG and pre-commit-ci[bot] committed Nov 16, 2022
1 parent 73a08ac commit a0227ca
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 15 deletions.
2 changes: 1 addition & 1 deletion python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from taichi.lang._ndarray import *
from taichi.lang._ndrange import ndrange
from taichi.lang._texture import Texture
from taichi.lang.enums import Format, Layout
from taichi.lang.enums import DeviceCapability, Format, Layout
from taichi.lang.exception import *
from taichi.lang.field import *
from taichi.lang.impl import *
Expand Down
34 changes: 33 additions & 1 deletion python/taichi/lang/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,36 @@
SNodeGradType = _ti_core.SNodeGradType
Format = _ti_core.Format

__all__ = ['Layout', 'AutodiffMode', 'SNodeGradType', 'Format']

class DeviceCapability:
spirv_version_1_3 = "spirv_version=1.3"
spirv_version_1_4 = "spirv_version=1.4"
spirv_version_1_5 = "spirv_version=1.5"
spirv_has_int8 = "spirv_has_int8"
spirv_has_int16 = "spirv_has_int16"
spirv_has_int64 = "spirv_has_int64"
spirv_has_float16 = "spirv_has_float16"
spirv_has_float64 = "spirv_has_float64"
spirv_has_atomic_i64 = "spirv_has_atomic_i64"
spirv_has_atomic_float16 = "spirv_has_atomic_float16"
spirv_has_atomic_float16_add = "spirv_has_atomic_float16_add"
spirv_has_atomic_float16_minmax = "spirv_has_atomic_float16_minmax"
spirv_has_atomic_float = "spirv_has_atomic_float"
spirv_has_atomic_float_add = "spirv_has_atomic_float_add"
spirv_has_atomic_float_minmax = "spirv_has_atomic_float_minmax"
spirv_has_atomic_float64 = "spirv_has_atomic_float64"
spirv_has_atomic_float64_add = "spirv_has_atomic_float64_add"
spirv_has_atomic_float64_minmax = "spirv_has_atomic_float64_minmax"
spirv_has_variable_ptr = "spirv_has_variable_ptr"
spirv_has_physical_storage_buffer = "spirv_has_physical_storage_buffer"
spirv_has_subgroup_basic = "spirv_has_subgroup_basic"
spirv_has_subgroup_vote = "spirv_has_subgroup_vote"
spirv_has_subgroup_arithmetic = "spirv_has_subgroup_arithmetic"
spirv_has_subgroup_ballot = "spirv_has_subgroup_ballot"
spirv_has_non_semantic_info = "spirv_has_non_semantic_info"
spirv_has_no_integer_wrap_decoration = "spirv_has_no_integer_wrap_decoration"


__all__ = [
'Layout', 'AutodiffMode', 'SNodeGradType', 'Format', 'DeviceCapability'
]
22 changes: 21 additions & 1 deletion tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,32 @@ def test(a: ti.types.ndarray(), c: ti.u8):

g.run({'a': a, 'c': c})

m = ti.aot.Module(caps=['spirv_has_int8'])
m = ti.aot.Module(caps=[ti.DeviceCapability.spirv_has_int8])
m.add_graph('g_init', g)
with tempfile.TemporaryDirectory() as tmpdir:
m.save(tmpdir)


@test_utils.test(arch=[ti.vulkan])
def test_devcap():
module = ti.aot.Module(
ti.vulkan,
caps=[
ti.DeviceCapability.spirv_has_float16,
ti.DeviceCapability.spirv_has_atomic_float16_minmax
])

with tempfile.TemporaryDirectory() as tmpdir:
module.save(tmpdir)

with open(tmpdir + "/metadata.json") as f:
j = json.load(f)
caps = j["aot_data"]["required_caps"]
assert caps["spirv_version"] == 0x10300
assert caps["spirv_has_float16"] == 1
assert caps["spirv_has_atomic_float16_minmax"] == 1


@test_utils.test(arch=[ti.vulkan])
def test_module_arch_fallback():
with pytest.warns(
Expand Down
24 changes: 12 additions & 12 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,18 @@ def _get_expected_matrix_apis():

user_api = {}
user_api[ti] = [
'BitpackedFields', 'CRITICAL', 'DEBUG', 'ERROR', 'Field', 'FieldsBuilder',
'Format', 'GUI', 'INFO', 'Layout', 'Matrix', 'MatrixField',
'MatrixNdarray', 'Mesh', 'MeshInstance', 'Ndarray', 'SNode', 'ScalarField',
'ScalarNdarray', 'Struct', 'StructField', 'TRACE', 'TaichiAssertionError',
'TaichiCompilationError', 'TaichiNameError', 'TaichiRuntimeError',
'TaichiRuntimeTypeError', 'TaichiSyntaxError', 'TaichiTypeError',
'Texture', 'Vector', 'VectorNdarray', 'WARN', 'abs', 'acos', 'activate',
'ad', 'algorithms', 'aot', 'append', 'arm64', 'asin', 'assume_in_range',
'atan2', 'atomic_add', 'atomic_and', 'atomic_max', 'atomic_min',
'atomic_or', 'atomic_sub', 'atomic_xor', 'axes', 'bit_cast', 'bit_shr',
'block_local', 'cache_read_only', 'cast', 'cc', 'ceil', 'cos', 'cpu',
'cuda', 'data_oriented', 'dataclass', 'deactivate',
'BitpackedFields', 'CRITICAL', 'DEBUG', "DeviceCapability", 'ERROR',
'Field', 'FieldsBuilder', 'Format', 'GUI', 'INFO', 'Layout', 'Matrix',
'MatrixField', 'MatrixNdarray', 'Mesh', 'MeshInstance', 'Ndarray', 'SNode',
'ScalarField', 'ScalarNdarray', 'Struct', 'StructField', 'TRACE',
'TaichiAssertionError', 'TaichiCompilationError', 'TaichiNameError',
'TaichiRuntimeError', 'TaichiRuntimeTypeError', 'TaichiSyntaxError',
'TaichiTypeError', 'Texture', 'Vector', 'VectorNdarray', 'WARN', 'abs',
'acos', 'activate', 'ad', 'algorithms', 'aot', 'append', 'arm64', 'asin',
'assume_in_range', 'atan2', 'atomic_add', 'atomic_and', 'atomic_max',
'atomic_min', 'atomic_or', 'atomic_sub', 'atomic_xor', 'axes', 'bit_cast',
'bit_shr', 'block_local', 'cache_read_only', 'cast', 'cc', 'ceil', 'cos',
'cpu', 'cuda', 'data_oriented', 'dataclass', 'deactivate',
'deactivate_all_snodes', 'dx11', 'dx12', 'eig', 'exp', 'experimental',
'extension', 'f16', 'f32', 'f64', 'field', 'float16', 'float32', 'float64',
'floor', 'func', 'get_addr', 'get_compute_stream_device_time_elapsed_us',
Expand Down

0 comments on commit a0227ca

Please sign in to comment.