# Setup

In [None]:
import analysis as al
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
import tqdm.auto as tqdm
%matplotlib widget 
# requires ipympl package

In [None]:
def grab(x):
    return x.detach().cpu().numpy()

In [None]:
def wrap(x):
    return (x + np.pi) % (2*np.pi) - np.pi

In [None]:
torch.set_default_dtype(torch.double)

# SU(N) utils

Almost entirely copied from StN project.

## Matrix ops

In [None]:
def adjoint(U):
    return U.conj().swapaxes(-1,-2)
def trace(U):
    return U.diagonal(dim1=-2, dim2=-1).sum(-1)

In [None]:
def dist(U, V):
    """Distance between two SU(Nc) matrices."""
    Nc = U.shape[-1]
    return 1 - (1/Nc)*trace(U @ adjoint(V)).real

def diag(U):
    """Just takes the diagonal format of a matrix U."""
    return torch.diagonal(U, dim1=-1, dim2=-2)

def mat_angle(U):
    """
    Diagonalizes an SU(Nc) matrix to get
        
        U = V exp(i th) Vinv
    
    and returns the eigen angles, as well as
    eigenvecs through V, V^dagg
    """
    L, V = torch.linalg.eig(U)
    Vinv = torch.linalg.inv(V)
    th = torch.angle(L)
    return th, V, Vinv

In [None]:
def proj_trless(A):
    """Removes the trace of a matrix A."""
    Nc = A.shape[-1]
    A -= (1/Nc)*torch.eye(Nc)*trace(A)[...,None,None]
    return A

def proj_sun_algebra(A):
    """
    Projects a matrix A into the Lie algebra su(Nc) by
    1.) Making it Hermitian by symmetrizing it,
    2.) Making it traceless (see prev func).
    """
    Nc = A.shape[-1]
    Ap = (A + adjoint(A))/2
    Ap = proj_trless(Ap)
    return Ap

def sample_sun_gaussian(shape):
    """
    Samples a random, gaussian-distributed element of
    the Lie algebra su(Nc).
    """
    Nc, Nc_ = shape[-2:]
    assert Nc == Nc_
    return proj_sun_algebra(torch.randn(shape) + 1j*torch.randn(shape))

def sample_su2_haar(n):
    """
    Exactly samples a normal SU(2) group element from the Haar measure.
    """
    z = np.random.normal(size=(n, 4))
    x = z / np.linalg.norm(z, axis=-1, keepdims=True)
    mat = [x[:,0]+1j*x[:,1], x[:,2]+1j*x[:,3], -x[:,2]+1j*x[:,3], x[:,0]-1j*x[:,1]]
    return np.stack(mat, axis=-1).reshape(n, 2, 2)

In [None]:
### GENERATORS
### Normalization Tr[T^a T^b] = delta^{ab}
_su2_gens = torch.stack([
    torch.tensor([[0, 1], [1, 0]]),
    torch.tensor([[0, -1j], [1j, 0]]),
    torch.tensor([[1, 0], [0, -1]]),
]) / np.sqrt(2)

def test_su2_gens():
    Delta = trace(_su2_gens[:,None] @ _su2_gens)
    assert torch.allclose(Delta, torch.eye(len(_su2_gens)).cdouble())
    print('Test SU(2) gens PASSED')

if __name__ == '__main__': test_su2_gens()

_su3_gens = torch.stack([
    torch.tensor([[0, 1, 0], [1, 0, 0], [0, 0, 0]]),
    torch.tensor([[0, -1j, 0], [1j, 0, 0], [0, 0, 0]]),
    torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 0]]),
    torch.tensor([[0, 0, 1], [0, 0, 0], [1, 0, 0]]),
    torch.tensor([[0, 0, -1j], [0, 0, 0], [1j, 0, 0]]),
    torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0]]),
    torch.tensor([[0, 0, 0], [0, 0, -1j], [0, 1j, 0]]),
    torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -2]])/np.sqrt(3),
]) / np.sqrt(2)

def test_su3_gens():
    """
    Checks whether the SU(3) generators satisfy

        Tr(ta @ tb) = delta_ab
    """
    Delta = trace(_su3_gens[:,None] @ _su3_gens)
    assert torch.allclose(Delta, torch.eye(len(_su3_gens)).cdouble())
    print('Test SU(3) gens PASSED')

if __name__ == '__main__': test_su3_gens()

def sun_gens(Nc):
    if Nc == 2:
        return _su2_gens
    elif Nc == 3:
        return _su3_gens
    else:
        raise ValueError(f'{Nc=}')

def embed_sun_algebra(omega, Nc):
    return torch.einsum('...x,xab->...ab', omega.cdouble(), sun_gens(Nc))

In [None]:
def _su2_matrix_exp(A):
    """
    Computes the (complex) matrix exponential map 
    of an su(2) Lie algebra matrix A. This is done by:

    1.) Decomposing A = n \dot Sigma into the SU(2) generators,
    2.) checking that ||n|| ~ 0, i.e. how close A is to identity I,
    3.) defining the norm as ||n||^2 + 0.1 where ||n||^2 is close to 0, otherwise just ||n||^2
    4.) 
    """
    omega = torch.einsum('...ab,xba->...x', -1j*A, _su2_gens)
    # assert torch.allclose(omega.imag, torch.tensor(0.0))
    omega = omega.real
    norm_sq = (omega**2).sum(dim=-1)
    # need to be careful of sqrt branch cut at 0
    ident_inds = torch.isclose(norm_sq, torch.tensor(0.0))
    ident_inds_mat = torch.isclose(norm_sq[...,None,None], torch.tensor(0.0))
    norm_omega = torch.where(ident_inds, norm_sq+0.1, norm_sq).sqrt()  # where true yield input, else other (condition, input, other)
    omega_dot_sigma = embed_sun_algebra(omega, 2)
    U = (
        torch.cos(norm_omega/np.sqrt(2))[...,None,None]*torch.eye(2)
        + 1j*torch.sinc(norm_omega/(np.pi*np.sqrt(2)))[...,None,None]*omega_dot_sigma
    )
    return torch.where(ident_inds_mat, torch.eye(2).cdouble()+A, U)

