Skip to content

Commit

Permalink
Merge pull request #316 from hugohadfield/dcga_dpga_test_speedup
Browse files Browse the repository at this point in the history
Reduces the no-JIT run time of dpga and dg3c tests
  • Loading branch information
hugohadfield committed May 21, 2020
2 parents 0feacbf + e077f9f commit b385669
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
18 changes: 15 additions & 3 deletions clifford/test/test_dg3c.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
import pytest
import numpy as np
import numba

from ..dg3c import *

"""
Expand All @@ -11,6 +13,10 @@
https://doi.org/10.1007/s00006-017-0784-0
"""

too_slow_without_jit = pytest.mark.skipif(
numba.config.DISABLE_JIT, reason="test is too slow without JIT"
)


class BasicTests(unittest.TestCase):
def test_metric(self):
Expand All @@ -35,7 +41,7 @@ def test_up_down(self):
Test that we can map points up and down into the dpga
"""
rng = np.random.RandomState()
for i in range(100):
for i in range(1 if numba.config.DISABLE_JIT else 100):
pnt_vector = rng.randn(3)
pnt = up(pnt_vector)
res = down(100*pnt)
Expand All @@ -51,7 +57,7 @@ def test_up_down_cga1(self):
"""
rng = np.random.RandomState()
pnt_vector = rng.randn(3)
for i in range(100):
for i in range(10 if numba.config.DISABLE_JIT else 100):
pnt = up_cga1(pnt_vector)
res = down_cga1(100*pnt)
np.testing.assert_allclose(res, pnt_vector)
Expand All @@ -62,6 +68,7 @@ def test_up_down_cga1(self):


class GeometricPrimitiveTests(unittest.TestCase):
@too_slow_without_jit
def test_reciprocality(self):
"""
Ensure that the cyclide ops and the reciprocal frame are
Expand Down Expand Up @@ -108,6 +115,7 @@ def test_line(self):
assert Ldcga | up(pnt_vec_b) == 0 * eo
assert Ldcga | up(0.5*pnt_vec_a + 0.5*pnt_vec_b) == 0 * eo

@too_slow_without_jit
def test_translation(self):
rng = np.random.RandomState()
# Make a dcga line
Expand Down Expand Up @@ -168,7 +176,8 @@ def test_translation(self):
Tdcga = (Tc1 * Tc2).normal()
assert (Tdcga * E * ~Tdcga) | eo == 0 * e1

def test_rotation(self):
@too_slow_without_jit
def test_line_rotation(self):
theta = np.pi/2
RC1 = np.e ** (-0.5*theta*e12)
RC2 = np.e ** (-0.5*theta*e67)
Expand All @@ -192,6 +201,8 @@ def test_rotation(self):
assert (Rdcga * Ldcga * ~Rdcga)|up(pnt_vec_rotated) == 0*e1
np.testing.assert_allclose((Rdcga * Ldcga * ~Rdcga).value, Ldcga_rotated.value, rtol=1E-4, atol=1E-6)

@too_slow_without_jit
def test_quadric_rotation(self):
# Construct and ellipsoid
px = 0
py = 2.5
Expand Down Expand Up @@ -227,6 +238,7 @@ def test_rotation(self):

assert Erot|eo == 0*eo

@too_slow_without_jit
def test_bivector_orthogonality(self):
"""
Rotors in each algebra should be orthogonal
Expand Down
9 changes: 5 additions & 4 deletions clifford/test/test_dpga.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from clifford.dpga import *
import numba


class TestBasicDPGA:
Expand Down Expand Up @@ -44,15 +45,15 @@ def test_bivector_identities(self):

def test_up_down(self):
rng = np.random.RandomState() # can pass a seed here later
for i in range(1000):
for i in range(10 if numba.config.DISABLE_JIT else 1000):
p = rng.standard_normal(3)
dpga_pnt = up(p)
pnt_down = down(np.random.rand()*dpga_pnt)
np.testing.assert_allclose(pnt_down, p)

def test_translate(self):
rng = np.random.RandomState() # can pass a seed here later
for i in range(100):
for i in range(10 if numba.config.DISABLE_JIT else 100):
tvec = rng.standard_normal(3)
wt = tvec[0]*w1 + tvec[1]*w2 + tvec[2]*w3
biv = w0s*wt
Expand All @@ -76,7 +77,7 @@ def test_translate(self):

def test_rotate(self):
rng = np.random.RandomState() # can pass a seed here later
for i in range(100):
for i in range(10 if numba.config.DISABLE_JIT else 100):
mvec = rng.standard_normal(3)
nvec = rng.standard_normal(3)
m = mvec[0] * w1 + mvec[1] * w2 + mvec[2] * w3
Expand Down Expand Up @@ -106,7 +107,7 @@ def test_rotate(self):

def test_line(self):
rng = np.random.RandomState() # can pass a seed here later
for i in range(100):
for i in range(5 if numba.config.DISABLE_JIT else 100):
p1vec = rng.standard_normal(3)
p2vec = rng.standard_normal(3)
p1 = up(p1vec)
Expand Down

0 comments on commit b385669

Please sign in to comment.