In [1]:
import numpy as np
import pdb
import time

from geometry_tools import representation, utils
from geometry_tools.automata import fsa

In [2]:
t=2*np.arccosh(1/np.tan(np.pi/8))

c=np.cosh(t)
s=np.sinh(t)

a=np.array([[c,0,s],[0,1,0],[s,0,c]])
A=np.array([[c,0,-s],[0,1,0],[-s,0,c]])

# define b and B
theta=np.pi/4
cc=np.cos(theta)
ss=np.sin(theta)
rot=np.array([[cc,ss,0],[-ss,cc,0],[0,0,1]])
rotinv=np.array([[cc,-ss,0],[ss,cc,0],[0,0,1]])
b=rot@a@rotinv
B=rot@A@rotinv

# define c and C
theta=np.pi/2
cc=np.cos(theta)
ss=np.sin(theta)
rot=np.array([[cc,ss,0],[-ss,cc,0],[0,0,1]])
rotinv=np.array([[cc,-ss,0],[ss,cc,0],[0,0,1]])
c=rot@a@rotinv
C=rot@A@rotinv

# define d and D
theta=3*(np.pi/4)
cc=np.cos(theta)
ss=np.sin(theta)
rot=np.array([[cc,ss,0],[-ss,cc,0],[0,0,1]])
rotinv=np.array([[cc,-ss,0],[ss,cc,0],[0,0,1]])
d=rot@a@rotinv
D=rot@A@rotinv

In [3]:
rep = representation.Representation()

rep["a"] = a
rep["b"] = b
rep["c"] = c
rep["d"] = d

relation = "adCbADcB"
rep[relation]

array([[ 1.00000000e+00,  4.09729211e-12,  1.09562396e-11],
       [ 5.21151958e-12,  1.00000000e+00, -5.60560825e-12],
       [-1.87139855e-11,  7.67681862e-12,  1.00000000e+00]])

In [18]:
class SymmetricPoint:
    def __init__(self, matrix=None, a=None, k=None):
        if matrix is None and (k is None or a is None):
            raise ValueError("A matrix or its diagonalization must be specified to initialize a point")
            
        if (k is None) ^ (a is None):
            raise ValueError("Both K and A must be specified to initialize a point via diagonalization")
        
        self.matrix = matrix
        self.k = k
        self.a = a
        
        if a is None or k is None:
            self._compute_coset(matrix)
        
        if matrix is None:
            self._compute_matrix(a, k)
        
    def _compute_matrix(self, a, k):
        self.matrix = (k @ utils.construct_diagonal(a) @ 
                       np.swapaxes(k, -1, -2))
        
    def _compute_coset(self, matrices):
        # we take an abs here, which is probably bad
        a, self.k = np.linalg.eigh(matrices)
        self.a = np.abs(a)
        
    def __getitem__(self, val):
        return SymmetricPoint(self.matrix[val], self.a[val], self.k[val])
        
    def to_origin(self):
        return (self.k @ 
         np.sqrt(np.abs(utils.construct_diagonal(1 / self.a))) @
         np.swapaxes(self.k, -1, -2))
        
    def origin_midpoint(self):
        return SymmetricPoint(a=np.sqrt(self.a), k=self.k)
    
    def distance_to_origin(self):
        return np.linalg.norm(np.log(np.abs(self.a)), axis=-1)
    
    def invert_by_transvection(self):
        d = self.a.shape[-1]
        reverse = utils.permutation_matrix(list(range(d - 1, -1, -1)))
        return SymmetricPoint(k=self.k @ reverse, a=np.flip(1 / self.a, axis=-1))
    
    @staticmethod
    def from_translation(matrix):
        return SymmetricPoint(matrix @ np.swapaxes(matrix, -1, -2))