def test_su2_matrix_exp():
    omega = torch.randn((5,len(_su2_gens)))
    omega = torch.cat([torch.zeros((1,len(_su2_gens))), omega])
    def expiA(omega):
        A = embed_sun_algebra(omega, Nc=2)
        U = _su2_matrix_exp(1j*A)
        return torch.stack([U.real, U.imag])
    def expiA2(omega):
        A = embed_sun_algebra(omega, Nc=2)
        U = torch.matrix_exp(1j*A)
        return torch.stack([U.real, U.imag])
    assert torch.allclose(expiA(omega), expiA2(omega))
    jac = torch.func.vmap(torch.func.jacfwd(expiA))(omega)
    jac2 = torch.func.vmap(torch.func.jacfwd(expiA2))(omega)
    assert torch.allclose(jac, jac2), f'{jac=} {jac2=}'

if __name__ == '__main__': test_su2_matrix_exp()

def sun_matrix_exp(A):
    r"""exp(A) assuming $-iA \in su(N)$"""
    Nc = A.shape[-1]
    if Nc == 2:
        return _su2_matrix_exp(A)
        # return torch.matrix_exp(A)
    else:
        return torch.matrix_exp(A)

## Autograd stuff

In [None]:
def grad_sun(f):
    """
    Wrapper for torch.func.grad that ensures derivs are evaluated along the
    directions of Lie algebra generators.
    """
    def grad(U):
        Nc = U.shape[-1]
        gens = sun_gens(Nc)
        omega = torch.zeros(U.shape[:-2]+(len(gens),))
        def ff(omega):
            A = embed_sun_algebra(omega, Nc)
            return f(U + 1j*A@U)
        return torch.func.grad(ff)(omega)
    return grad

In [None]:
def jacfwd_sun(f):
    def jacfwd(U):
        Nc = U.shape[-1]
        gens = sun_gens(Nc)
        omega = torch.zeros(U.shape[:-2]+(len(gens),))
        def ff(omega):
            A = embed_sun_algebra(omega, Nc)
            return f(U + 1j*A@U)
        return torch.func.jacfwd(ff)(omega)
    return jacfwd

In [None]:
def hess_sun(f):
    """
    Wrapper for torch.func.hessian that ensures derivs are evaluated along the
    directions of Lie algebra generators.
    """
    def hess(U):
        Nc = U.shape[-1]
        gens = sun_gens(Nc)
        omega = torch.zeros(U.shape[:-2]+(len(gens),))
        def ff(omega):
            A = embed_sun_algebra(omega, Nc)
            # TODO: make sure this shortcut works
            # return f(U + 1j*A@U - 0.5*A@A@U)
            return f(sun_matrix_exp(1j*A) @ U)
        return torch.func.hessian(ff)(omega)
    return hess

## SU(3) canonical cell

In [None]:
def better_canonicalize_su3(thW):
    # project onto hyperplane defined by sum_i theta_i = 0
    thW[...,-1] -= torch.sum(thW, dim=-1)
    v = thW.reshape(-1, 3)
    
    # wrap into canonical hexagon centered on I
    # map from (a,b,c) into v = (th1, th2, th3)
    U = 2*np.pi * torch.tensor([
        [1, 0, -1],
        [0, -1, 1],
        [-1, 1, 0],
    ])
    # map from v into (a,b,c)
    Uinv = torch.tensor([
        [1, 0, -1],
        [0, -1, 1],
        [-1, 1, 0]
    ]) / (6*np.pi)
    kappa = Uinv @ torch.transpose(v, 0, 1) # ij,jb -> ib
    a, b, c = kappa[0], kappa[1], kappa[2]
    k = (b+c)/2
    a -= k
    b -= k
    c -= k
    a -= torch.round(a)
    k = torch.round(b)
    b -= k
    c += k
    b -= torch.round(b - (a+c)/2)
    k = (b+c)/2
    a -= k
    b -= k
    c -= k
    a -= torch.round(a)
    c -= torch.round(c - (a+b)/2)
    return torch.transpose(U @ torch.stack([a,b,c], dim=0), 0, 1).reshape(thW.shape)

In [None]:
# matrix to transform from (theta1, theta2, theta3) -> (alpha, beta) plane
_su3_A = np.array([
    [0, 1, -1] / np.sqrt(2),
    [2, -1, -1] / np.sqrt(6),
])
_su3_Ainv = np.array([
    [0, np.sqrt(2/3)],
    [1 / np.sqrt(2), -1 / np.sqrt(6)],
    [-1 / np.sqrt(2), -1 / np.sqrt(6)],
])
assert np.allclose(_su3_A @ _su3_Ainv, np.identity(2))


# for nice plots in (alpha, beta) plane
_su3_hex = (2*np.pi/3) * np.array([
    (1,1,-2),
    (2,-1,-1),
    (1,-2,1),
    (-1,-1,2),
    (-2,1,1),
    (-1,2,-1),
    (1,1,-2),
])

# SU(N) heat-kernel distribution

We use the standard definition of SU(N) Brownian motion as defined in stochastic quantization. For diffusion, we will want the forward process to be a variance-expanding scheme, so that for sufficiently large noise coefficient $\sigma$, the final distribution is very nearly uniform.

Unlike for elements of $\mathbb{R}^n$, we **cannot "fast-forward"** $SU(N)$ Brownian motion by simply setting the scale of the sampled Gaussian noise to $\sqrt{t}$. For example, two Brownian steps in the Euler-Heun scheme generate higher order terms, which play a role if $dt$ is not infinitesimal:
$$
\begin{aligned}
\exp(i \sqrt{dt} A) \exp(i \sqrt{dt} B) &= \exp(i \sqrt{dt}(A+B) - \frac{1}{2} dt [A, B]) + O(dt^{3/2}) \\
&= \exp(i \sqrt{2 dt} C - \frac{1}{2} dt [A, B]) + O(dt^{3/2}).
\end{aligned}
$$
Here $C = (A+B)/\sqrt{2} \sim \mathcal{N}(0, I)$ is corrcctly distributed, but we see a next-order non-trivial effect.

