Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove Program::current_ast_builder() #7075

Merged
merged 15 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def _svd3d(A, dt, iters=None):
else:
iters = 8
if dt == f32:
rets = get_runtime().prog.current_ast_builder().sifakis_svd_f32(
rets = get_runtime().compiling_callable.ast_builder().sifakis_svd_f32(
A.ptr, iters)
else:
rets = get_runtime().prog.current_ast_builder().sifakis_svd_f64(
rets = get_runtime().compiling_callable.ast_builder().sifakis_svd_f64(
A.ptr, iters)
assert len(rets) == 21
U_entries = rets[:9]
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, ptr_expr, num_dims) -> None:

@taichi_scope
def sample_lod(self, uv, lod):
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
args_group = make_expr_group(*_get_entries(uv), lod)
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kSampleLod,
self.ptr_expr, args_group)
Expand All @@ -42,7 +42,7 @@ def sample_lod(self, uv, lod):

@taichi_scope
def fetch(self, index, lod):
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
args_group = make_expr_group(*_get_entries(index), lod)
v = ast_builder.make_texture_op_expr(
_ti_core.TextureOpType.kFetchTexel, self.ptr_expr, args_group)
Expand All @@ -69,7 +69,7 @@ def __init__(self, ptr_expr, num_dims) -> None:

@taichi_scope
def load(self, index):
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
args_group = make_expr_group(*_get_entries(index))
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kLoad,
self.ptr_expr, args_group)
Expand All @@ -89,7 +89,7 @@ def load(self, index):

@taichi_scope
def store(self, index, value):
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
args_group = make_expr_group(*_get_entries(index),
*_get_entries(value))
impl.expr_init(
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, arr, indices_first):

@taichi_scope
def subscript(self, i, j):
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()

indices_second = (i, ) if len(self.arr.element_shape()) == 1 else (i,
j)
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def transform_as_kernel():
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type,
ctx.is_real_function)
impl.get_runtime().prog.finalize_rets()
impl.get_runtime().compiling_callable.finalize_rets()
for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation,
primitive_types.RefType):
Expand Down Expand Up @@ -841,7 +841,7 @@ def build_Attribute(ctx, node):
attr_len = len(node.attr)
if attr_len == 1:
node.ptr = Expr(impl.get_runtime(
).prog.current_ast_builder().expr_subscript(
).compiling_callable.ast_builder().expr_subscript(
node.value.ptr.ptr,
make_expr_group(keygroup.index(node.attr)),
impl.get_runtime().get_current_src_info()))
Expand Down
18 changes: 10 additions & 8 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@

@taichi_scope
def expr_init_shared_array(shape, element_type):
return get_runtime().prog.current_ast_builder().expr_alloca_shared_array(
shape, element_type)
return get_runtime().compiling_callable.ast_builder(
).expr_alloca_shared_array(shape, element_type)


