# Setup

In [None]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm.auto as tqdm
%matplotlib widget

In [None]:
torch.set_default_dtype(torch.float64)

In [None]:
def grab(x):
    return x.detach().cpu().numpy()
def wrap(theta):
    return (theta + np.pi) % (2*np.pi) - np.pi

In [None]:
def adjoint(U):
    return U.conj().swapaxes(-1,-2)
def trace(U):
    return U.diagonal(dim1=-2, dim2=-1).sum(-1)
def mat_angle(U):
    L, V = torch.linalg.eig(U)
    Vinv = torch.linalg.inv(V)
    th = torch.angle(L)
    return th, V, Vinv
def proj_trless(A):
    Nc = A.shape[-1]
    A -= (1/Nc)*torch.eye(Nc)*trace(A)[...,None,None]
    return A
def proj_sun_algebra(A):
    Nc = A.shape[-1]
    Ap = (A + adjoint(A))/2
    Ap = proj_trless(Ap)
    return Ap
def sample_sun_gaussian(shape):
    Nc, Nc_ = shape[-2:]
    assert Nc == Nc_
    return proj_sun_algebra(torch.randn(shape) + 1j*torch.randn(shape))

In [None]:
def canonicalize_su2(thW):
    # map first angle into [0,pi]
    thW[...,0] = torch.abs(wrap(thW[...,0]))
    # second angle is negative first
    thW[...,1] = -thW[...,0]
    return thW
def canonicalize_su3(thW):
    # project onto hyperplane defined by sum_i theta_i = 0
    thW[...,-1] -= torch.sum(thW, dim=-1)
    v = thW.reshape(-1, 3)
    
    # wrap into canonical hexagon centered on I
    # map from (a,b,c) into v = (th1, th2, th3)
    U = 2*np.pi * torch.tensor([
        [1, 0, -1],
        [0, -1, 1],
        [-1, 1, 0],
    ])
    # map from v into (a,b,c)
    Uinv = torch.tensor([
        [1, 0, -1],
        [0, -1, 1],
        [-1, 1, 0]
    ]) / (6*np.pi)
    kappa = Uinv @ torch.transpose(v, 0, 1) # ij,jb -> ib
    a, b, c = kappa[0], kappa[1], kappa[2]
    k = (b+c)/2
    a -= k
    b -= k
    c -= k
    a -= torch.round(a)
    k = torch.round(b)
    b -= k
    c += k
    b -= torch.round(b - (a+c)/2)
    k = (b+c)/2
    a -= k
    b -= k
    c -= k
    a -= torch.round(a)
    c -= torch.round(c - (a+b)/2)
    return torch.transpose(U @ torch.stack([a,b,c], dim=0), 0, 1).reshape(thW.shape)

In [None]:
def canonicalize_sun(thW):
    Nc = thW.shape[-1]
    if Nc == 2:
        return canonicalize_su2(thW)
    elif Nc == 3:
        return better_canonicalize_su3(thW)
    else:
        raise ValueError(f'not supported {Nc=}')

In [None]:
# matrix to transform from (theta1, theta2, theta3) -> (alpha, beta) plane
_su3_A = np.array([
    [0, 1, -1] / np.sqrt(2),
    [2, -1, -1] / np.sqrt(6),
])
_su3_Ainv = np.array([
    [0, np.sqrt(2/3)],
    [1 / np.sqrt(2), -1 / np.sqrt(6)],
    [-1 / np.sqrt(2), -1 / np.sqrt(6)],
])
assert np.allclose(_su3_A @ _su3_Ainv, np.identity(2))


# for nice plots in (alpha, beta) plane
_su3_hex = (2*np.pi/3) * np.array([
    (1,1,-2),
    (2,-1,-1),
    (1,-2,1),
    (-1,-1,2),
    (-2,1,1),
    (-1,2,-1),
    (1,1,-2),
])

# Score and Heat kernel

Evolving a sample $x$ according to the **score** $\nabla \log{p}(x)$ of any distribution $p$ should exactly cancel diffusion (with coefficient $1$) due to the relevant Fokker-Planck equation:
$$
dx_t = \tfrac{1}{2} \nabla \log{p}(x) dt + dW_t
\quad \iff \quad
\frac{d}{dt} \log{p} = -\tfrac{1}{2} \nabla \cdot (\nabla \log{p}) + \tfrac{1}{2} \Delta \log{p} = 0.
$$
This allows us to check numerically that we have evaluated the score correctly for the heat kernel, which we need to evaluate and train score-based diffusion models.

