Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Replace internal representation of Python-scope ti.Matrix with numpy arrays #7559

Merged
merged 7 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import ChainMap
from sys import version_info

import numpy as np
from taichi._lib import core as _ti_core
from taichi.lang import (_ndarray, any_array, expr, impl, kernel_arguments,
matrix, mesh)
Expand All @@ -18,7 +19,7 @@
TaichiTypeError, handle_exception_from_cpp)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
from taichi.lang.matrix import Matrix, MatrixType, Vector
from taichi.lang.snode import append, deactivate, length
from taichi.lang.struct import Struct, StructType
from taichi.lang.util import is_taichi_class, to_taichi_type
Expand Down Expand Up @@ -746,7 +747,7 @@ def build_Return(ctx, node):
values = node.value.ptr
if isinstance(values, Matrix):
values = itertools.chain.from_iterable(values.to_list()) if\
not is_vector(values) else iter(values.to_list())
values.ndim == 1 else iter(values.to_list())
else:
values = [values]
ctx.ast_builder.create_kernel_exprgroup_return(
Expand Down Expand Up @@ -1021,7 +1022,7 @@ def build_Compare(ctx, node):
f'"{type(node_op).__name__}" is not supported in Taichi kernels.'
)
val = ti_ops.bit_and(val, op(l, r))
if not isinstance(val, bool):
if not isinstance(val, (bool, np.bool_)):
val = ti_ops.cast(val, primitive_types.i32)
node.ptr = val
return node.ptr
Expand Down
10 changes: 4 additions & 6 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.exception import TaichiCompilationError, TaichiTypeError
from taichi.lang.matrix import make_matrix
from taichi.lang.util import is_taichi_class, to_numpy_type
from taichi.lang.util import is_matrix_class, is_taichi_class, to_numpy_type
from taichi.types import primitive_types
from taichi.types.primitive_types import integer_types, real_types

Expand All @@ -20,10 +20,8 @@ def __init__(self, *args, tb=None, dtype=None):
elif isinstance(args[0], Expr):
self.ptr = args[0].ptr
self.tb = args[0].tb
elif is_taichi_class(args[0]):
raise TaichiTypeError(
'Cannot initialize scalar expression from '
f'taichi class: {type(args[0])}')
elif is_matrix_class(args[0]):
self.ptr = make_matrix(args[0].to_list()).ptr
elif isinstance(args[0], (list, tuple)):
self.ptr = make_matrix(args[0]).ptr
else:
Expand Down Expand Up @@ -97,7 +95,7 @@ def _clamp_unsigned_to_range(npty, val):


def make_constant_expr(val, dtype):
if isinstance(val, bool):
if isinstance(val, (bool, np.bool_)):
constant_dtype = primitive_types.i32
return Expr(_ti_core.make_const_expr_int(constant_dtype, val))

Expand Down
13 changes: 6 additions & 7 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from types import FunctionType, MethodType
from typing import Any, Iterable, Sequence

import numpy as np
from taichi._lib import core as _ti_core
from taichi._snode.fields_builder import FieldsBuilder
from taichi.lang._ndarray import ScalarNdarray
Expand Down Expand Up @@ -43,11 +44,7 @@ def expr_init(rhs):
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, Matrix):
if rhs.ndim == 1:
entries = [rhs(i) for i in range(rhs.n)]
else:
entries = [[rhs(i, j) for j in range(rhs.m)] for i in range(rhs.n)]
return make_matrix(entries)
return make_matrix(rhs.to_list())
if isinstance(rhs, SharedArray):
return rhs
if isinstance(rhs, Struct):
Expand Down Expand Up @@ -167,8 +164,8 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):