However, the **heat-kernel** $K(U, t)$ has been derived in the case of $U(N)$ and $SU(N)$ groups (see e.g. Menotti and Onofri). It is a class function, structured as a periodically wrapped function of the eigenvalues. We will include the relevant Haar measure factor on the eigenvalues,
$$
h(\theta_1, \dots, \theta_n) = \prod_{i < j} |e^{i \theta_i} - e^{i \theta_j}|^2 = \prod_{i < j} (2 - 2 \cos(\theta_i - \theta_j)),
$$
and will find it convenient to work in terms of the "unwrapped" eigenvalue angles $x_i \in \mathbb{R}$ from which $\lambda_i = e^{i x_i}$.
Including the Haar measure, the $U(N)$ heat kernel comes out to
$$
K(x, t) = \mathcal{N} \prod_{i < j} \left( \frac{2 - 2 \cos(x_i - x_j)}{|\mathrm{sinc}(\tfrac{1}{2}(x_i - x_j))|} \right)
\exp(-\frac{1}{2t} \sum_i x_i^2) = \mathcal{N} \prod_{i < j} \left( |x_i - x_j| \sqrt{2 - 2 \cos(x_i - x_j)} \right)
\exp(-\frac{1}{2t} \sum_i x_i^2).
$$
Going to the $SU(N)$ heat kernel just requires an additional $\delta(\sum_i x_i)$.

**TODO:** I don't understand the absolute value in the measure, need to work this out carefully at some point.

**TODO:** The normalization is non-trivial, and I haven't managed to line this up with references. It isn't needed in the following, so we proceed without it.

In [None]:
# SU(N) heat kernel stuff

# TODO(gkanwar): I think something is wrong with this measure term, given the
# SU(2) results just below. Gotta derive this properly to check... e.g. it
# might need to have the _periodic_ theta_i - theta_j, not x_i - x_j as the
# argument?
def _hk_meas_factor(delta):
    return (2 - 2*np.cos(delta)) / np.abs(np.sinc(delta / (2*np.pi)))
    # return np.sqrt(2 - 2*np.cos(delta)) * np.abs(delta)

def _eval_hk_su2(theta, t, *, n_max=3):
    # ignore ALL normalization factors for now
    return sum(
        _hk_meas_factor(2*theta + 4*np.pi*n) * np.exp(-(1/t) * (theta + 2*np.pi*n)**2) for n in range(-n_max, n_max)
    )


def _eval_hk_su3(theta1, theta2, t, *, n_max=3):
    # ignore ALL normalization factors for now
    theta3 = -theta1 - theta2  # sum to zero
    
    d12 = theta1 - theta2
    d23 = theta2 - theta3
    d31 = theta3 - theta1
    
    return sum(
        _hk_meas_factor(d12+2*np.pi*(n1-n2)) *
        _hk_meas_factor(d23+2*np.pi*(n2-(-n1-n2))) *
        _hk_meas_factor(d31+2*np.pi*((-n1-n2)-n1)) *
        np.exp(
            -1/(2*t) * (theta1 + 2*np.pi*n1)**2
            -1/(2*t) * (theta2 + 2*np.pi*n2)**2
            -1/(2*t) * (theta3 + 2*np.pi*(-n1-n2))**2
        )
        for n1 in range(-n_max, n_max) for n2 in range(-n_max, n_max)
    )


def eval_hk_sun_unwrapped(xs, t, meas_only=False):
    assert len(t.shape) == 1 and t.shape[0] == xs.shape[0], 't should be batched'
    xs = np.array(xs)
    xn = -np.sum(xs, axis=-1, keepdims=True)
    xs = np.concatenate([xs, xn], axis=-1)
    
    delta_x = np.stack([
        xs[...,i] - xs[...,j]
        for i in range(xs.shape[-1]) for j in range(i+1, xs.shape[-1])
    ], axis=-1)  # all pairwise differences
    
    meas = np.prod(_hk_meas_factor(delta_x), axis=-1)
    if meas_only:
        return meas
    weight = np.exp(-1/(2*t) * np.sum(xs**2, axis=-1))
    return meas * weight


def eval_hk_sun(thetas, t, *, n_max=3):
    # promote t to batch, if needed
    t = np.array(t)
    if len(t.shape) == 0:
        t = t * np.ones(thetas.shape[0])
    assert t.shape == (thetas.shape[0],)
    thetas = np.array(thetas)
    total = 0
    for ns in itertools.product(range(-n_max, n_max), repeat=thetas.shape[-1]):
        ns = np.array(ns)
        xs = thetas + 2*np.pi*ns
        total = total + eval_hk_sun_unwrapped(xs, t)
    return total

In [None]:
def _test_hk_sun():
    """
    Checks that the general `eval_hk_sun` reduces to the
    hard-coded implementations for SU(2) and SU(3) when N = 2,3.
    """
    np.random.seed(1234)
    
    # SU(2)
    theta = 6*np.pi*np.random.normal(size=(1024,))  # for multiple windings
    va = _eval_hk_su2(theta, t=0.5)
    vb = eval_hk_sun(theta[...,None], t=0.5)
    assert np.allclose(va, vb), \
        '[FAILED SU(2): ]'
    
    # SU(3)
    thetas = 6*np.pi*np.random.normal(size=(1024,2))  # Nc - 1 = 2
    va = _eval_hk_su3(thetas[...,0], thetas[...,1], t=0.5)
    vb = eval_hk_sun(thetas, t=0.5)
    assert np.allclose(va, vb)
    
    print('[PASSED test_hk_sun]')

if __name__ == '__main__': _test_hk_sun()