Alternatively, we can check that the score is correctly evaluated using auto-differentiation.

We check both for the **Euclidean case** and the **SU(N) case** below.

In [None]:
def eucl_log_hk(x, *, sigma):
    """Euclidean heat kernel with width sigma."""
    return -(x**2).sum(-1)/(2*sigma**2)

def eucl_score(x, *, sigma):
    """Euclidean score for HK with width sigma."""
    return -x/(sigma**2)

In [None]:
def _test_eucl_hk():
    # forward diffusion
    batch_size = 16000
    x = torch.zeros((batch_size, 1))
    t = 2.0
    sigma = np.sqrt(t)
    steps = 100
    dt = t/steps
    for _ in tqdm.tqdm(range(steps)):
        x += np.sqrt(dt) * torch.randn_like(x)
    x = grab(x)
    
    # analytical
    x_mesh = torch.linspace(-5, 5, steps=51)[...,None]
    hk = torch.exp(eucl_log_hk(x_mesh, sigma=sigma))
    x_mesh = grab(x_mesh)
    hk = grab(hk)
    hk /= np.sum(hk) * (x_mesh[1]-x_mesh[0])
    
    # plot
    fig, ax = plt.subplots(1, 1)
    bins = np.linspace(-5, 5, num=51)
    ax.hist(x, bins=bins, density=True, label='data')
    ax.plot(x_mesh, hk, color='k', label='analytical')
    ax.legend()
    ax.set_title(f'Eucl heat kernel at {t=}')
    plt.show()
_test_eucl_hk()

In [None]:
def _test_eucl_hk_evolution():
    batch_size = 16000
    steps = 100
    t_final = 2.0
    dt = t_final / steps
    times_to_plot = [0.2, 0.5, 1.0, 1.5, 2.0]  # intermediate times

    # prepare subplot
    fig, axes = plt.subplots(1, len(times_to_plot), figsize=(5 * len(times_to_plot), 4))

    x = torch.zeros((batch_size, 1))  # initial data at t=0
    x_mesh = torch.linspace(-5, 5, 51)[..., None]

    t_current = 0.0
    step_idx = 0
    for step in tqdm.tqdm(range(steps)):
        x += np.sqrt(dt) * torch.randn_like(x)
        t_current += dt
        if step_idx < len(times_to_plot) and t_current >= times_to_plot[step_idx]:
            sigma = np.sqrt(times_to_plot[step_idx])

            # analytical
            hk = torch.exp(eucl_log_hk(x_mesh, sigma=sigma))
            hk = grab(hk)
            hk /= np.sum(hk) * (x_mesh[1]-x_mesh[0])
            x_plot = grab(x)

            # plot
            ax = axes[step_idx]
            bins = np.linspace(-5, 5, num=51)
            ax.hist(x_plot, bins=bins, density=True, label='data')
            ax.plot(grab(x_mesh), hk, color='k', label='analytical')
            ax.set_title(f't = {times_to_plot[step_idx]:.2f}')
            ax.legend()

            step_idx += 1

    plt.tight_layout()
    plt.show()

_test_eucl_hk_evolution()

In [None]:
def _test_eucl_score():
    # forward diffusion
    batch_size = 16000
    x = torch.zeros((batch_size, 1))
    t = 2.0
    sigma = np.sqrt(t)
    steps = 1000
    dt = t/steps
    for _ in tqdm.tqdm(range(steps)):
        x += np.sqrt(dt) * torch.randn_like(x)
    x_pre = np.copy(grab(x))
    
    # stationary diffusion
    for _ in tqdm.tqdm(range(steps)):
        g = eucl_score(x, sigma=sigma)
        x += np.sqrt(dt) * torch.randn_like(x) + 0.5*dt * g
    x = grab(x)
    
    # analytical
    x_mesh = torch.linspace(-5, 5, steps=51)[...,None]
    hk = torch.exp(eucl_log_hk(x_mesh, sigma=sigma))
    x_mesh = grab(x_mesh)
    hk = grab(hk)
    hk /= np.sum(hk) * (x_mesh[1]-x_mesh[0])
    
    # plot
    fig, ax = plt.subplots(1, 1)
    bins = np.linspace(-5, 5, num=51)
    ax.hist(x_pre, bins=bins, density=True, label='data (initial)', histtype='step', color='xkcd:red', linestyle='--')
    ax.hist(x, bins=bins, density=True, label='data (final)')
    ax.plot(x_mesh, hk, color='k', label='analytical')
    ax.legend()
    ax.set_title(f'Eucl heat kernel at {t=}')
    plt.show()
