# Denoising using the Analytical Score Function for the ${\rm SU}(N)$ Heat Kernel

In this notebook, we will show how knowledge of the analytical score function for the heat kernel can be used to "undo," or *denoise*, a forward diffusion process that takes initial data into corrupted noisy samples.

## Setup

In [None]:
# General imports
import numpy as np
import torch

import tqdm.auto as tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
# Imports from our repo
import sys
sys.path.insert(0, '..')  # repo source code

from src.linalg import trace, adjoint
from src.diffusion import VarianceExpandingDiffusion, VarianceExpandingDiffusionSUN
from src.sun import (
    proj_to_algebra, matrix_exp, matrix_log,
    random_sun_element, random_un_haar_element,
    group_to_coeffs, coeffs_to_group,
    extract_sun_algebra, embed_diag, mat_angle
)
from src.canon import canonicalize_sun
from src.heat import (
    eucl_score_hk,
    sun_score_hk, sample_sun_hk, sun_hk, _sun_hk_unwrapped
)
from src.utils import grab, wrap
from src.devices import set_device, get_device, summary

In [None]:
set_device('cpu')
print(summary())

## Euclidean Diffusion

Variance-expanding diffusion is a stochastic process described by $$dx = g(t) dW$$ that corresponds to the heat equation: $$\partial_t u(x, t) = \frac{g(t)^2}{2} \Delta u(x, t),$$ where $g^2$ is a postive, time-dependent scalar quantity we call the *diffusivity*. The differential operator $\Delta$ is the Laplace-Beltrami operator on whatever ambient space the samples $x$ occupy. In Euclidean space, $x \in \mathbb{R}^n$, $\Delta$ is just the familiar Laplacian given by $\Delta := \sum_{i=1}^n \partial_i^2$. In particuar, for a single, real-valued degree of freedom (with constant unit diffusivity), the heat equation is $$\partial_t u = \partial_x^2 u,$$ and the fundamental solution (Green's function / propagator) for this PDE is called the *Heat Kernel* and is given by $$K(x, t) = \frac{1}{\sqrt{2\pi\sigma(t)^2}}e^{-\frac{x^2}{2\sigma(t)^2}},$$ where $\sigma(t)$ is the marginal standard deviation

In [None]:
def euclidean_heat_kernel(x, t, width=None):
    """Computes the Euclidean heat kernel density K(x, t) for `x` at time `t`."""
    if width is None:
        width = t**0.5  # assume unit constant diffusivity
    normalization = 1 / (2 * np.pi * width**2)**0.5
    weight = torch.exp(-x**2 / (2 * width**2))
    return normalization * weight


def _visualize_heat_kernel():
    xs = torch.linspace(-4, 4, 100)
    times = np.linspace(0.05, 1, 20)
    
    fig, ax = plt.subplots(1, 1)
    cmap = mpl.colormaps.get_cmap('viridis')
    ax.set_xlabel('$x$')
    ax.set_ylabel('Heat Kernel Density')
    for t in times:
        line = ax.plot(xs, euclidean_heat_kernel(xs, t), color=cmap(t))#, label=f'$t = {t}$')
    time_colors = mpl.cm.ScalarMappable(mpl.colors.Normalize(times[0], times[-1]))
    cbar = fig.colorbar(time_colors, cmap=cmap, ax=ax, label='$t$')
    fig.show()

_visualize_heat_kernel()

One can easily sample from the Euclidean heat kernel at arbitrary time $t$; since $K(x, t)$ is always a normal distribution, you only need to know the marginal standard deviation at $t$.

In [None]:
def sample_euclid_hk(batch_size, t, width=None):
    """Generates `batch_size` samples from the Euclidean
    Heat kernel at time `t` with width `width`."""
    if width is None:
        width = t**0.5
    x_t = width * torch.randn((batch_size,))
    return x_t

def _test_sample_hk():
    batch_size = 16384
    xs = torch.linspace(-4, 4, 100)
    times = [0.1, 0.5, 1.0]
    
    cmap = plt.get_cmap('viridis')
    fig, ax = plt.subplots(1, 1)
    ax.set_xlabel('$x$')
    ax.set_ylabel('Density')
    for t in times:
        xt = sample_euclid_hk(batch_size, t)
        ax.hist(xt, bins=50, density=True, color=cmap(t), alpha=0.55, label=f'Heat kernel samples at $t = {t}$')
        ax.plot(xs, euclidean_heat_kernel(xs, t), ls='--', color=cmap(t), label=f'Analytical HK at $t = {t}$')
    fig.legend(loc='right', fontsize='small')
    ax.set_title('Heat Kernel Density')
    fig.show()

_test_sample_hk()

It is also easy to simulate the forward process

In [None]:
def visualize_euclid_fwd():
    """Simulates forward Euclidean VE diffusion process."""
    batch_size = 4096
    x_0 = torch.zeros((batch_size, 1))
    diffuser = VarianceExpandingDiffusion(sigma=1.1)
    
    xs = torch.linspace(-4, 4, 100)
    times = [0.01, 0.1, 0.5, 1.0]
    cmap = mpl.colormaps.get_cmap('viridis')
    
    fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
    fig.suptitle('Density During Forward Process')
    axes[0].set_ylabel('Density')
    for t, ax in zip(times, axes):
        # Samples stats should match params of analytical HK
        x_t = diffuser(x_0, torch.tensor(t).repeat(batch_size))
        sigma_t = diffuser.sigma_func(t)
        assert torch.allclose(torch.std(x_t), sigma_t, atol=5e-2), \
            f'StDev of samples {torch.std(x_t).item():.4f} does not match marginal HK StDev {sigma_t:.4f}'
        
        # Visualize forward process over time
        ax.set_title(f'$t = {t}$')
        ax.set_xlabel('$x$')
        ax.hist(x_t, bins=50, density=True, color=cmap(t), alpha=0.55, label=f'Diffused samples at $t = {t}$')
        ax.plot(xs, euclidean_heat_kernel(xs, t, width=sigma_t), ls='--', color=cmap(t), label=f'Analytical HK at $t = {t}$')
        ax.legend()
    fig.show()

visualize_euclid_fwd()

It is easy to compute the analytical score function for the Euclidean heat kernel:

$$s(x, t) = \partial_x \log K(x, t) = -\frac{x}{\sigma(t)^2}$$

which can be used to solve the reverse SDE: $$d\tilde{x} = -g(t)^2 s(x, t)dt + g(t) dW$$

In [None]:
def denoise_backward(x_1, diffuser, num_steps=200, solver_type='ODE'):
    dt = 1 / num_steps
    t = 1.0
    x_t = x_1.clone()

    trajectories = []
    for step in range(num_steps):
        sigma_t = diffuser.sigma_func(t)
        g_t = diffuser.noise_coeff(t)
        score = eucl_score_hk(x_t, width=sigma_t)
        
        # Integration step
        if solver_type == 'ODE':
            x_t = x_t + 0.5 * g_t**2 * score * dt
        elif solver_type == 'SDE':
            x_t = x_t + 0.5*g_t**2 * score * dt + g_t * torch.rand_like(x_t) * dt**0.5  # TODO: debug
        else:
            raise NotImplementedError(f'Integration method {solver_type} not supported')
        t -= dt
        trajectories.append(x_t)
    return x_t, trajectories


def _test_denoise():
    batch_size = 8192
    x_1 = sample_euclid_hk(batch_size, t=1.0)
    diffuser = VarianceExpandingDiffusion(sigma=1.1)
    
    num_steps = 200
    x_0, trajectories = denoise_backward(x_1, diffuser, num_steps)
    
    cmap = mpl.colormaps.get_cmap('viridis')
    xs = torch.linspace(-4, 4, 100)
    steps = [20, 50, 150, 195]
    
    fig, axes = plt.subplots(1, len(steps), figsize=(4*len(steps), 4), sharey=True)
    fig.suptitle('Density During Reverse Process')
    axes[0].set_ylabel('Density')
    for ax, step in zip(axes, steps):
        # Check that stats of denoised samples match params of analytical HK
        t = 1 - step/num_steps
        x_t = trajectories[step]
        sigma_t = diffuser.sigma_func(t)
        #assert torch.allclose(torch.std(x_t), sigma_t, atol=5e-2), \
        #    f'StDev of samples {torch.std(x_t).item():.4f} does not match marginal HK StDev {sigma_t:.4f}'
        
        # Display reverse process plots
        ax.set_title(f'$t = {t:.3f}$')
        ax.set_xlabel('$x$')
        hk = euclidean_heat_kernel(xs, t)
        hk /= hk.sum() * (xs[1] - xs[0])
        ax.plot(xs, hk, ls='--', color=cmap(t), label=f'Analytic HK')
        ax.hist(x_t, bins=50, color=cmap(t), alpha=0.55, density=True, label=f'Denoised samples')
        ax.legend(fontsize=7.5)
    fig.show()

_test_denoise()

## ${\rm SU}(N)$ Diffusion

We can do the same analytical noising-denoising procedure for an SU(2) variable on the group space in exact analogy to the previous results in Euclidean space.

In [None]:
def visualize_sun_hk():
    xs = torch.linspace(-np.pi, np.pi, 100).unsqueeze(-1)
    times = torch.linspace(0, 1, 10)

    fig, ax = plt.subplots(1, 1)
    cmap = mpl.colormaps.get_cmap('viridis')
    ax.set_xlabel(r'$\theta$')
    ax.set_ylabel('Density')
    ax.set_title('SU(2) Spetral Heat Kernel')
    for t in times:
        ax.plot(xs, sun_hk(xs, width=t**0.5, n_max=1), color=cmap(t))
    time_colors = mpl.cm.ScalarMappable(mpl.colors.Normalize(times[0], times[-1]))
    cbar = fig.colorbar(time_colors, cmap=cmap, ax=ax, label='$t$')
    fig.show()

visualize_sun_hk()

In [None]:
def test_sample_sun_hk():
    batch_size = 2048
    Nc = 2
    _U_0_re = torch.eye(Nc).repeat(batch_size, 1, 1)
    _U_0_im = torch.zeros((batch_size, Nc, Nc))
    U_0 = torch.complex(_U_0_re, _U_0_im)

    # Diffusion process
    sigma = 1.1
    diffuser = VarianceExpandingDiffusionSUN(sigma)

    times = [0.1, 0.5, 0.9, 1.0]
    #xs = torch.linspace(-np.pi, np.pi, 100).unsqueeze(-1)  
    xs = torch.linspace(0, np.pi, 100).unsqueeze(-1) # `sample_sun_hk` canonicalizes angles, so better to look at [0, pi]
    cmap = mpl.colormaps.get_cmap('viridis')

    # Compare samples from the heat kernel to the analytical heat kernel at each time
    fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
    fig.suptitle('Samples from the SU(2) Heat Kernel over Time')
    axes[0].set_ylabel('Density')
    for t, ax in zip(times, axes):
        sigma_t = diffuser.sigma_func(t)
        x_t = sample_sun_hk(batch_size, Nc=2, width=sigma_t.repeat(batch_size), n_iter=25)  # more IS iters for better sample quality at small t
        hk = sun_hk(xs, width=sigma_t, n_max=3, eig_meas=True)
        hk /= hk.sum() * (xs[1] - xs[0])
        ax.hist(x_t[:, 0], bins=50, density=True, color=cmap(t), alpha=0.55, label='HK samples')
        ax.plot(xs, hk, color=cmap(t), ls='--', label='Analytic SU(2) HK')
        ax.set_xlabel(r'$\theta$')
        ax.set_title(f'$t = {t}$')
        ax.legend()
    fig.show()

test_sample_sun_hk()

Simulate the forward diffusion process

In [None]:
def visualize_sun_fwd_diffusion():
    # Create initial data (2x2 identity matrices)
    batch_size = 4096
    Nc = 2
    _U_0_re = torch.eye(Nc).repeat(batch_size, 1, 1)
    _U_0_im = torch.zeros((batch_size, Nc, Nc))
    U_0 = torch.complex(_U_0_re, _U_0_im)

    # Diffusion process
    sigma = 1.1
    diffuser = VarianceExpandingDiffusionSUN(sigma)

    times = [0.1, 0.5, 0.75, 0.99]
    xs = torch.linspace(-np.pi, np.pi, 100).unsqueeze(-1)
    cmap = mpl.colormaps.get_cmap('viridis')

    fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
    fig.suptitle('SU(2) Angular Spectral Density over Forward Process')
    axes[0].set_ylabel('Density')
    for t, ax in zip (times, axes):
        # Forward diffusion
        U_t, X_t, V = diffuser.diffuse(U_0, t=t*torch.ones((batch_size,)), n_iter=25)  # diffuse from t=0 -> t=T
        thetas, _, _ = mat_angle(U_t)
        sigma_t = diffuser.sigma_func(t)

        # Analytical HK
        hk = sun_hk(xs, width=diffuser.sigma_func(t), n_max=5, eig_meas=True)
        hk /= hk.sum() * (xs[1] - xs[0])
        print(f'sigma({t}):', sigma_t.item())
        
        # Plot snapshots
        ax.set_title(f'$t = {t}$')
        ax.set_xlabel(r'$\theta$')
        ax.hist(thetas[:, 0], bins=50, density=True, color=cmap(t), alpha=0.65, label='Diffused Samples')
        ax.plot(xs, hk, color=cmap(t), ls='--', label='Analytical HK')
    fig.tight_layout()
    fig.show()

visualize_sun_fwd_diffusion()

Simulate the reverse (denoising) process using the analytically known score for the SU(N) heat kernel

In [None]:
def sample_sun_gaussian(shape):
    Nc, Nc_ = shape[-2:]
    assert Nc == Nc_
    return proj_to_algebra(torch.randn(shape) + 1j*torch.randn(shape))

In [None]:
def denoise_backwards(U_1, diffuser, num_steps=200, verbose=False, solve_type='SDE'):
    trajectories = []
    dt = 1 / num_steps
    t = 1.0
    U_t = U_1.clone()
    for step in tqdm.tqdm(range(num_steps)):
        # Eigendecompose
        x_t, V, V_inv = mat_angle(U_t)

        # Get SDE params
        sigma_t = diffuser.sigma_func(t)
        g_t = diffuser.noise_coeff(t)
        score = sun_score_hk(x_t[..., :-1], width=sigma_t)

        # Integration step in reverse time
        if solve_type == 'ODE':
            x_t = x_t + 0.5 * g_t**2 * score * dt  # ODE Euler step on spectra
            D = embed_diag(torch.exp(1j * x_t)).to(V)
            U_t = V @ D @ V_inv
        elif solve_type == 'SDE':
            x_t = x_t + g_t**2 * score * dt  # SDE Euler-Maruyama drift step on spectra
            D = embed_diag(torch.exp(1j * x_t)).to(V)
            U_t = V @ D @ V_inv
            U_t = matrix_exp(g_t * dt**0.5 * sample_sun_gaussian(U_t.shape)) @ U_t  # SDE noise step    
        else:
            raise NotImplementedError(f'Integration method {solve_type} not implemented')
        t -= dt

        # Collect and print metrics
        trajectories.append(U_t)
        if verbose:
            print(f'Step {step}/{num_steps} | trace(U) = {trace(U_t/Nc).mean().item():.6f}')
            print()

    return U_t, trajectories


def _test_denoise_sun():
    # Initial data: 2x2 identity matrices
    batch_size = 1024
    Nc = 2
    _U_0_re = torch.eye(Nc).repeat(batch_size, 1, 1)
    _U_0_im = torch.zeros((batch_size, Nc, Nc))
    U_0 = torch.complex(_U_0_re, _U_0_im)

    # Diffuse: U_0 -> U_1
    sigma = 1.1
    diffuser = VarianceExpandingDiffusionSUN(sigma)
    U_1, _, _ = diffuser.diffuse(U_0, torch.ones(batch_size), n_iter=25)
    
    # Denoise: U_1 -> U_0'
    U_0, trajectories = denoise_backwards(U_1, diffuser, num_steps=200, solve_type='ODE')
    print("Re[U_0'] =\n", grab(U_0.mean(0)).real)
    print("Im[U_0'] =\n", grab(U_0.mean(0)).imag)

    # Plot reverse trajectories
    times = [1.0, 0.75, 0.25, 0.05]
    cmap = mpl.colormaps.get_cmap('viridis')
    fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
    fig.suptitle('SU(2) Angular Spectral Density During Reverse Process')
    axes[0].set_ylabel('Density')
    xs = torch.linspace(-np.pi, np.pi, 100).unsqueeze(-1)
    for t, ax in zip(times, axes):
        # Histogram denoised samples
        step = int((1 - t) * len(trajectories))
        U_t = trajectories[step]
        x_t, _, _ = mat_angle(U_t)
        ax.hist(x_t[:, 0], bins=50, density=True, color=cmap(t), alpha=0.65, label='Denoised samples')
        
        # Analytical HK spectral density
        hk = sun_hk(xs, width=diffuser.sigma_func(t), n_max=3)
        hk /= hk.sum() * (xs[1] - xs[0])
        
        # Plot
        ax.plot(xs, hk, color=cmap(t), ls='--', label='Analytical HK')
        ax.set_title(f'$t = {t}$')
        ax.set_xlabel(r'$\theta$')
        ax.legend()
    fig.show()

_test_denoise_sun()