The measure factor over several $2\pi$ intervals is shown below. It grows due to the behavior of $\mathrm{sinc}(\Delta)$ for large $\Delta$, but is guaranteed to be cut off by the Gaussian weights.

In [None]:
deltas = np.linspace(-4*np.pi, 4*np.pi, 501)
fig, ax = plt.subplots(1,1, figsize=(3.5, 2.5))
ax.plot(deltas, _hk_meas_factor(deltas))
ax.set_xlabel(r'$\Delta$')
fig.set_tight_layout(True)
plt.show()

### SU(2)
Heat kernel distribution over the phase $\theta$ that sets the eigenvalue phases
$$\theta_1 = \theta, \qquad \theta_2 = -\theta.$$

In [None]:
def _check_hk_su2():
    """
    Simulates diffusion on SU(2). Collects histograms of eigenangles
    at increasing times and compares thm to the analytic solution
    from `_eval_hk_su2'. 

    Forward diffusion process is performed as 
        U_{t+1} = exp(iA*sqrt{dt}) @ U_t,
    where A is a Lie algebra-valued noise matrix.
    """
    Nc = 2
    U0 = torch.stack([torch.eye(Nc)]*16000).cdouble()  # identity element of SU(2)
    
    dt = 0.01
    ts = []
    alpha = []
    bins = np.linspace(0, np.pi, num=51)
    
    U1 = U0.clone()
    for i in tqdm.tqdm(range(250)):
        A = sample_sun_gaussian(U1.shape)
        # assert torch.allclose(trace(A), torch.tensor(0.0).cdouble())
        U1 = sun_matrix_exp(1j*A*np.sqrt(dt)) @ U1
        if (i+1) % 50 == 0:
            ts.append((i+1)*dt)
            thetas, _, _ = mat_angle(U1)
            alpha_U1 = np.abs(grab(thetas[...,0]))  # theta equivalent to -theta by symmetry
            alpha.append(np.histogram(alpha_U1, bins=bins, density=True)[0])
    fig, ax = plt.subplots(1,1)
    cmap = plt.get_cmap('viridis')
    xs = (bins[1:]+bins[:-1])/2
    for i, (t, alpha_t) in enumerate(zip(ts, alpha)):
        color = cmap(i / len(ts))
        # data
        ax.plot(xs, alpha_t, label=f't = {t:.2f}', color=color)
        # anaytical
        # TODO: the measure term of this result warps the answer so it is not
        # properly symmetric and centered at pi/2...
        ys = _eval_hk_su2(xs, t, n_max=10)
        norm = np.sum(ys * (bins[1]-bins[0]))
        print(f'{norm=}')
        ys /= norm
        ax.plot(xs, ys, color=color, linestyle='--')
    ax.legend()
    ax.set_title('SU(2) heat kernel (data vs analytic)')
    ax.set_xticks([0, np.pi/2, np.pi])
    ax.set_xticklabels(['0', r'$\pi/2$', r'$\pi$'])
    ax.set_ylabel(r'Density')
    ax.set_xlabel(r'Eig $\theta$')
_check_hk_su2()

### SU(3)
Heat kernel distribution over the space $(\alpha, \beta)$ that parameterize the constrained phases
$$
\begin{pmatrix} \theta_1 \\ \theta_2 \\ \theta_3 \end{pmatrix} = A \begin{pmatrix} \alpha \\ \beta \end{pmatrix},
$$
with $\theta_3 = -\theta_1 - \theta_2$. Note that the Weyl group of permutations relates six triangular "chambers" that form the hexagonal space of all possible eigenvalue phases obtained from diagonalization. This total hexagonal is shown in the $(\alpha, \beta)$ plane for reference below.

In [None]:
def _check_hk_su3():
    """
    Simulates (variance-expanding) diffusion on SU(3). 
    
    Collects histograms of eigenangles at increasing times and 
    compares them to the analytic solution from `_eval_hk_su2'. 

    Forward diffusion process is performed as 
        U_{t+1} = exp(iA*sqrt{dt}) @ U_t,
    where A is a Lie algebra-valued noise matrix.
    """
    batch_size = 2**14
    Nc = 3
    U0 = torch.stack([torch.eye(Nc)] * batch_size).cdouble()
    
    dt = 0.01
    ts = []
    ab = []
    
    U1 = U0.clone()
    for i in tqdm.tqdm(range(240)):
        A = sample_sun_gaussian(U1.shape)
        # assert torch.allclose(trace(A), torch.tensor(0.0).cdouble())
        U1 = sun_matrix_exp(1j*A*np.sqrt(dt)) @ U1
        if (i+1) % 40 == 0:
            ts.append((i+1)*dt)
            thetas, _, _ = mat_angle(U1)  # [batch_size, 3]
            
            # randomize order of thetas
            thetas = grab(thetas)
            for i in range(len(thetas)):
                np.random.shuffle(thetas[i])
            thetas = torch.tensor(thetas)
            
            canon_th = better_canonicalize_su3(thetas)
            ab.append(np.einsum('ax,...x->a...', _su3_A, grab(canon_th)))  # project into Cartan subalgebra coordinate plane

    bins = np.linspace(-5.5, 5.5, num=51)
    fig, axes = plt.subplots(2, 3, figsize=(6, 4))
    fig.suptitle('SU(3) heat kernel (data)')
    fig2, axes2 = plt.subplots(2, 3, figsize=(6, 4))
    fig2.suptitle('SU(3) heat kernel (analytic)')
    cmap = plt.get_cmap('viridis')
    xs = (bins[1:]+bins[:-1])/2
    for i, (t, ab_t, ax, ax2) in enumerate(zip(ts, ab, axes.flatten(), axes2.flatten())):
        # data
        ax.hist2d(*ab_t, bins=bins)
        
        # analytical
        a, b = np.meshgrid(bins, bins, indexing='ij')  # 2D grid over (alpha, beta) space
        th = np.einsum('xa, a... -> x...', _su3_Ainv, np.stack((a,b), axis=0))  # map back into (th1, th2, th3) space
        th = grab(better_canonicalize_su3(torch.tensor(th).moveaxis(0, -1)))  # fold into fundamental hexagon
        v = eval_hk_sun(th[...,:2], t)
        ax2.contourf(a, b, v)
        
        # outlines
        ax.plot(*(_su3_A @ np.transpose(_su3_hex)), color='w', linestyle='--')
        ax2.plot(*(_su3_A @ np.transpose(_su3_hex)), color='w', linestyle='--')
        ax.set_aspect(1.0)
        ax2.set_aspect(1.0)
        ax.text(0.05, 0.95, f't={t:.2f}', fontsize=7, color='w', transform=ax.transAxes, ha='left', va='top')
        ax2.text(0.05, 0.95, f't={t:.2f}', fontsize=7, color='w', transform=ax.transAxes, ha='left', va='top')
    # ax.legend()
    # ax.set_ylabel(r'Density')
    # ax.set_xlabel(r'Eig $\theta$')
