Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Hide internal apis about Fields #4302

Merged
merged 25 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions python/taichi/lang/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def __init__(self, _vars):
def snode(self):
"""Gets representative SNode for info purposes.

Returns:
SNode: Representative SNode (SNode of first field member).
"""
return self._snode

@property
def _snode(self):
"""Gets representative SNode for info purposes.

Returns:
SNode: Representative SNode (SNode of first field member).
"""
Expand All @@ -35,7 +44,7 @@ def shape(self):
Returns:
Tuple[Int]: Field shape.
"""
return self.snode.shape
return self._snode.shape

@property
def dtype(self):
Expand All @@ -44,16 +53,16 @@ def dtype(self):
Returns:
DataType: Data type of each individual value.
"""
return self.snode._dtype
return self._snode._dtype

@property
def name(self):
def _name(self):
"""Gets field name.

Returns:
str: Field name.
"""
return self.snode._name
return self._snode._name

def parent(self, n=1):
"""Gets an ancestor of the representative SNode in the SNode tree.
Expand All @@ -66,7 +75,7 @@ def parent(self, n=1):
"""
return self.snode.parent(n)

def get_field_members(self):
def _get_field_members(self):
"""Gets field members.

Returns:
Expand All @@ -82,7 +91,7 @@ def _loop_range(self):
"""
return self.vars[0].ptr

