Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Add any()/all() functions for Matrix, move casting ops to common_ops #1064

Merged
merged 8 commits into from May 26, 2020
Merged
4 changes: 3 additions & 1 deletion docs/matrix.rst
Expand Up @@ -13,9 +13,11 @@ Matrices
- ``ti.tr(A)``
- ``ti.determinant(A, type)``
- ``ti.cross(a, b)``, where ``a`` and ``b`` are 3D vectors (i.e. ``3x1`` matrices)
- ``A.cast(type)``
- ``A.cast(type)`` or simply ``int(A)`` and ``float(A)``
- ``R, S = ti.polar_decompose(A, ti.f32)``
- ``U, sigma, V = ti.svd(A, ti.f32)`` (Note that ``sigma`` is a ``3x3`` diagonal matrix)
- ``any(A)``
- ``all(A)``

TODO: doc here better like Vector. WIP

Expand Down
8 changes: 8 additions & 0 deletions python/taichi/lang/common_ops.py
Expand Up @@ -110,3 +110,11 @@ def __invert__(self): # ~a => a.__invert__()
def __not__(self): # not a => a.__not__()
import taichi as ti
return ti.logical_not(self)

def __ti_int__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_ip)

def __ti_float__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_fp)
8 changes: 0 additions & 8 deletions python/taichi/lang/expr.py
Expand Up @@ -223,14 +223,6 @@ def fill(self, val):
from .meta import fill_tensor
fill_tensor(self, val)

def __ti_int__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_ip)

def __ti_float__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_fp)

def parent(self, n=1):
import taichi as ti
p = self.ptr.snode()
Expand Down
14 changes: 14 additions & 0 deletions python/taichi/lang/matrix.py
Expand Up @@ -482,6 +482,20 @@ def min(self):
ret = impl.min(ret, self.entries[i])
return ret

def any(self):
import taichi as ti
ret = (self.entries[0] != ti.expr_init(0))
for i in range(1, len(self.entries)):
ret = ret + (self.entries[i] != ti.expr_init(0))
return -(ret < ti.expr_init(0))

def all(self):
import taichi as ti
ret = self.entries[0] != ti.expr_init(0)
for i in range(1, len(self.entries)):
ret = ret + (self.entries[i] != ti.expr_init(0))
return -(ret == ti.expr_init(-len(self.entries)))

def dot(self, other):
assert self.m == 1 and other.m == 1
return (self.transposed(self) @ other).subscript(0, 0)
Expand Down
10 changes: 10 additions & 0 deletions python/taichi/lang/ops.py
Expand Up @@ -342,6 +342,16 @@ def ti_min(*args):
return ti_min(args[0], ti_min(*args[1:]))


def ti_any(a):
assert hasattr(a, 'any')
return a.any()


def ti_all(a):
assert hasattr(a, 'all')
return a.all()


def append(l, indices, val):
import taichi as ti
a = ti.expr_init(
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/transformer.py
Expand Up @@ -554,6 +554,10 @@ def visit_Call(self, node):
node.func = self.parse_expr('ti.ti_int')
elif func_name == 'float':
node.func = self.parse_expr('ti.ti_float')
elif func_name == 'any':
node.func = self.parse_expr('ti.ti_any')
elif func_name == 'all':
node.func = self.parse_expr('ti.ti_all')
else:
pass
return node
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_linalg.py
Expand Up @@ -173,3 +173,36 @@ def fill():
assert m2[0][j, i] == int(i + 3 * j + 1)
assert m3[0][i, j] == int(i + 3 * j + 1)
assert m4[0][j, i] == int(i + 3 * j + 1)


@ti.all_archs
def test_any_all():
a = ti.Matrix(2, 2, dt=ti.i32, shape=())
b = ti.var(dt=ti.i32, shape=())

@ti.kernel
def func_any():
b[None] = any(a[None])

@ti.kernel
def func_all():
b[None] = all(a[None])

for i in range(2):
for j in range(2):
a[None][0, 0] = i
a[None][1, 0] = j
a[None][1, 1] = i
a[None][0, 1] = j

func_any()
if i == 1 or j == 1:
assert b[None] == 1
else:
assert b[None] == 0

func_all()
if i == 1 and j == 1:
assert b[None] == 1
else:
assert b[None] == 0