_check_hk_su3()

For comparison, the uniform Haar measure on the $(\alpha, \beta)$ plane of eigenvalues looks like:

In [None]:
def _plot_haar():
    """
    Plots the density of the Haar measure in the 2D plane
    that parameterizes conjugacy classes of SU(3), i.e. 
    the space of eigenangles up to permutation.
    """
    fig, ax = plt.subplots(1,1, figsize=(3.5, 3))
    bins = np.linspace(-5.5, 5.5, num=51)
    a, b = np.meshgrid(bins, bins, indexing='ij')
    
    th = np.einsum('xa,a...->x...', _su3_Ainv, np.stack((a,b), axis=0))
    th = grab(better_canonicalize_su3(torch.tensor(th).moveaxis(0, -1)))
    
    # pairwise differences
    v = (
        (2 - 2*np.cos(th[...,0]-th[...,1])) * 
        (2 - 2*np.cos(th[...,1]-th[...,2])) * 
        (2 - 2*np.cos(th[...,2]-th[...,0]))
    )

    ax.contourf(a, b, v)
    # outlines
    ax.plot(*(_su3_A@np.transpose(_su3_hex)), color='w', linestyle='--')
    ax.set_aspect(1.0)
    ax.set_title('SU(3) Haar measure')
_plot_haar()

### Sampling
We will want to sample from the given heat-kernel distribution at arbitrary times $t$. Since it is straightforward to wrap the distribution _after_ sampling, we work in the space of $x_i \in \mathbb{R}$. Since the measure seems to be relatively mild (at least for small $N_c$), we can simply apply a few hits of Metropolis resampling.

In [None]:
def sample_hk_sun(batch_size, Nc, t, *, n_iter=3):
    """
    Performs rejection sampling to draw samples from the
    SU(N) heat kernel at time `t`.

    Generates (unwrapped) isotropic proposals on the
    eigenangles space, respecting SU(N) constraints.
    Then iterates Metropolis for `n_iter` steps on 
    the angles.

    Random eigenbasis is also generated by sampling
    Haar-random unitary matrices.

    Returns samples A in the algebra.

    Args:
        batch_size: Number of samples to generate
        Nc: Dimension of fundamental rep of SU(N)
        t: Diffusion time (width of heat kernel), batched or scalar
        n_iter: Number of Metropolis 'hits'
    """
    # promote t to batch, if needed
    t = np.array(t)
    if len(t.shape) == 0:
        t = t*np.ones((batch_size,))
    assert t.shape == (batch_size,)

    def propose():
        """Samples eigenangles from centered Gaussian on hyperplane."""
        xs = np.random.normal(size=(batch_size, Nc))
        xs -= np.mean(xs, axis=-1, keepdims=True)  # sum to zero
        xs *= np.sqrt(t)[...,None]  # match heat kernel width
        return xs
    
    # sample eigenangles
    xs = propose()
    for i in range(n_iter):
        xps = propose()
        p = eval_hk_sun_unwrapped(xps[..., :-1], t, meas_only=True)
        p /= eval_hk_sun_unwrapped(xs[..., :-1], t, meas_only=True)  # ratio b/w new, old points 
        u = np.random.random(size=p.shape)
        xs[u < p] = xps[u < p]  # accept / reject step
    
    # OV: sample eigenvectors (see note below)
    # xs_full = np.concatenate([xs, -np.sum(xs, axis=-1, keepdims=True)], axis=-1)  # full set of angles (sum to zero)
    # A = np.empty((batch_size, Nc, Nc), dtype=np.complex128)
    # for i in range(batch_size):  # sample Haar unitary
    V, _ = np.linalg.qr(np.random.randn(batch_size, Nc, Nc) + 1j * np.random.randn(batch_size, Nc, Nc))
    D = np.identity(xs.shape[-1]) * xs[...,None] # embed diagonal
    A = V @ D @ adjoint(V)
    
    return xs, A

**OV:** I am wondering if the above sampling of eigenvectors (and eigen-recomposition by conjugation thereafter) is valid. My initial question is: can we treat the eigenvectors as Haar-random at each slice in diffusion time $t$? It certainly wouldn't make sense for the eigenvectors $V$ to remain constant over diffusion time... but does random sampling of $V$ at each $t$ suffice?

Recall that the heat kernel at time $t$ on ${\rm SU}(N)$ satisfies $$\partial_t K_t(U) = \Delta K_t(U)$$ subject to initial condition $K_0(U) = \delta(U)$, where $\Delta$ is the Laplace-Beltrami operator on the ${\rm SU}(N)$ manifold. 

We may always write an element $U \in {\rm SU}(N)$ as $U = V D V^\dagger$, with $D = {\rm diag}\left(e^{i\theta_1}, \cdots, e^{i\theta_N} \right)$, and WLOG set $\theta_N = -\sum_{j=1}^{N-1}\theta_j$. The Haar measure on ${\rm SU}(N)$ should then factorize into eigenangles (in an $N-1$-dimensional space) and eigenvectors $V$ uniformly distributed over ${\rm SU}(N) / \mathbb{T}$, where $\mathbb{T} \leq {\rm SU}(N)$ is the maximal torus in the group.