_test_eucl_score()

In [None]:
def _test_eucl_score_evolution():
    batch_size = 16000
    steps = 1000
    t_final = 2.0
    dt = t_final / steps
    times_to_plot = [0.2, 0.5, 1.0, 1.5, 2.0]

    # prepare subplot
    fig, axes = plt.subplots(1, len(times_to_plot), figsize=(5 * len(times_to_plot), 4))

    x = torch.zeros((batch_size, 1))
    x_mesh = torch.linspace(-5, 5, 51)[..., None]

    t_current = 0.0
    step_idx = 0
    for step in tqdm.tqdm(range(steps)):
        # forward diffusion
        x += np.sqrt(dt) * torch.randn_like(x)

        t_current += dt
        if step_idx < len(times_to_plot) and t_current >= times_to_plot[step_idx]:
            sigma = np.sqrt(times_to_plot[step_idx])

            # stationary diffusion step (Euler-Maruyama using score)
            g = eucl_score(x, sigma=sigma)
            x_step = x + 0.5 * dt * g  # only drift, not adding extra noise for clarity
            x_plot = grab(x_step)

            # analytical density
            hk = torch.exp(eucl_log_hk(x_mesh, sigma=sigma))
            hk = grab(hk)
            hk /= np.sum(hk) * (x_mesh[1] - x_mesh[0])

            # plot
            ax = axes[step_idx]
            bins = np.linspace(-5, 5, num=51)
            ax.hist(x_plot, bins=bins, density=True, label='data', histtype='step', color='xkcd:red')
            ax.plot(grab(x_mesh), hk, color='k', label='analytical')
            ax.set_title(f't = {times_to_plot[step_idx]:.2f}')
            ax.legend()

            step_idx += 1

    plt.tight_layout()
    plt.show()

_test_eucl_score_evolution()

In [None]:
def _sun_hk_meas_J(delta):
    """Measure term Jij on SU(N) matrix eigen-angle diffs"""
    return 2 * torch.sin(delta/2)
    
def _sun_hk_meas_D(delta):
    """Measure term Dij on Hermitian matrix eigenvalue diffs"""
    return delta

def _sun_hk_unwrapped(xs, *, sigma, eig_meas=True):
    xn = -torch.sum(xs, dim=-1, keepdims=True)
    xs = torch.cat([xs, xn], dim=-1)
    # pairwise differences
    delta_x = torch.stack([
        xs[...,i] - xs[...,j]
        for i in range(xs.shape[-1]) for j in range(i+1, xs.shape[-1])
    ], dim=-1)
    # include/exclude eigenvalue Haar measure J^2
    if eig_meas:
        meas = torch.prod(_sun_hk_meas_D(delta_x) * _sun_hk_meas_J(delta_x), axis=-1)
    else:
        meas = torch.prod(_sun_hk_meas_D(delta_x) / _sun_hk_meas_J(delta_x), axis=-1)
    weight = torch.exp(-1/(2*sigma**2) * (xs**2).sum(-1))
    return meas * weight

def _sun_score_unwrapped(xs, *, sigma):
    K = _sun_hk_unwrapped(xs, sigma=sigma, eig_meas=False)
    xn = -torch.sum(xs, dim=-1, keepdims=True)
    xs = torch.cat([xs, xn], dim=-1)
    Nc = xs.shape[-1]
    delta_mat = xs[...,:,None] - xs[...,None,:] + 0.1*torch.eye(Nc).to(xs)
    grad_meas = 1/delta_mat - 0.5/torch.tan(0.5*delta_mat)
    grad_meas = grad_meas * (1 - torch.eye(Nc)).to(xs) # mask diagonal
    grad_meas = grad_meas.sum(-1)
    grad_weight = -xs/sigma**2
    return (grad_meas + grad_weight) * K[...,None]
    
