Skip to content

Commit

Permalink
tools.g3c: Add jit decorators to all of the val_ function wrappers
Browse files Browse the repository at this point in the history
This means there is no longer any reason to call the `val` versions from other code.
  • Loading branch information
eric-wieser committed Jun 11, 2020
1 parent 0457c8b commit 2803619
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions clifford/tools/g3c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def point_beyond_plane(point, plane):
return (point|(I5*plane))[0] < 0


@numba.njit
def unsign_sphere(S):
"""
Normalises the sign of a sphere
Expand Down Expand Up @@ -662,6 +663,7 @@ def val_get_line_intersection(L3_val, Ldd_val):
return project_val(P/P_denominator, 1)


@numba.njit
def get_line_intersection(L3, Ldd):
"""
Gets the point of intersection of two orthogonal lines that meet
Expand All @@ -685,6 +687,7 @@ def val_midpoint_between_lines(L1_val, L2_val):
return val_normalise_n_minus_1(project_val(gmt_func(S, gmt_func(ninf.value, S)), 1))


@numba.njit
def midpoint_between_lines(L1, L2):
"""
Gets the point that is maximally close to both lines
Expand All @@ -693,6 +696,7 @@ def midpoint_between_lines(L1, L2):
return layout.MultiVector(val_midpoint_between_lines(L1.value, L2.value))


@numba.njit
def midpoint_of_line_cluster(line_cluster):
"""
Gets a center point of a line cluster
Expand Down Expand Up @@ -886,6 +890,7 @@ def generate_dilation_rotor(scale):
return math.cosh(gamma/2) + math.sinh(gamma/2)*(ninf^no)


@numba.njit
def generate_translation_rotor(euc_vector_a):
"""
Generates a rotor that translates objects along the euclidean vector euc_vector_a
Expand All @@ -912,6 +917,7 @@ def meet_val(a_val, b_val):
return dual_func(omt_func(dual_func(a_val), dual_func(b_val)))


@numba.njit
def meet(A, B):
"""
The meet algorithm as described in "A Covariant Approach to Geometry"
Expand Down Expand Up @@ -942,6 +948,7 @@ def val_intersect_line_and_plane_to_point(line_val, plane_val):
return output


@numba.njit
def intersect_line_and_plane_to_point(line, plane):
"""
Returns the point at the intersection of a line and plane
Expand All @@ -965,6 +972,7 @@ def val_normalise_n_minus_1(mv_val):
raise ZeroDivisionError('Multivector has 0 einf component')


@numba.njit
def normalise_n_minus_1(mv):
"""
Normalises a conformal point so that it has an inner product of -1 with einf
Expand Down Expand Up @@ -1019,6 +1027,7 @@ def val_point_pair_to_end_points(T):
return output


@numba.njit
def point_pair_to_end_points(T):
"""
Extracts the end points of a point pair bivector
Expand Down Expand Up @@ -1052,6 +1061,7 @@ def check_sigma_for_positive_root_val(sigma_val):
return (sigma_val[0] + dorst_norm_val(sigma_val)) > 0


@numba.njit
def check_sigma_for_positive_root(sigma):
""" Square Root of Rotors - Checks for a positive root """
return check_sigma_for_positive_root_val(sigma.value)
Expand All @@ -1063,6 +1073,7 @@ def check_sigma_for_negative_root_val(sigma_value):
return (sigma_value[0] - dorst_norm_val(sigma_value)) > 0


@numba.njit
def check_sigma_for_negative_root(sigma):
""" Square Root of Rotors - Checks for a negative root """
return check_sigma_for_negative_root_val(sigma.value)
Expand All @@ -1074,6 +1085,7 @@ def check_infinite_roots_val(sigma_value):
return (sigma_value[0] + dorst_norm_val(sigma_value)) < 0.0000000001


@numba.njit
def check_infinite_roots(sigma):
""" Square Root of Rotors - Checks for a infinite roots """
return check_infinite_roots_val(sigma.value)
Expand Down Expand Up @@ -1104,6 +1116,7 @@ def negative_root_val(sigma_val):
return result


@numba.njit
def positive_root(sigma):
"""
Square Root of Rotors - Evaluates the positive root
Expand All @@ -1112,6 +1125,7 @@ def positive_root(sigma):
return layout.MultiVector(res_val)


@numba.njit
def negative_root(sigma):
""" Square Root of Rotors - Evaluates the negative root """
res_val = negative_root_val(sigma.value)
Expand Down Expand Up @@ -1146,6 +1160,7 @@ def general_root_val(sigma_value):
raise ValueError('No root exists')


@numba.njit
def general_root(sigma):
""" The general case of the root of a grade 0, 4 multivector """
output = general_root_val(sigma.value)
Expand All @@ -1160,6 +1175,7 @@ def val_annihilate_k(K_val, C_val):
return val_normalised(gmt_func(k_4, C_val))