Note: the heat kernel is *conjugation-invariant*-- one can easily show that $\tilde{K}_t(U) := K_t(V U_t V^\dagger)$ satisfies the same PDE as $K_t(U)$, so by a uniqueness argument, conjugation-invariance must follow. This implies that $\forall V \in {\rm SU}(N)$, $$K_t(U) = K_t(V U V^\dagger).$$ This means the heat kernel is only spectrum-dependent, so the distribution of the eigenvalues $D$ will evolve with $t$ on the Torus and the distribution of $V$ is Haar and independent of $t$. Simulating a diffusion trajectory will cause both eigenvalues AND eigenvectors to change, but simply sampling along the diffusion trajectory, i.e. sampling the marginal distribution of $U_t$, means that the eigenvectors are distributed according to a stationary Haar distribution. We are not sampling a full path, just a single point (marginal sample), so I think this is fine (?)

In [None]:
def _test_sample_hk_su3():
    batch_size = 16000
    Nc = 3
    ts = [0.4, 0.8, 1.2, 1.6, 2.0, 2.4]
    
    bins = np.linspace(-5.5, 5.5, num=51)
    fig, axes = plt.subplots(2, 3, figsize=(6,4))
    fig2, axes2 = plt.subplots(2, 3, figsize=(6,4))
    fig.suptitle('SU(3) heat kernel (direct sampling)')
    fig2.suptitle('SU(3) heat kernel (analytical)')
    
    for t, ax, ax2 in zip(ts, axes.flatten(), axes2.flatten()):
        # samples
        xs, _ = sample_hk_sun(batch_size, Nc=Nc, t=t, n_iter=30)
        xs = grab(better_canonicalize_su3(torch.tensor(xs)))
        ab = np.einsum('ax,...x->...a', _su3_A, xs)
        ax.hist2d(ab[...,0], ab[...,1], bins=bins)
        
        # analytical
        a, b = np.meshgrid(bins, bins, indexing='ij')
        th = np.einsum('xa,a...->x...', _su3_Ainv, np.stack((a,b), axis=0))
        th = grab(better_canonicalize_su3(torch.tensor(th).moveaxis(0, -1)))
        v = eval_hk_sun(th[...,:2], t)
        ax2.contourf(a, b, v)
        
        # outlines
        ax.plot(*(_su3_A @ np.transpose(_su3_hex)), color='w', linestyle='--')
        ax2.plot(*(_su3_A @ np.transpose(_su3_hex)), color='w', linestyle='--')
        ax.text(0.05, 0.95, f't={t:.2f}', fontsize=7, color='w', transform=ax.transAxes, ha='left', va='top')
        ax2.text(0.05, 0.95, f't={t:.2f}', fontsize=7, color='w', transform=ax.transAxes, ha='left', va='top')
    plt.show()

_test_sample_hk_su3()

# SU(N) diffusion

We use the standard definition of SU(N) Brownian motion as defined in stochastic quantization to run the forward process. This will use the equivalent of the variance expanding scheme, so that for sufficiently large noise coefficient $\sigma$, the final distribution is very nearly uniform. One should either use asymptotic times, or tricks of rescaling time evolution to make this exact -- let's not worry about these details for now.

Conventionally, $t = 0$ will be the target distribution and $t = 1$ will be the uniform pure-noise distribution.

In [None]:
import scipy

### Need a matrix log func since PyTorch doesn't have this natively

# scipy implementation - quite slow...
def log_sun_scipy(V):
    """Compute matrix log for SU(N) matrices using SciPy."""
    V_np = grab(V)
    A_np = np.array([scipy.linalg.logm(v) for v in V_np])
    return torch.tensor(A_np, dtype=torch.cdouble)

def log_sun_torch(V):  # warning: this torch implementation seems unstable
    """
    Compute log(V) for V in SU(N) using eigendecomposition.
    
    Args: 
        V: Input batch of unitary matrices
    Returns:
        Matrix logarithm of `V`
    """
    eigvals, eigvecs = torch.linalg.eig(V)
    log_eigvals = torch.log(eigvals)  # principal branch
    log_diag = torch.diag_embed(log_eigvals)
    
    Q = eigvecs
    Qinv = torch.linalg.inv(Q)
    A = Q @ log_diag @ Qinv
    A = (A + adjoint(A)) / 2  # Hermitian
    trA = trace(A)
    A = A - trA[..., None, None] / A.shape[-1]  # traceless
    return A



def sun_fwd_diffusion(U, t, *, sigma):
    """
    Directly sample the forward diffusion process using the HK, 
    starting from initial data `U`.

    Steps:
        1.) Sample V from heat kenel at time s
        2.) compute algebra perturbation A = log(V)
        3.) apply perturbation to input U
    
    Args:
        U: Initial SU(N) matrices at t=0
        t: Scalar time in [0, 1]
        sigma: Diffusion or noise scale at time `t`

    Returns:
        A: log matrix in algebra su(N)
        Up: New SU(N) sample at time `t`
    """
    batch_size, Nc, _ = U.shape
    if callable(sigma):
        s = sigma(t)
    else:
        s = sigma 
    _, A_np = sample_hk_sun(batch_size, Nc=Nc, t=s, n_iter=3)
    A = torch.tensor(A_np, dtype=torch.cdouble)
    #A = -1j * log_sun_scipy(V)
    # A = -1j * log_sun_torch(V)
    V = sun_matrix_exp(1j*A)
    Up = V @ U
    return A, Up