def sun_hk(thetas, *, sigma, n_max=3, eig_meas=True):
    """SU(N) heat kernel with width sigma.
    
    Given as a density with respect to natural measure on eigenvalues."""
    total = 0
    for ns in itertools.product(range(-n_max, n_max), repeat=thetas.shape[-1]):
        ns = torch.tensor(ns)
        xs = thetas + 2*np.pi*ns
        total = total + _sun_hk_unwrapped(xs, sigma=sigma, eig_meas=eig_meas)
    return total

def sun_score(thetas, *, sigma, n_max=3):
    """SU(N) score for HK with width sigma.
    
    Returns gradient nabla_x K(U), which must be embedded as
      P diag_embed(nabla_x K(U)) D P*
    in terms of the diagonalization U = P D P*, to obtain the full
    left-acting gradient.
    """
    total = 0
    for ns in itertools.product(range(-n_max, n_max), repeat=thetas.shape[-1]):
        ns = torch.tensor(ns)
        xs = thetas + 2*np.pi*ns
        total = total + _sun_score_unwrapped(xs, sigma=sigma)
    return total

def sun_score_v2(thetas, *, sigma, n_max=3):
    assert len(thetas.shape) == 2, 'expects batched ths'
    Nc = thetas.shape[-1]+1
    f = lambda ths: sun_hk(ths, sigma=sigma, n_max=n_max, eig_meas=False)
    def gf(ths):
        g = torch.func.grad(f)(ths)
        gn = -g.sum(-1) / Nc
        g = torch.cat([g + gn, gn[...,None]], dim=-1)
        return g
    return torch.func.vmap(gf)(thetas)

In [None]:
def _test_sun_score():
    torch.manual_seed(1234)
    batch_size = 128
    Nc = 3
    thetas = 3*np.pi*torch.rand((batch_size, Nc-1))
    a = sun_score(thetas, sigma=1.0, n_max=1)
    b = sun_score_v2(thetas, sigma=1.0, n_max=1)
    assert torch.allclose(a, b), f'{a=} {b=} {a/b=}'
    print('[PASSED test_sun_score]')
_test_sun_score()

In [None]:
def sun_sample_hk(batch_size, Nc, *, sigma, n_iter=3, n_max=3):
    def propose():
        """Samples eigenangles from uniform dist."""
        xs = 2*np.pi*np.random.random(size=(batch_size, Nc))
        xs[...,-1] = -np.sum(xs[...,:-1])
        return grab(canonicalize_sun(torch.tensor(xs)))
    
    # sample eigenangles
    xs = propose()
    for i in range(n_iter):
        xps = propose()
        # ratio b/w new, old points
        p = sun_hk(xps[..., :-1], sigma=sigma, n_max=n_max)
        p /= sun_hk(xs[..., :-1], sigma=sigma, n_max=n_max)
        u = np.random.random(size=p.shape)
        xs[u < p] = xps[u < p]  # accept / reject step

    # sample eigenvectors
    V, _ = np.linalg.qr(np.random.randn(batch_size, Nc, Nc) + 1j * np.random.randn(batch_size, Nc, Nc))
    D = np.identity(xs.shape[-1]) * xs[...,None] # embed diagonal
    A = V @ D @ adjoint(V)
    
    return xs, A

In [None]:
def _test_sun_hk():
    # forward diffusion
    batch_size = 16000
    Nc = 2
    x = torch.stack([torch.eye(Nc)]*batch_size).cdouble()
    t = 2.0
    sigma = np.sqrt(t)
    steps = 500
    dt = t/steps
    for _ in tqdm.tqdm(range(steps)):
        x = torch.matrix_exp(1j * np.sqrt(dt) * sample_sun_gaussian(x.shape)) @ x
    x = grab(torch.arccos(trace(x)/2)) # get canonical cell angle
    
    # analytical
    x_mesh = torch.linspace(0, np.pi-1e-3, steps=51)[...,None]
    hk = sun_hk(x_mesh, sigma=sigma)
    print(f'{hk=}')
    x_mesh = grab(x_mesh)
    hk = grab(hk)
    hk /= np.sum(hk) * (x_mesh[1]-x_mesh[0])
    
    # plot
    fig, ax = plt.subplots(1, 1)
    bins = np.linspace(0, np.pi, num=51)
    ax.hist(x, bins=bins, density=True, label='data')
    ax.plot(x_mesh, hk, color='k', label='analytical')
    ax.legend()
    ax.set_title(f'SU(N) Heat kernel at {t=}')
    plt.show()
_test_sun_hk()