def set_grad(self, grad):
def _set_grad(self, grad):
"""Sets corresponding gradient field.

Args:
Expand Down Expand Up @@ -187,27 +196,27 @@ def __getitem__(self, key):
def __str__(self):
if taichi.lang.impl.inside_kernel():
return self.__repr__() # make pybind11 happy, see Matrix.__str__
if self.snode.ptr is None:
if self._snode.ptr is None:
return '<Field: Definition of this field is incomplete>'
return str(self.to_numpy())

def pad_key(self, key):
def _pad_key(self, key):
if key is None:
key = ()
if not isinstance(key, (tuple, list)):
key = (key, )
assert len(key) == len(self.shape)
return key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))

def initialize_host_accessors(self):
def _initialize_host_accessors(self):
if self.host_accessors:
return
taichi.lang.impl.get_runtime().materialize()
self.host_accessors = [
SNodeHostAccessor(e.ptr.snode()) for e in self.vars
]

def host_access(self, key):
def _host_access(self, key):
return [SNodeHostAccess(e, key) for e in self.host_accessors]


Expand Down Expand Up @@ -266,13 +275,13 @@ def from_numpy(self, arr):

@python_scope
def __setitem__(self, key, value):
self.initialize_host_accessors()
self.host_accessors[0].setter(value, *self.pad_key(key))
self._initialize_host_accessors()
self.host_accessors[0].setter(value, *self._pad_key(key))

@python_scope
def __getitem__(self, key):
self.initialize_host_accessors()
return self.host_accessors[0].getter(*self.pad_key(key))
self._initialize_host_accessors()
return self.host_accessors[0].getter(*self._pad_key(key))

def __repr__(self):
# make interactive shell happy, prevent materialization
Expand Down
10 changes: 5 additions & 5 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def subscript(value, *_indices, skip_reordered=False):
if isinstance(value, SparseMatrixProxy):
return value.subscript(*_indices)
if isinstance(value, Field):
_var = value.get_field_members()[0].ptr
_var = value._get_field_members()[0].ptr
if _var.snode() is None:
if _var.is_primal():
raise RuntimeError(
Expand All @@ -186,7 +186,7 @@ def subscript(value, *_indices, skip_reordered=False):
if isinstance(value, StructField):
return _IntermediateStruct(
{k: subscript(v, *_indices)
for k, v in value.items})
for k, v in value._items})
return Expr(_ti_core.subscript(_var, indices_expr_group))
if isinstance(value, AnyArray):
# TODO: deprecate using get_attribute to get dim
Expand Down Expand Up @@ -339,12 +339,12 @@ def _check_matrix_field_member_shape(self):
if any(shape != shapes[0] for shape in shapes):
raise RuntimeError(
'Members of the following field have different shapes ' +
f'{shapes}:\n{self._get_tb(_field.get_field_members()[0])}'
f'{shapes}:\n{self._get_tb(_field._get_field_members()[0])}'
)

def _calc_matrix_field_dynamic_index_stride(self):
for _field in self.matrix_fields:
_field.calc_dynamic_index_stride()
_field._calc_dynamic_index_stride()

def materialize(self):
self.materialize_root_fb(not self.materialized)
Expand Down Expand Up @@ -605,7 +605,7 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False):

x, x_grad = create_field_member(dtype, name)
x, x_grad = ScalarField(x), ScalarField(x_grad)
x.set_grad(x_grad)
x._set_grad(x_grad)

if shape is not None:
dim = len(shape)
Expand Down
20 changes: 10 additions & 10 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def field(cls,
entries, entries_grad = zip(*entries)
entries, entries_grad = MatrixField(entries, n, m), MatrixField(
entries_grad, n, m)
entries.set_grad(entries_grad)
entries._set_grad(entries_grad)
impl.get_runtime().matrix_fields.append(entries)

if shape is None:
Expand All @@ -853,11 +853,11 @@ def field(cls,

dim = len(shape)
if layout == Layout.SOA:
for e in entries.get_field_members():
for e in entries._get_field_members():
impl.root.dense(impl.index_nd(dim),
shape).place(ScalarField(e), offset=offset)
if needs_grad:
for e in entries_grad.get_field_members():
for e in entries_grad._get_field_members():
impl.root.dense(impl.index_nd(dim),
shape).place(ScalarField(e),
offset=offset)
Expand Down Expand Up @@ -1080,7 +1080,7 @@ class _MatrixFieldElement(_IntermediateMatrix):
def __init__(self, field, indices):
super().__init__(field.n, field.m, [
expr.Expr(ti_core.subscript(e.ptr, indices))
for e in field.get_field_members()
for e in field._get_field_members()
])
self.dynamic_index_stride = field.dynamic_index_stride

Expand Down Expand Up @@ -1114,7 +1114,7 @@ def get_scalar_field(self, *indices):
j = 0 if len(indices) == 1 else indices[1]
return ScalarField(self.vars[i * self.m + j])

def calc_dynamic_index_stride(self):
def _calc_dynamic_index_stride(self):
# Algorithm: https://github.com/taichi-dev/taichi/issues/3810
paths = [ScalarField(var).snode._path_from_root() for var in self.vars]
num_members = len(paths)
Expand Down Expand Up @@ -1236,15 +1236,15 @@ def from_numpy(self, arr):

@python_scope
def __setitem__(self, key, value):
self.initialize_host_accessors()
self._initialize_host_accessors()
self[key]._set_entries(value)

@python_scope
def __getitem__(self, key):
self.initialize_host_accessors()
key = self.pad_key(key)
host_access = self.host_access(key)
return Matrix([[host_access[i * self.m + j] for j in range(self.m)]
self._initialize_host_accessors()
key = self._pad_key(key)
_host_access = self._host_access(key)
return Matrix([[_host_access[i * self.m + j] for j in range(self.m)]
for i in range(self.n)])

def __repr__(self):
Expand Down
52 changes: 26 additions & 26 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def __init__(self, field: ScalarField, mesh_ptr: _ti_core.MeshPtr,

@python_scope
def __setitem__(self, key, value):
self.initialize_host_accessors()
self._initialize_host_accessors()
key = self.g2r_field[key]
self.host_accessors[0].setter(value, *self.pad_key(key))
self.host_accessors[0].setter(value, *self._pad_key(key))

@python_scope
def __getitem__(self, key):
self.initialize_host_accessors()
self._initialize_host_accessors()
key = self.g2r_field[key]
return self.host_accessors[0].getter(*self.pad_key(key))
return self.host_accessors[0].getter(*self._pad_key(key))


class MeshReorderedMatrixFieldProxy(MatrixField):
Expand All @@ -75,15 +75,15 @@ def __init__(self, field: MatrixField, mesh_ptr: _ti_core.MeshPtr,

@python_scope
def __setitem__(self, key, value):
self.initialize_host_accessors()
self._initialize_host_accessors()
self[key]._set_entries(value)

@python_scope
def __getitem__(self, key):
self.initialize_host_accessors()
self._initialize_host_accessors()
key = self.g2r_field[key]
key = self.pad_key(key)
return _IntermediateMatrix(self.n, self.m, self.host_access(key))
key = self._pad_key(key)
return _IntermediateMatrix(self.n, self.m, self._host_access(key))


class MeshElementField:
Expand All @@ -94,22 +94,22 @@ def __init__(self, mesh_instance, _type, attr_dict, field_dict, g2r_field):
self.field_dict = field_dict
self.g2r_field = g2r_field

self.register_fields()
self._register_fields()

@property
def keys(self):
return list(self.field_dict.keys())

@property
def members(self):
def _members(self):
return list(self.field_dict.values())

@property
def items(self):
def _items(self):
return self.field_dict.items()

@staticmethod
def make_getter(key):
def _make_getter(key):
def getter(self):
if key not in self.getter_dict:
if self.attr_dict[key].reorder:
Expand All @@ -128,17 +128,17 @@ def getter(self):

return getter

def register_fields(self):
def _register_fields(self):
self.getter_dict = {}
for k in self.keys:
setattr(MeshElementField, k,
property(fget=MeshElementField.make_getter(k)))
property(fget=MeshElementField._make_getter(k)))

def get_field_members(self):
def _get_field_members(self):
field_members = []
for m in self.members:
for m in self._members:
assert isinstance(m, Field)
field_members += m.get_field_members()
field_members += m._get_field_members()
return field_members

@python_scope
Expand All @@ -150,33 +150,33 @@ def copy_from(self, other):

@python_scope
def fill(self, val):
for v in self.members:
for v in self._members:
v.fill(val)

def initialize_host_accessors(self):
for v in self.members:
v.initialize_host_accessors()
def _initialize_host_accessors(self):
for v in self._members:
v._initialize_host_accessors()

def get_member_field(self, key):
return self.field_dict[key]

@python_scope
def from_numpy(self, array_dict):
for k, v in self.items:
for k, v in self._items:
v.from_numpy(array_dict[k])

@python_scope
def from_torch(self, array_dict):
for k, v in self.items:
for k, v in self._items:
v.from_torch(array_dict[k])

@python_scope
def to_numpy(self):
return {k: v.to_numpy() for k, v in self.items}
return {k: v.to_numpy() for k, v in self._items}

@python_scope
def to_torch(self, device=None):
return {k: v.to_torch(device=device) for k, v in self.items}
return {k: v.to_torch(device=device) for k, v in self._items}

@python_scope
def __len__(self):
Expand Down Expand Up @@ -484,7 +484,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
elif isinstance(attr, StructField):
raise RuntimeError('ti.Mesh has not support StructField yet')
else: # isinstance(attr, Field)
var = attr.get_field_members()[0].ptr
var = attr._get_field_members()[0].ptr
setattr(
self, key,
impl.Expr(_ti_core.subscript(var,
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def init(arch=None,

def no_activate(*args):
for v in args:
get_runtime().prog.no_activate(v.snode.ptr)
get_runtime().prog.no_activate(v._snode.ptr)


def block_local(*args):
Expand All @@ -382,21 +382,21 @@ def block_local(*args):
_logging.warn("""opt_level = 1 is enforced to enable bls analysis.""")
impl.current_cfg().opt_level = 1
for a in args:
for v in a.get_field_members():
for v in a._get_field_members():
get_runtime().prog.current_ast_builder().insert_snode_access_flag(
_ti_core.SNodeAccessFlag.block_local, v.ptr)


def mesh_local(*args):
for a in args:
for v in a.get_field_members():
for v in a._get_field_members():
get_runtime().prog.current_ast_builder().insert_snode_access_flag(
_ti_core.SNodeAccessFlag.mesh_local, v.ptr)


def cache_read_only(*args):
for a in args:
for v in a.get_field_members():
for v in a._get_field_members():
get_runtime().prog.current_ast_builder().insert_snode_access_flag(
_ti_core.SNodeAccessFlag.read_only, v.ptr)

Expand Down
Loading