In [19]:
def apply_isometry(isometry, point, broadcast="pairwise"):
    iso_dims = len(isometry.shape) - 2
    pt_dims = len(point.matrix.shape) - 2
    
    if broadcast == "elementwise":
        result = isometry @ point.matrix @ np.swapaxes(isometry, -1, -2)
    
    if broadcast == "pairwise_reversed":
        left = utils.matrix_product(isometry, point.matrix, broadcast="pairwise_reversed")
        result = left @ np.swapaxes(isometry, -1, -2)
    
    if broadcast == "pairwise":
        left = utils.matrix_product(isometry, point.matrix, broadcast="pairwise")
        right = np.expand_dims(np.swapaxes(isometry, -1, -2), 
                               axis=tuple(range(iso_dims, iso_dims + pt_dims)))
        result = left @ right
        
    return SymmetricPoint(result)

In [20]:
def riemannian_distance(p1, p2, broadcast="pairwise"):
    g = p1.to_origin()
    
    g_translate = apply_isometry(g, p2, broadcast=broadcast)
    
    singular_values, _ = np.linalg.eigh(g_translate)
    vector_valued_dist = np.log(np.abs(singular_values))
    return np.linalg.norm(vector_valued_dist, axis=-1)

In [21]:
def z_angle_from_origin(p1, p2, zeta, broadcast="pairwise"):
    zeta_trace = np.trace(zeta @ zeta)
    
    p1_zeta = p1.k @ zeta @ np.swapaxes(p1.k, -1, -2)
    p2_zeta = p2.k @ zeta @ np.swapaxes(p2.k, -1, -2)
    
    product = utils.matrix_product(p1_zeta, p2_zeta,
                                   broadcast=broadcast)
    
    return np.trace(product, axis1=-1, axis2=-2) / zeta_trace

In [22]:
zeta = np.diag(np.array([2,-1,-1]));
i_zeta = np.diag(np.array([1,1,-2]));
trz2 = np.trace(zeta @ zeta);

In [23]:
automaton = fsa.load_kbmag_file("genus2_surf_max.wa")

In [28]:
# 1/2 of the word length to check
length = 5

# we will work with at most 1000 * 1000 group elements at once
ARR_SIZE_CAP = 1000

In [29]:
# some dynamic programming tricks
pre_memos = {}
post_memos = {}

In [30]:
starttime = time.time()

worst_z_angle = 1.0
worst_iz_angle = 1.0
worst_dist = 1e20

for vertex in automaton.vertices():

    # for each vertex v of the automaton:
    # 1. get an array of elements of length l coming from paths ending at v
    # 2. get an array of elements of length l coming from paths starting at v
    prefix = rep.automaton_accepted(automaton, length, end_state=vertex,
                                    maxlen=False, precomputed=pre_memos)
    postfix = rep.automaton_accepted(automaton, length, start_state=vertex,
                                      maxlen=False, precomputed=post_memos)
    
    if len(prefix) == 0 or len(postfix) == 0:
        continue
    
    # cut up the prefix and postfix lists into fixed-length chunks,
    # and run the straight-and-spaced check on them
    n = len(prefix)
    m = len(postfix)
    nsegs = int(n / ARR_SIZE_CAP) + 1
    msegs = int(m / ARR_SIZE_CAP) + 1
    for i in range(nsegs):
        for j in range(msegs):
            
            prefix_elts = np.linalg.inv(
                    prefix[i * ARR_SIZE_CAP:min((i + 1) * ARR_SIZE_CAP, n)]
            )
            postfix_elts = postfix[j * ARR_SIZE_CAP:min((j + 1) * ARR_SIZE_CAP, m)]
            
            p1 = SymmetricPoint.from_translation(prefix_elts)
            p2 = SymmetricPoint.from_translation(postfix_elts)
            
            m1 = p1.origin_midpoint()
            m2 = p2.origin_midpoint()
            
            # filter out points where an eigenvalue is zero since they cause numerical issues / crashes
            m1_nonzero = np.all(m1.a != 0.0, axis=-1)
            m2_nonzero = np.all(m2.a != 0.0, axis=-1)            
            m1 = m1[m1_nonzero]
            m2 = m2[m2_nonzero]

            m1_inv = m1.invert_by_transvection()
            m2_inv = m2.invert_by_transvection()
            
            m1_inv_m2 = apply_isometry(m1.to_origin(), m2, broadcast="pairwise_reversed")
            m2_inv_m1 = apply_isometry(m2.to_origin(), m1, broadcast="pairwise_reversed")
            
            dists = m1_inv_m2.distance_to_origin()
            zeta_angle = z_angle_from_origin(m2_inv, m2_inv_m1, zeta, 
                                             broadcast="elementwise")
            iota_zeta_angle = z_angle_from_origin(m1_inv, m1_inv_m2, i_zeta,
                                                  broadcast="elementwise")
            
            worst_dist = min(worst_dist, np.min(dists))
            worst_z_angle = min(worst_z_angle, np.min(zeta_angle))
            worst_iz_angle = min(worst_iz_angle, np.min(iota_zeta_angle))
            
            print(
                "vertex: {} / {}, chunk {} / {} ({} group elements)".format(
                    vertex, len(automaton.vertices()),
                    i * msegs + j + 1, nsegs * msegs,
                    len(prefix_elts) * len(postfix_elts)
            ))
            