In [None]:
def _test_sun_hk_evolution():
    batch_size = 16000
    Nc = 2
    x = torch.stack([torch.eye(Nc)]*batch_size).cdouble()
    t = 2.0
    steps = 500
    dt = t / steps
    snapshot_steps = [0, int(steps*0.25), int(steps*0.5), int(steps*0.75), steps-1]
    snapshots = []
    sigmas = []

    # forward diffusion
    for i in tqdm.tqdm(range(steps)):
        x = torch.matrix_exp(1j * np.sqrt(dt) * sample_sun_gaussian(x.shape)) @ x
        if i in snapshot_steps:
            angles = torch.arccos(torch.real(trace(x)/Nc))  # canonical eigenangle
            snapshots.append(grab(angles))
            sigmas.append(np.sqrt(dt*(i+1)))


    n_snap = len(snapshots)
    fig, axes = plt.subplots(1, n_snap, figsize=(4*n_snap, 4), sharey=True)
    for ax, angles, sigma, step in zip(axes, snapshots, sigmas, snapshot_steps):
        # compute analytical heat kernel at this sigma
        x_mesh = torch.linspace(0, np.pi-1e-3, steps=51)[...,None]
        bins = np.linspace(0, np.pi, num=51)
        hk = sun_hk(x_mesh, sigma=sigma)
        x_mesh = grab(x_mesh)
        hk = grab(hk)
        hk /= np.sum(hk) * (x_mesh[1]-x_mesh[0])

        ax.hist(angles, bins=bins, density=True, histtype='step', color='xkcd:red', label='simulation')
        ax.plot(x_mesh, hk, color='k', label='analytical')
        ax.set_title(f'step {step}, t={dt*step:.2f}')
        ax.legend()

    fig.suptitle('SU(2) Heat Kernel Evolution')
    plt.show()

_test_sun_hk_evolution()

In [None]:
def _test_sun_score():
    # forward diffusion
    batch_size = 4096
    Nc = 2
    x = torch.stack([torch.eye(Nc)]*batch_size).cdouble()
    t = 2.0
    sigma = np.sqrt(t)
    steps = 500
    dt = t/steps
    for _ in tqdm.tqdm(range(steps)):
        x = torch.matrix_exp(1j * np.sqrt(dt) * sample_sun_gaussian(x.shape)) @ x
    x_pre = grab(torch.arccos(trace(x)/2))
    
    # stationary diffusion
    x = torch.matrix_exp(0.1j * sample_sun_gaussian(x.shape))
    t2 = 3.0
    steps = 500
    dt = t2/steps
    xt = []
    for i in tqdm.tqdm(range(steps)):
        thetas, V, Vinv = mat_angle(x)
        g = sun_score(thetas[...,:-1], sigma=sigma)
        assert torch.allclose(g.sum(-1), torch.tensor(0.0).to(g)), \
            f'{g.sum(-1)=}'
        thetas += 0.5 * dt * g
        D = torch.exp(1j * thetas)[...,None] * torch.eye(Nc).to(thetas)
        x = V @ D @ Vinv
        x = torch.matrix_exp(1j * np.sqrt(dt) * sample_sun_gaussian(x.shape)) @ x
        if (i+1) % 100 == 0:
            xt.append(grab(torch.arccos(trace(x)/2)))
    
    # analytical
    x_mesh = torch.linspace(0, np.pi-1e-3, steps=51)[...,None]
    hk = sun_hk(x_mesh, sigma=sigma)
    print(f'{hk=}')
    x_mesh = grab(x_mesh)
    hk = grab(hk)
    hk /= np.sum(hk) * (x_mesh[1] - x_mesh[0])
    
    # plot
    fig, ax = plt.subplots(1, 1)
    bins = np.linspace(0, np.pi, num=51)
    ax.hist(x_pre, bins=bins, density=True, label='data (initial)', histtype='step', color='xkcd:red', linestyle='--')
    cmap = plt.get_cmap('viridis')
    for i,x in enumerate(xt):
        style = dict(histtype='step', color=cmap(i/len(xt)))
        if i == len(xt)-1:
            style['label'] = 'data (final)'
            style['histtype'] = 'bar'
        ax.hist(x, bins=bins, density=True, **style)
    ax.plot(x_mesh, hk, color='k', label='analytical')
    ax.legend()
    ax.set_title(f'SU(N) Heat kernel at {t=}')
    plt.show()
_test_sun_score()