Skip to content

Commit

Permalink
[refactor] Remove empty_copy() and copy() from Matrix/Struct (#3536)
Browse files Browse the repository at this point in the history
* [refactor] Remove empty_copy() and copy() from Matrix/Struct

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
strongoier and taichi-gardener committed Nov 17, 2021
1 parent 1b4fa21 commit c3f1102
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 47 deletions.
32 changes: 6 additions & 26 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import numbers
from collections.abc import Iterable

Expand All @@ -13,7 +12,6 @@
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.ops import cast
from taichi.lang.types import CompoundType
from taichi.lang.util import (cook_dtype, in_python_scope, python_scope,
taichi_scope, to_numpy_type, to_pytorch_type)
Expand Down Expand Up @@ -356,16 +354,8 @@ def w(self, value):
@property
@python_scope
def value(self):
if isinstance(self.entries[0], SNodeHostAccess):
# fetch values from SNodeHostAccessor
ret = self.empty_copy()
for i in range(self.n):
for j in range(self.m):
ret.entries[i * self.m + j] = self(i, j)
else:
# is local python-scope matrix
ret = self.entries
return ret
return Matrix([[self(i, j) for j in range(self.m)]
for i in range(self.n)])

# host access & python scope operation
@python_scope
Expand Down Expand Up @@ -420,14 +410,6 @@ def set_entries(self, value):
for j in range(self.m):
self[i, j] = value[i][j]

def empty_copy(self):
return Matrix.empty(self.n, self.m)

def copy(self):
ret = self.empty_copy()
ret.entries = copy.copy(self.entries)
return ret

@taichi_scope
def cast(self, dtype):
"""Cast the matrix element data type.
Expand All @@ -440,10 +422,9 @@ def cast(self, dtype):
"""
_taichi_skip_traceback = 1
ret = self.copy()
for i, entry in enumerate(ret.entries):
ret.entries[i] = ops_mod.cast(entry, dtype)
return ret
return Matrix(
[[ops_mod.cast(self(i, j), dtype) for j in range(self.m)]
for i in range(self.n)])

def trace(self):
"""The sum of a matrix diagonal elements.
Expand Down Expand Up @@ -1352,8 +1333,7 @@ def cast(self, mat):
int(mat(i, j)) if self.dtype in ti.integer_types else float(
mat(i, j)) for j in range(self.m)
] for i in range(self.n)])
return Matrix([[cast(mat(i, j), self.dtype) for j in range(self.m)]
for i in range(self.n)])
return mat.cast(self.dtype)

def filled_with_scalar(self, value):
return Matrix([[value for _ in range(self.m)] for _ in range(self.n)])
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,13 +919,13 @@ def rescale_index(a, b, I):
assert isinstance(
I, matrix.Matrix
), f"The third argument must be an index (list or ti.Vector)"
Ib = I.copy()
entries = [I(i) for i in range(I.n)]
for n in range(min(I.n, min(len(a.shape), len(b.shape)))):
if a.shape[n] > b.shape[n]:
Ib.entries[n] = I.entries[n] // (a.shape[n] // b.shape[n])
entries[n] = I(n) // (a.shape[n] // b.shape[n])
if a.shape[n] < b.shape[n]:
Ib.entries[n] = I.entries[n] * (b.shape[n] // a.shape[n])
return Ib
entries[n] = I(n) * (b.shape[n] // a.shape[n])
return matrix.Vector(entries)


def get_addr(f, indices):
Expand Down
16 changes: 0 additions & 16 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import numbers

from taichi.lang import expr, impl
Expand Down Expand Up @@ -186,21 +185,6 @@ def assign_renamed(x, y):

return self.element_wise_writeback_binary(assign_renamed, val)

def empty_copy(self):
"""
Nested structs and matrices need to be recursively handled.
"""
struct = Struct.empty(self.keys)
for k, v in self.items:
if isinstance(v, (Struct, Matrix)):
struct.entries[k] = v.empty_copy()
return struct

def copy(self):
ret = self.empty_copy()
ret.entries = copy.copy(self.entries)
return ret

def __len__(self):
"""Get the number of entries in a custom struct"""
return len(self.entries)
Expand Down
5 changes: 4 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,10 @@ class RangeAssumptionExpression : public Expression {
const Expr &base,
int low,
int high)
: input(input), base(base), low(low), high(high) {
: input(load_if_ptr(input)),
base(load_if_ptr(base)),
low(low),
high(high) {
}

void type_check() override;
Expand Down

0 comments on commit c3f1102

Please sign in to comment.