@taichi_scope
def expr_init(rhs):
if rhs is None:
return Expr(get_runtime().prog.current_ast_builder().expr_alloca())
return Expr(
get_runtime().compiling_callable.ast_builder().expr_alloca())
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, Matrix):
Expand Down Expand Up @@ -69,7 +70,7 @@ def expr_init(rhs):
return rhs
if hasattr(rhs, '_data_oriented'):
return rhs
return Expr(get_runtime().prog.current_ast_builder().expr_var(
return Expr(get_runtime().compiling_callable.ast_builder().expr_var(
Expr(rhs).ptr,
get_runtime().get_current_src_info()))

Expand Down Expand Up @@ -131,7 +132,7 @@ def check_validity(x):

@taichi_scope
def subscript(ast_builder, value, *_indices, skip_reordered=False):
ast_builder = get_runtime().prog.current_ast_builder()
ast_builder = get_runtime().compiling_callable.ast_builder()
# Directly evaluate in Python for non-Taichi types
if not isinstance(
value,
Expand Down Expand Up @@ -282,6 +283,7 @@ def __init__(self, kernels=None):
self.compiled_functions = {}
self.src_info_stack = []
self.inside_kernel = False
self.compiling_callable = None # pointer to instance of lang::Kernel/Function
self.current_kernel = None
self.global_vars = []
self.grad_vars = []
Expand Down Expand Up @@ -842,7 +844,7 @@ def add_separators(_vars):

_vars = add_separators(_vars)
entries = ti_format_list_to_content_entries(_vars)
get_runtime().prog.current_ast_builder().create_print(entries)
get_runtime().compiling_callable.ast_builder().create_print(entries)


@taichi_scope
Expand Down Expand Up @@ -884,7 +886,7 @@ def ti_format(*args, **kwargs):
def ti_assert(cond, msg, extra_args):
# Mostly a wrapper to help us convert from Expr (defined in Python) to
# _ti_core.Expr (defined in C++)
get_runtime().prog.current_ast_builder().create_assert_stmt(
get_runtime().compiling_callable.ast_builder().create_assert_stmt(
Expr(cond).ptr, msg, extra_args)


Expand Down Expand Up @@ -1064,7 +1066,7 @@ def stop_grad(x):
Args:
x (:class:`~taichi.Field`): A field.
"""
get_runtime().prog.current_ast_builder().stop_grad(x.snode.ptr)
get_runtime().compiling_callable.ast_builder().stop_grad(x.snode.ptr)


def current_cfg():
Expand Down
13 changes: 7 additions & 6 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def decl_scalar_arg(dtype):
is_ref = True
dtype = dtype.tp
dtype = cook_dtype(dtype)
arg_id = impl.get_runtime().prog.decl_scalar_arg(dtype)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_arg(dtype)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))


Expand All @@ -71,15 +71,16 @@ def decl_sparse_matrix(dtype):
value_type = cook_dtype(dtype)
ptr_type = cook_dtype(u64)
# Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
arg_id = impl.get_runtime().prog.decl_scalar_arg(ptr_type)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_arg(ptr_type)
return SparseMatrixProxy(
_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)


def decl_ndarray_arg(dtype, dim, element_shape, layout):
dtype = cook_dtype(dtype)
element_dim = len(element_shape)
arg_id = impl.get_runtime().prog.decl_arr_arg(dtype, dim, element_shape)
arg_id = impl.get_runtime().compiling_callable.insert_arr_arg(
dtype, dim, element_shape)
if layout == Layout.AOS:
element_dim = -element_dim
return AnyArray(
Expand All @@ -89,14 +90,14 @@ def decl_ndarray_arg(dtype, dim, element_shape, layout):

def decl_texture_arg(num_dimensions):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().prog.decl_texture_arg(f32)
arg_id = impl.get_runtime().compiling_callable.insert_texture_arg(f32)
return TextureSampler(
_ti_core.make_texture_ptr_expr(arg_id, num_dimensions), num_dimensions)


def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().prog.decl_texture_arg(f32)
arg_id = impl.get_runtime().compiling_callable.insert_texture_arg(f32)
return RWTextureAccessor(
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, num_channels,
channel_format, lod), num_dimensions)
Expand All @@ -114,4 +115,4 @@ def decl_ret(dtype, real_func=False):
[dtype.n, dtype.m], dtype.dtype)
else:
dtype = cook_dtype(dtype)
impl.get_runtime().prog.decl_ret(dtype)
impl.get_runtime().compiling_callable.insert_ret(dtype)
22 changes: 15 additions & 7 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __call__(self, *args, **kwargs):
self,
is_kernel=False,
args=args,
ast_builder=impl.get_runtime().prog.current_ast_builder(),
ast_builder=impl.get_runtime().current_kernel.ast_builder(),
is_real_function=self.is_real_function)
ret = transform_tree(tree, ctx)
if not self.is_real_function:
Expand All @@ -261,14 +261,14 @@ def func_call_rvalue(self, key, args):
elif isinstance(args[i],
impl.Expr) and args[i].ptr.is_tensor():
non_template_args.extend([
Expr(x) for x in impl.get_runtime().prog.
current_ast_builder().expand_exprs([args[i].ptr])
Expr(x) for x in impl.get_runtime().compiling_callable.
ast_builder().expand_exprs([args[i].ptr])
])
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args,
real_func_arg=True)
func_call = impl.get_runtime().prog.current_ast_builder(
func_call = impl.get_runtime().compiling_callable.ast_builder(
).insert_func_call(self.taichi_functions[key.instance_id],
non_template_args)
if self.return_type is None:
Expand All @@ -288,8 +288,11 @@ def do_compile(self, key, args):
fn = impl.get_runtime().prog.create_function(key)

def func_body():
old_callable = impl.get_runtime().compiling_callable
impl.get_runtime().compiling_callable = fn
ctx.ast_builder = fn.ast_builder()
transform_tree(tree, ctx)
impl.get_runtime().compiling_callable = old_callable

self.taichi_functions[key.instance_id] = fn
self.compiled[key.instance_id] = func_body
Expand Down Expand Up @@ -487,6 +490,10 @@ def __init__(self, _func, autodiff_mode, _classkernel=False):
self.compiled_kernels = {}
self.has_print = False

def ast_builder(self):
assert self.kernel_cpp is not None
return self.kernel_cpp.ast_builder()

def reset(self):
self.runtime = impl.get_runtime()

Expand Down Expand Up @@ -579,8 +586,11 @@ def taichi_ast_generator(kernel_cxx):
"Please check if you have direct/indirect invocation of kernels within kernels. "
"Note that some methods provided by the Taichi standard library may invoke kernels, "
"and please move their invocations to Python-scope.")
self.kernel_cpp = kernel_cxx
self.runtime.inside_kernel = True
self.runtime.current_kernel = self
assert self.runtime.compiling_callable is None
self.runtime.compiling_callable = kernel_cxx
try:
ctx.ast_builder = kernel_cxx.ast_builder()
transform_tree(tree, ctx)
Expand All @@ -592,12 +602,10 @@ def taichi_ast_generator(kernel_cxx):
finally:
self.runtime.inside_kernel = False
self.runtime.current_kernel = None
self.runtime.compiling_callable = None

taichi_kernel = impl.get_runtime().prog.create_kernel(
taichi_ast_generator, kernel_name, self.autodiff_mode)

self.kernel_cpp = taichi_kernel

assert key not in self.runtime.compiled_functions
self.runtime.compiled_functions[key] = self.get_function_body(
taichi_kernel)
Expand Down
12 changes: 7 additions & 5 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def make_matrix(arr, dt=None):
else:
dt = cook_dtype(dt)
return expr.Expr(
impl.get_runtime().prog.current_ast_builder().make_matrix_expr(
impl.get_runtime().compiling_callable.ast_builder().make_matrix_expr(
shape, dt, [expr.Expr(elt).ptr for elt in arr]))


Expand Down Expand Up @@ -1491,8 +1491,9 @@ def __call__(self, *args):
entries += list(x.ravel())
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_exprs([x.ptr])
impl.Expr(e)
for e in impl.get_runtime().compiling_callable.ast_builder(
).expand_exprs([x.ptr])
]
elif isinstance(x, Matrix):
entries += x.entries
Expand Down Expand Up @@ -1605,8 +1606,9 @@ def __call__(self, *args):
entries += x.entries
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_exprs([x.ptr])
impl.Expr(e)
for e in impl.get_runtime().compiling_callable.ast_builder(
).expand_exprs([x.ptr])
]
else:
entries.append(x)
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def _TetMesh():
class MeshElementFieldProxy:
def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
entry_expr: impl.Expr):
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()

self.mesh = mesh
self.element_type = element_type
Expand Down Expand Up @@ -653,7 +653,7 @@ def ptr(self):

@property
def id(self): # return the global non-reordered index
ast_builder = impl.get_runtime().prog.current_ast_builder()
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
l2g_expr = impl.Expr(
ast_builder.mesh_index_conversion(self.mesh.mesh_ptr,
self.element_type,
Expand Down
Loading