def _test_sun_fwd_diffusion():
    torch.manual_seed(42)

    batch_size = 2 ** 4
    Nc = 3
    t = 0.5

    sigma = lambda t: 5**t

    U0 = torch.eye(Nc, dtype=torch.cdouble).repeat(batch_size, 1, 1)
    A, Up = sun_fwd_diffusion(U0, t, sigma=sigma)

    # Check A is Hermitian
    assert torch.allclose(A, adjoint(A)), \
        '[FAILED: A is NOT Hermitian]'
    
    # Check A is traceless
    trA = trace(A)
    assert torch.allclose(trA, torch.zeros_like(trA)), \
        '[FAILED: A is NOT traceless]'
    
    # Check Up is unitary
    I = torch.eye(Nc, dtype=torch.cdouble).repeat(batch_size, 1, 1)
    assert torch.allclose(Up @ adjoint(Up), I), \
        '[FAILED: Up is NOT unitary]'
    
    # Check det(Up) = 1
    dets = torch.linalg.det(Up)
    assert torch.allclose(dets, torch.ones_like(dets)), \
        '[FAILED: det(Up) =/= 1]'
    
    print('[PASSED forward diffusion test]')


if __name__ == '__main__': _test_sun_fwd_diffusion()


In [None]:
class Score(torch.nn.Module):
    def __init__(self, Nc):
        super().__init__()
        n_gen = len(sun_gens(Nc))  # Nc^2 - 1
        
        # (U,t) -> A
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2*Nc*Nc+1, 32),
            torch.nn.SiLU(),
            torch.nn.Linear(32, 32),
            torch.nn.SiLU(),
            torch.nn.Linear(32, n_gen),
            #torch.nn.SiLU()
        )

    def forward(self, U, t):
        """
        Outputs a direction in the Lie algebra, i.e. 
        a tangent vector at the identity.

        Args:
            U: Input SU(N) matrix
            t: Diffusion time

        Returns:
            Vector of Lie algebra generator coefficients
        """
        U_re, U_im = U.real.flatten(1), U.imag.flatten(1)
        assert len(t.shape) == 1, 't should be scalar batch'
        Up = torch.cat([U_re, U_im, t.unsqueeze(1)], dim=1)
        return self.net(Up)
    

def _test_score():
    batch_size = 1
    Nc = 3
    shape = (batch_size, Nc, Nc)
    A = sample_sun_gaussian(shape)
    U = sun_matrix_exp(1j * A)
    t = torch.randn((batch_size,))

    score_net = Score(Nc)
    out = score_net(U, t)
    print('Input U shape:', U.shape)
    print('Output coeffs shape:', out.shape)
    
    # Check that score maps to vector of algebra coeffs
    assert out.shape == (batch_size, Nc**2 - 1), \
        '[FAILED: score network does not produce correct shape]'

    # Recompose with the generators to produce algebra element
    Ap = torch.einsum('bg, g... -> b...', out.cdouble(), _su3_gens)
    Up = sun_matrix_exp(1j * Ap)
    I = torch.eye(Nc).repeat(batch_size, 1, 1).cdouble()
    assert torch.allclose(Up @ adjoint(Up), I), \
        '[FAILED: score net does not produce a unitary matrix]'
    
    print('[PASSED test simple score net]')

if __name__ == '__main__': _test_score()

Let's try training a simple example

In [None]:
# NOTE(gkanwar): This is identical to embed_sun_algebra() in utils above :)
def combine_sun_basis(coeffs, basis):
    """
    Creates an su(Nc) algebra element by forming the linear combination
    of coefficients (`coeffs`) and generators (`basis`).

    Number of coefficients should match number of generators, Nc^2 - 1.

    Args:
        coeffs: Vector of generator coefficients; [batch_size, Ng]
        basis: Tensor of Lie group generators/basis vectors; [Ng, Nc, Nc]

    Returns:
        Algebra matrix; [batch_size, Nc, Nc]
    """
    if coeffs.shape[-1] != len(basis):
        raise ValueError('Number of coefficients must match number of generators')
    return torch.einsum('bg, gij -> bij', coeffs.cdouble(), basis)


def get_batch(batch_size, Nc, t, *, sigma):
    """Gets a batch of SU(N) matrices at time t."""
    _, A_np = sample_hk_sun(batch_size, Nc, sigma(t))
    return torch.tensor(A_np, dtype=torch.cdouble)


def target_score(A, t, *, sigma):
    """Estimate the true score using s(U) = -(1 / sigma^2) * log(U)"""
    # TODO(gkanwar): factor of 2 here?
    return A / (2*sigma(t)[...,None,None]**2)


def score_matching_loss(A, U, score_net, t, *, sigma):
    """
    Computes the score matching loss at the time / noise level
    parameterized by sigma. 
    """
    pred_coeffs = score_net(U, t)
    pred_score = combine_sun_basis(pred_coeffs, _su2_gens)
    true_score = target_score(A, t, sigma=sigma)

    fro_norm = torch.sum((pred_score - true_score).conj() * (pred_score - true_score), dim=(1, 2))
    loss = torch.mean(fro_norm.real)
    return loss

In [None]:
# Hyperparams
Nc = 2
batch_size = 2048
epochs = 5000
lr = 3e-4
sigma = lambda t: 3.0*t**0.5

# Init model and optimizer
score_net = Score(Nc)
optimizer = torch.optim.Adam(score_net.parameters(), lr=lr)

# Training loop
loss_hist = []
for epoch in range(epochs):
    t = np.random.uniform(0.001, 1.0, size=batch_size)  # sample random t
    # sigma_t = sigma(t)
    A = get_batch(batch_size, Nc, t, sigma=sigma)
    # in general U = V @ U1, in terms of data U1
    # here we train to U1 = Id, i.e. sampled from p1 = delta(U1) at t=1
    U = sun_matrix_exp(1j*A)

    optimizer.zero_grad()
    loss = score_matching_loss(A, U, score_net, torch.tensor(t), sigma=sigma)
    loss.backward()
    optimizer.step()
    loss_hist.append(grab(loss))

    if epoch % 250 == 0:
        print(f"[Epoch: {epoch}/{epochs}] Loss = {loss.item():.6f}")


