Skip to content

Commit

Permalink
Add tests for complex number type support in algebras (#355)
Browse files Browse the repository at this point in the history
These tests ensure that complex number type support behaves as one would expect and type info is not lost.

Note that this also tweaks printing of multivector coefficients, since `round(complex)` is not allowed, but `np.round(complex)` is.
To prevent breakage on object arrays, we only perform rounding if the array is floating-point (`np.inexact`).

Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
  • Loading branch information
hugohadfield and eric-wieser committed Oct 19, 2020
1 parent a1b1ba3 commit ce0bc74
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 3 deletions.
9 changes: 6 additions & 3 deletions clifford/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def __and__(self, other) -> 'MultiVector':

def __mul__(self, other) -> 'MultiVector':
""" ``self * other``, the geometric product :math:`MN` """

other, mv = self._checkOther(other, coerce=False)

if mv:
Expand Down Expand Up @@ -496,10 +495,14 @@ def __str__(self) -> str:
else:
if coeff < 0:
sep = seps[1]
abs_coeff = -round(coeff, p)
sign = -1
else:
sep = seps[0]
abs_coeff = round(coeff, p)
sign = 1
if np.issubdtype(self.value.dtype, np.inexact):
abs_coeff = sign*np.round(coeff, p)
else:
abs_coeff = sign*coeff

if grade == 0:
# scalar
Expand Down
107 changes: 107 additions & 0 deletions clifford/test/test_complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
This file tests the behaviour of clifford with complex numbers. The tests
below are based on an informed guess about the correct behavior, and we
should not be afraid of changing them if the things we're testing for turn
out to not be the conventional definitions.
"""

import pytest
import numpy as np
import clifford as cf

from clifford import Cl, conformalize


# using fixtures here results in them only being created if needed
@pytest.fixture(scope='module')
def g2():
return Cl(2)[0]


@pytest.fixture(scope='module')
def g3():
return Cl(3)[0]


@pytest.fixture(scope='module')
def g4():
return Cl(4)[0]


@pytest.fixture(scope='module')
def g5():
return Cl(5)[0]


@pytest.fixture(scope='module')
def g3c():
return conformalize(Cl(3)[0])[0]


@pytest.fixture(scope='module')
def pga():
from clifford.pga import layout
return layout


class TestCliffordComplex:
@pytest.fixture(params=[3, 4, 5, 'g3c', (3, 0, 1)], ids='Cl({})'.format)
def algebra(self, request, g3, g4, g5, g3c, pga):
return {3: g3, 4: g4, 5: g5, 'g3c': g3c, (3, 0, 1): pga}[request.param]

def test_addition(self, algebra):
A = algebra.randomMV()
B = algebra.randomMV()
res = (A + 1j*B).value
res2 = A.value + 1j*B.value
np.testing.assert_array_equal(res, res2)

def test_subtraction(self, algebra):
A = algebra.randomMV()
B = algebra.randomMV()
res = (A - 1j*B).value
res2 = A.value - 1j*B.value
np.testing.assert_array_equal(res, res2)

@pytest.mark.parametrize('p', [cf.operator.gp, cf.operator.op, cf.operator.ip,
cf.MultiVector.lc, cf.MultiVector.vee])
def test_prod(self, algebra, p):
A = algebra.randomMV()
B = algebra.randomMV()
C = algebra.randomMV()
D = algebra.randomMV()
res = (p(A + 1j*B, C + 1j*D)).value
res2 = p(A, C).value + 1j*p(B, C).value + 1j*p(A, D).value - p(B, D).value
np.testing.assert_allclose(res, res2)

def test_reverse(self, algebra):
A = algebra.randomMV()
B = algebra.randomMV()
res = (~(A + 1j*B)).value
res2 = (~A).value + 1j*(~B).value
np.testing.assert_array_equal(res, res2)

def test_grade_selection(self, algebra):
A = algebra.randomMV()
B = algebra.randomMV()
res = ((A + 1j*B)(2)).value
res2 = A(2).value + 1j*B(2).value
np.testing.assert_array_equal(res, res2)

def test_dual(self, algebra):
A = algebra.randomMV()
B = algebra.randomMV()
res = (A + 1j*B).dual().value
res2 = A.dual().value + 1j*B.dual().value
np.testing.assert_array_equal(res, res2)

def test_inverse(self, algebra):
if 0 in algebra.sig:
pytest.xfail("The inverse in degenerate metrics is known to fail")
A = algebra.randomMV()
B = algebra.randomMV()
original = (A + 1j*B)
res = algebra.scalar + 0j
res2 = original*original.inv()
np.testing.assert_almost_equal(res2.value.real, res.value.real)
np.testing.assert_almost_equal(res2.value.imag, res.value.imag)

0 comments on commit ce0bc74

Please sign in to comment.