@numba.njit
def annihilate_k(K, C):
""" Removes K from C = KX via (K[0] - K[4])*C """
return layout.MultiVector(val_annihilate_k(K.value, C.value))
Expand Down Expand Up @@ -1197,6 +1213,7 @@ def neg_twiddle_root_val(C_value):
return output


@numba.njit
def pos_twiddle_root(C):
"""
Square Root and Logarithm of Rotors
Expand All @@ -1208,6 +1225,7 @@ def pos_twiddle_root(C):
return [layout.MultiVector(output[0, :]), layout.MultiVector(output[1, :])]


@numba.njit
def neg_twiddle_root(C):
"""
Square Root and Logarithm of Rotors
Expand Down Expand Up @@ -1342,6 +1360,7 @@ def TRS_between_rounds(X1, X2):
return normalised((~T2)*S*Rc*T1)


@numba.njit
def motor_between_rounds(X1, X2):
"""
Calculate the motor between any pair of rounds of the same grade
Expand Down Expand Up @@ -1401,6 +1420,7 @@ def val_motor_between_objects(X1, X2):
return val_motor_between_rounds(X1, X2)


@numba.njit
def motor_between_objects(X1, X2):
"""
Calculates a motor that takes X1 to X2
Expand Down Expand Up @@ -1551,6 +1571,7 @@ def val_norm(mv_val):
return np.sqrt(np.abs(gmt_func(adjoint_func(mv_val), mv_val)[0]))


@numba.njit
def norm(mv):
""" Returns sqrt(abs(~A*A)) """
return val_norm(mv.value)
Expand All @@ -1562,6 +1583,7 @@ def val_normalised(mv_val):
return mv_val/val_norm(mv_val)


@numba.njit
def normalised(mv):
""" fast version of the normal() function """
return layout.MultiVector(val_normalised(mv.value))
Expand All @@ -1587,11 +1609,13 @@ def val_rotor_between_lines(L1_val, L2_val):
return gmt_func(normalisation_val, output_val)


@numba.njit
def rotor_between_lines(L1, L2):
""" return the rotor between two lines """
return layout.MultiVector(val_rotor_between_lines(L1.value, L2.value))


@numba.njit
def rotor_between_planes(P1, P2):
""" return the rotor between two planes """
return layout.MultiVector(val_rotor_rotor_between_planes(P1.value, P2.value))
Expand Down Expand Up @@ -1734,6 +1758,7 @@ def val_apply_rotor(mv_val, rotor_val):
return gmt_func(rotor_val, gmt_func(mv_val, adjoint_func(rotor_val)))


@numba.njit
def apply_rotor(mv_in, rotor):
""" Applies rotor to multivector in a fast way """
return layout.MultiVector(val_apply_rotor(mv_in.value, rotor.value))
Expand All @@ -1745,6 +1770,7 @@ def val_apply_rotor_inv(mv_val, rotor_val, rotor_val_inv):
return gmt_func(rotor_val, gmt_func(mv_val, rotor_val_inv))


@numba.njit
def apply_rotor_inv(mv_in, rotor, rotor_inv):
""" Applies rotor to multivector in a fast way takes pre computed adjoint"""
return layout.MultiVector(val_apply_rotor_inv(mv_in.value, rotor.value, rotor_inv.value))
Expand Down Expand Up @@ -1774,6 +1800,7 @@ def val_convert_2D_polar_line_to_conformal_line(rho, theta):
return line_val


@numba.njit
def convert_2D_polar_line_to_conformal_line(rho, theta):
""" Converts a 2D polar line to a conformal line """
line_val = val_convert_2D_polar_line_to_conformal_line(rho, theta)
Expand All @@ -1788,6 +1815,7 @@ def val_up(mv_val):
return mv_val - no.value + omt_func(temp, gmt_func(gmt_func(mv_val, mv_val), ninf.value))


@numba.njit
def fast_up(mv):
""" Fast up mapping """
return layout.MultiVector(val_up(mv.value))
Expand All @@ -1813,6 +1841,7 @@ def val_down(mv_val):
return gmt_func(omt_func(val_homo(mv_val), E0.value), E0.value)


@numba.njit
def fast_down(mv):
""" A fast version of down() """
return layout.MultiVector(val_down(mv.value))
Expand All @@ -1834,6 +1863,7 @@ def val_convert_2D_point_to_conformal(x, y):
return val_up(mv_val)


@numba.njit
def convert_2D_point_to_conformal(x, y):
""" Convert a 2D point to conformal """
return layout.MultiVector(val_convert_2D_point_to_conformal(x, y))
Expand All @@ -1857,6 +1887,7 @@ def dual_func(a_val):
return dual_gmt_func(I5.value, a_val)


@numba.njit
def fast_dual(a):
"""
Fast dual
Expand Down

0 comments on commit 2803619

Please sign in to comment.