In [None]:
fig, ax = plt.subplots(1,1)
ax.plot(loss_hist)
ax.set_xlabel('train step')
ax.set_ylabel('loss')
plt.show()

In [None]:
@torch.no_grad()
def sample_reverse(batch_size, score_net, steps=100):
    dt = 1.0 / steps
    tvals = torch.linspace(1.0, 0.0, steps+1)
    # TODO: we should probably actually sample from Haar uniform here. The assumption
    # is that we have diffused enough to forget the initial conditions... if we haven't
    # then in any case the correct dist would be U_0 = V U_1 , V ~ K(V) in terms of initial
    # data U_1 ~ p_1(U_1)
    _, A_np = sample_hk_sun(batch_size, Nc, t=sigma(1.0)**2)  # sample from prior p_1
    A = torch.tensor(A_np, dtype=torch.cdouble)
    U = sun_matrix_exp(1j * A)

    for t in tvals[:-1]:
        sigma_t = sigma(t)
        noise_scale = sigma_t * torch.tensor(2.0 * dt)**0.5

        score_coeffs = score_net(U, t * torch.ones((batch_size,)))
        score = combine_sun_basis(score_coeffs, _su2_gens)

        noise_coeffs = torch.randn_like(score_coeffs)
        noise = combine_sun_basis(noise_coeffs, _su2_gens)  # noise in Lie algebra

        # Euler step in su(N)
        delta = -dt * score + noise_scale * noise
        #delta = -dt * target_score(U, t) + noise_scale * noise
        U = sun_matrix_exp(1j * delta) @ U  # random walk on group mfld
    
    return U
    

U_new = sample_reverse(batch_size, score_net, steps=500)
# Check that these are in the group
assert torch.allclose(U_new @ adjoint(U_new), torch.eye(Nc).repeat(batch_size, 1, 1).cdouble()), \
    'U_new NOT unitary'
assert torch.allclose(dets:=torch.linalg.det(U), torch.ones_like(dets)), \
    'det(U_new) =/= 1'
print(f'Succesfully generated new SU({Nc}) matrices')

In [None]:
def validate_su2(U_sampled, t_eval=1.0):
    """
    Plot histogram of SU(2) eigenangles and compare to analytical HK.
    """
    # Get eigenangles
    thetas, _, _ = mat_angle(U_sampled)
    theta = thetas[:, 0]  # SU(2) has only one angle

    # Histogram of sampled eigenangles
    bins = np.linspace(0, np.pi, 100)
    xs = (bins[:-1] + bins[1:]) / 2
    hist, _ = np.histogram(np.abs(theta), bins=bins, density=True)

    # Analytical HK (unnormalized)
    ys = _eval_hk_su2(xs, t=sigma(t_eval)**2)
    ys /= np.sum(ys * (bins[1] - bins[0]))  # normalize

    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    ax.plot(xs, ys, label='Analytic HK', color='blue', linewidth=2)
    ax.plot(xs, hist, label='Sampled', color='red', linewidth=2)
    ax.set_title(r'SU(2) Eigenangle Distribution')
    ax.set_xlabel(r'$|\theta|$')
    ax.set_ylabel('Density')
    ax.set_xticks([0, np.pi/2, np.pi])
    ax.set_xticklabels([r'$0$', r'$\pi/2$', r'$\pi$'])
    ax.legend()
    fig.tight_layout()
    plt.show()

validate_su2(U_new, t_eval=0.1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
axes[0].set_ylabel('Density')
thetas_new, _, _ = mat_angle(U_new)
for i, ax in enumerate(axes):
    ax.hist(thetas_new[:, i], bins=50, density=True, label='Diffusion')
    ax.set_xlabel(rf'$\theta_{i+1}$')
fig.tight_layout()
# should see mirror symmetry between the two histograms since th1 = -th2

In [None]:
batch_size = U_new.shape[0]
Nc = 2
t_eval_prior = 1.0
t_eval_posterior = 0.1

# Prior samples
eigs_prior, _ = sample_hk_sun(batch_size, Nc, t=sigma(t_eval_prior)**2)
eigs_prior = torch.tensor(eigs_prior, dtype=torch.float64)
eigs_prior = torch.cat([eigs_prior, -torch.sum(eigs_prior, dim=-1, keepdim=True)], dim=-1)
print('eigs prior shape:', eigs_prior.shape)
U_prior = sun_matrix_exp(1j * combine_sun_basis(eigs_prior, _su2_gens)) 

theta_new, _, _ = mat_angle(U_new) 
theta_prior, _, _ = mat_angle(U_prior)
theta_new = np.abs(theta_new[:, 0])
theta_prior = np.abs(theta_prior[:, 0])


bins = np.linspace(0, np.pi, 50)
xs = (bins[1:] + bins[:-1]) / 2
hist_new, _ = np.histogram(theta_new, bins=bins, density=True)
hist_prior, _ = np.histogram(theta_prior, bins=bins, density=True)

hk_post = _eval_hk_su2(xs, t=sigma(t_eval_posterior)**2)
hk_prior = _eval_hk_su2(xs, t=sigma(t_eval_prior)**2)

dx = bins[1] - bins[0]
hk_post /= np.sum(hk_post * dx)
hk_prior /= np.sum(hk_prior * dx)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(xs, hist_new, label='Posterior (Diffusion)', color='blue', alpha=0.6)
ax.plot(xs, hist_prior, label='Prior (Noisy)', color='orange', alpha=0.6)
ax.plot(xs, hk_post, label='Posterior Analytic HK', color='darkblue', linewidth=2)
ax.plot(xs, hk_prior, label='Prior Analytic HK', color='darkorange', linewidth=2)

ax.set_xlabel(r'$|\theta|$')
ax.set_ylabel('Density')
ax.set_title(r'SU(2) Eigenangle Distributions')
ax.set_xticks([0, np.pi/2, np.pi])
ax.set_xticklabels([r'$0$', r'$\pi/2$', r'$\pi$'])
ax.legend()
fig.tight_layout()
plt.show()