timedelta = time.time() - starttime
print(f"Complete, time was {timedelta} seconds.")
print(f"Worst spacing: {worst_dist}")
print(f"Worst zeta_angle: {worst_z_angle}")
print(f"Worst iota_zeta angle: {worst_iz_angle}")

vertex: 2 / 37, chunk 1 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 2 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 3 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 4 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 5 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 6 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 7 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 8 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 9 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 10 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 11 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 12 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 13 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 14 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 15 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 16 / 34 (1000000 group elements)
vertex: 2 / 37, chunk 17 / 34 (703000 group elements)
vertex: 2 / 37, chunk 18 / 34 (709000 group elements)
vertex: 2 / 37, chunk

vertex: 5 / 37, chunk 34 / 51 (661000 group elements)
vertex: 5 / 37, chunk 35 / 51 (3000 group elements)
vertex: 5 / 37, chunk 36 / 51 (3000 group elements)
vertex: 5 / 37, chunk 37 / 51 (3000 group elements)
vertex: 5 / 37, chunk 38 / 51 (3000 group elements)
vertex: 5 / 37, chunk 39 / 51 (3000 group elements)
vertex: 5 / 37, chunk 40 / 51 (3000 group elements)
vertex: 5 / 37, chunk 41 / 51 (3000 group elements)
vertex: 5 / 37, chunk 42 / 51 (3000 group elements)
vertex: 5 / 37, chunk 43 / 51 (3000 group elements)
vertex: 5 / 37, chunk 44 / 51 (3000 group elements)
vertex: 5 / 37, chunk 45 / 51 (3000 group elements)
vertex: 5 / 37, chunk 46 / 51 (3000 group elements)
vertex: 5 / 37, chunk 47 / 51 (3000 group elements)
vertex: 5 / 37, chunk 48 / 51 (3000 group elements)
vertex: 5 / 37, chunk 49 / 51 (3000 group elements)
vertex: 5 / 37, chunk 50 / 51 (3000 group elements)
vertex: 5 / 37, chunk 51 / 51 (1983 group elements)
vertex: 6 / 37, chunk 1 / 34 (1000000 group elements)
vertex: 

vertex: 9 / 37, chunk 17 / 51 (619000 group elements)
vertex: 9 / 37, chunk 18 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 19 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 20 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 21 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 22 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 23 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 24 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 25 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 26 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 27 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 28 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 29 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 30 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 31 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 32 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 33 / 51 (1000000 group elements)
vertex: 9 / 37, chunk 34 / 51 (619000 group elements)
vertex: 9 / 