flattened_indices = []
for _index in _indices:
if is_taichi_class(_index):
ind = _index.entries
if isinstance(_index, Matrix):
ind = _index.to_list()
elif isinstance(_index, slice):
ind = [_index]
has_slice = True
Expand Down Expand Up @@ -1068,6 +1065,8 @@ def static(x, *xs) -> Any:
(bool, int, float, range, list, tuple, enumerate,
GroupedNDRange, _Ndrange, zip, filter, map)) or x is None:
return x
if isinstance(x, (np.bool_, np.integer, np.floating)):
return x

if isinstance(x, AnyArray):
return x
Expand Down
38 changes: 21 additions & 17 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TaichiTypeError, handle_exception_from_cpp)
from taichi.lang.expr import Expr
from taichi.lang.kernel_arguments import KernelArgument
from taichi.lang.matrix import Matrix, MatrixType
from taichi.lang.matrix import Matrix, MatrixType, Vector
from taichi.lang.shell import _shell_pop_print
from taichi.lang.struct import StructType
from taichi.lang.util import (cook_dtype, has_paddle, has_pytorch,
Expand Down Expand Up @@ -662,12 +662,13 @@ def func__(*args):
provided = type(v)
# Note: do not use sth like "needed == f32". That would be slow.
if id(needed) in primitive_types.real_type_ids:
if not isinstance(v, (float, int)):
if not isinstance(v,
(float, int, np.floating, np.integer)):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
launch_ctx.set_arg_float(actual_argument_slot, float(v))
elif id(needed) in primitive_types.integer_type_ids:
if not isinstance(v, int):
if not isinstance(v, (int, np.integer)):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
if is_signed(cook_dtype(needed)):
Expand Down Expand Up @@ -803,7 +804,9 @@ def call_back():
for a in range(needed.n):
for b in range(needed.m):
val = v[a, b] if needed.ndim == 2 else v[a]
if not isinstance(val, (int, float)):
if not isinstance(
val,
(int, float, np.integer, np.floating)):
raise TaichiRuntimeTypeError.get(
i, needed.dtype.to_string(), type(val))
launch_ctx.set_arg_float(
Expand All @@ -813,7 +816,7 @@ def call_back():
for a in range(needed.n):
for b in range(needed.m):
val = v[a, b] if needed.ndim == 2 else v[a]
if not isinstance(val, int):
if not isinstance(val, (int, np.integer)):
raise TaichiRuntimeTypeError.get(
i, needed.dtype.to_string(), type(val))
if is_signed(needed.dtype):
Expand Down Expand Up @@ -872,19 +875,20 @@ def call_back():
ret = t_kernel.get_ret_uint(0)
elif id(ret_dt) in primitive_types.real_type_ids:
ret = t_kernel.get_ret_float(0)
elif id(ret_dt.dtype) in primitive_types.integer_type_ids:
if is_signed(cook_dtype(ret_dt.dtype)):
it = iter(t_kernel.get_ret_int_tensor(0))
else:
it = iter(t_kernel.get_ret_uint_tensor(0))
ret = Matrix([[next(it) for _ in range(ret_dt.m)]
for _ in range(ret_dt.n)],
ndim=getattr(ret_dt, 'ndim', 2))
else:
it = iter(t_kernel.get_ret_float_tensor(0))
ret = Matrix([[next(it) for _ in range(ret_dt.m)]
for _ in range(ret_dt.n)],
ndim=getattr(ret_dt, 'ndim', 2))
if id(ret_dt.dtype
) in primitive_types.integer_type_ids:
if is_signed(cook_dtype(ret_dt.dtype)):
it = iter(t_kernel.get_ret_int_tensor(0))
else:
it = iter(t_kernel.get_ret_uint_tensor(0))
else:
it = iter(t_kernel.get_ret_float_tensor(0))
if ret_dt.ndim == 1:
ret = Vector([next(it) for _ in range(ret_dt.n)])
else:
ret = Matrix([[next(it) for _ in range(ret_dt.m)]
for _ in range(ret_dt.n)])
if callbacks:
for c in callbacks:
c()
Expand Down
Loading