vertex: 16 / 37, chunk 16 / 17 (294000 group elements)
vertex: 16 / 37, chunk 17 / 17 (181986 group elements)
vertex: 17 / 37, chunk 1 / 17 (294000 group elements)
vertex: 17 / 37, chunk 2 / 17 (294000 group elements)
vertex: 17 / 37, chunk 3 / 17 (294000 group elements)
vertex: 17 / 37, chunk 4 / 17 (294000 group elements)
vertex: 17 / 37, chunk 5 / 17 (294000 group elements)
vertex: 17 / 37, chunk 6 / 17 (294000 group elements)
vertex: 17 / 37, chunk 7 / 17 (294000 group elements)
vertex: 17 / 37, chunk 8 / 17 (294000 group elements)
vertex: 17 / 37, chunk 9 / 17 (294000 group elements)
vertex: 17 / 37, chunk 10 / 17 (294000 group elements)
vertex: 17 / 37, chunk 11 / 17 (294000 group elements)
vertex: 17 / 37, chunk 12 / 17 (294000 group elements)
vertex: 17 / 37, chunk 13 / 17 (294000 group elements)
vertex: 17 / 37, chunk 14 / 17 (294000 group elements)
vertex: 17 / 37, chunk 15 / 17 (294000 group elements)
vertex: 17 / 37, chunk 16 / 17 (294000 group elements)
vertex: 17 / 37, ch

vertex: 26 / 37, chunk 2 / 17 (42000 group elements)
vertex: 26 / 37, chunk 3 / 17 (42000 group elements)
vertex: 26 / 37, chunk 4 / 17 (42000 group elements)
vertex: 26 / 37, chunk 5 / 17 (42000 group elements)
vertex: 26 / 37, chunk 6 / 17 (42000 group elements)
vertex: 26 / 37, chunk 7 / 17 (42000 group elements)
vertex: 26 / 37, chunk 8 / 17 (42000 group elements)
vertex: 26 / 37, chunk 9 / 17 (42000 group elements)
vertex: 26 / 37, chunk 10 / 17 (42000 group elements)
vertex: 26 / 37, chunk 11 / 17 (42000 group elements)
vertex: 26 / 37, chunk 12 / 17 (42000 group elements)
vertex: 26 / 37, chunk 13 / 17 (42000 group elements)
vertex: 26 / 37, chunk 14 / 17 (42000 group elements)
vertex: 26 / 37, chunk 15 / 17 (42000 group elements)
vertex: 26 / 37, chunk 16 / 17 (42000 group elements)
vertex: 26 / 37, chunk 17 / 17 (15414 group elements)
vertex: 27 / 37, chunk 1 / 15 (42000 group elements)
vertex: 27 / 37, chunk 2 / 15 (42000 group elements)
vertex: 27 / 37, chunk 3 / 15 (42000 g

vertex: 36 / 37, chunk 6 / 15 (6000 group elements)
vertex: 36 / 37, chunk 7 / 15 (6000 group elements)
vertex: 36 / 37, chunk 8 / 15 (6000 group elements)
vertex: 36 / 37, chunk 9 / 15 (6000 group elements)
vertex: 36 / 37, chunk 10 / 15 (6000 group elements)
vertex: 36 / 37, chunk 11 / 15 (6000 group elements)
vertex: 36 / 37, chunk 12 / 15 (6000 group elements)
vertex: 36 / 37, chunk 13 / 15 (6000 group elements)
vertex: 36 / 37, chunk 14 / 15 (6000 group elements)
vertex: 36 / 37, chunk 15 / 15 (1644 group elements)
vertex: 37 / 37, chunk 1 / 15 (6000 group elements)
vertex: 37 / 37, chunk 2 / 15 (6000 group elements)
vertex: 37 / 37, chunk 3 / 15 (6000 group elements)
vertex: 37 / 37, chunk 4 / 15 (6000 group elements)
vertex: 37 / 37, chunk 5 / 15 (6000 group elements)
vertex: 37 / 37, chunk 6 / 15 (6000 group elements)
vertex: 37 / 37, chunk 7 / 15 (6000 group elements)
vertex: 37 / 37, chunk 8 / 15 (6000 group elements)
vertex: 37 / 37, chunk 9 / 15 (6000 group elements)
vertex