# Denoising using a trained Score Network for the ${\rm SU}(N)$ Heat Kernel

In this notebook, we will show how we can train a score model to perform the denoising which we showed analytically in the other notebook. The idea is that we can have a direct comparison between our ML results and the analytical results because we know the score function for the heat kernel exactly, but in more practical applications when we start with non-trivial 'training data' (i.e., give the heat equation initial conditions), we do not know the true score function for all time.

## Setup

In [None]:
# General imports
import numpy as np
import torch

import tqdm.auto as tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
# Imports from our repo
import sys
sys.path.insert(0, '..')  # repo source code

from src.linalg import trace, adjoint
from src.diffusion import VarianceExpandingDiffusion, VarianceExpandingDiffusionSUN
from src.sun import (
    proj_to_algebra, matrix_exp, matrix_log,
    random_sun_element, random_un_haar_element,
    group_to_coeffs, coeffs_to_group,
    extract_sun_algebra, embed_diag, mat_angle
)
from src.canon import canonicalize_sun
from src.heat import (
    eucl_score_hk,
    sun_score_hk, sample_sun_hk, sun_hk, _sun_hk_unwrapped
)
from src.utils import grab, wrap
from src.devices import set_device, get_device, summary

In [None]:
set_device('cuda', 1)
print(summary())

## Euclidean Case

In [None]:
def euclidean_heat_kernel(x, t, width=None):
    """Computes the Euclidean heat kernel density K(x, t) for `x` at time `t`."""
    if width is None:
        width = t**0.5  # assume unit constant diffusivity
    normalization = 1 / (2 * np.pi * width**2)**0.5
    weight = torch.exp(-x**2 / (2 * width**2))
    return normalization * weight

Build a score net

In [None]:
class ScoreNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2, 8),  # data & time = 1 + 1 dims
            torch.nn.SiLU(),
            torch.nn.Linear(8, 8),
            torch.nn.SiLU(),
            torch.nn.Linear(8, 8),
            torch.nn.SiLU(),
            torch.nn.Linear(8, 1))

    def forward(self, x_t, t):
        return self.net(torch.cat([x_t, t], dim=-1))


def _test_score_net():
    batch_size = 100
    x = torch.randn((batch_size, 1))
    t = torch.rand_like(x)
    s_t = ScoreNet()(x, t)
    print('x shape:', x.shape)
    print('t shape:', t.shape)
    print('s_t shape:', s_t.shape)
    assert s_t.shape == x.shape, \
        'Score output should have same shape as input'
    print('[PASSED]')

_test_score_net()

Define the score matching loss

In [None]:
def score_matching_loss(x_0, diffuser, score_net, tol=1e-5):
    t = torch.rand_like(x_0)
    t = tol + (1 - tol) * t  # stability near endpoints
    print('t:', grab(t))
    sigma_t = diffuser.sigma_func(t)
    
    x_t = diffuser(x_0, t.squeeze())
    #score = score_net(x_t, t)
    score = score_net(x_t, t) / sigma_t
    
    true_score = eucl_score_hk(x_t, width=sigma_t)
    print('score net norm:', torch.norm(score).item())
    print('true score norm:', torch.norm(true_score).item())

    #return torch.mean((score - true_score)**2)
    return torch.mean(sigma_t**2 * (score - true_score)**2)


def score_matching_loss(x_0, diffuser, score_net, tol=1e-4):
    t = torch.rand_like(x_0)
    t = tol + (1 - tol) * t  # stability near endpoints
    sigma_t = diffuser.sigma_func(t)
    
    # x_0 -> x_t, get s(x_t, t)
    x_t = diffuser(x_0, t.squeeze())
    #score = score_net(x_t, t)
    score = score_net(x_t, t) / sigma_t
    
    # s(x_t, t) should approximate grad log N(x_t; x_0, sigma_t^2)
    #true_score = eucl_score_hk(x_t, width=sigma_t)
    #eps = -true_score * sigma_t**2
    eps = (x_t - x_0) / sigma_t
    return torch.mean((sigma_t * score + eps)**2)  # weight factor of sigma(t)^2 for stability

In [None]:
# Do the training
sigma = 1.1
score_net = ScoreNet()
diffuser = VarianceExpandingDiffusion(sigma)
lr = 1e-3
epochs = 1000
batch_size = 1024
optimizer = torch.optim.Adam(params=score_net.parameters(), lr=lr)

x_0 = torch.zeros((batch_size, 1))
losses = []
for epoch in tqdm.tqdm(range(epochs)):
    optimizer.zero_grad()
    loss = score_matching_loss(x_0, diffuser, score_net)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch}/{epochs} | Loss = {loss.item():.6f}')
    losses.append(loss.item())

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.plot(losses, lw=0.75)
fig.show()

In [None]:
@torch.no_grad()
def euler_sampler(x_1, score_net, diffuser, num_steps=200, solver_type='ODE', verbose=False):
    score_net.eval()
    batch_size = x_1.size(0)
    
    trajectories = []
    dt = 1 / num_steps
    x_t = x_1.clone()
    t = 1.0
    for step in tqdm.tqdm(range(num_steps)):
        # Get ODE / SDE params
        sigma_t = diffuser.sigma_func(t)
        g_t = diffuser.noise_coeff(t)
        score = score_net(x_t, torch.tensor(t).repeat(batch_size, 1)) / sigma_t

        # Integration step
        if solver_type == 'ODE':
            x_t = x_t + 0.5 * g_t**2 * score * dt  # ODE Euler step
        elif solver_type == 'SDE':
            x_t = x_t + g_t**2 * score * dt + g_t * torch.rand_like(x_t) * dt**0.5  # SDE step
        else:
            raise NotImplementedError(f'Integration method {solver_type} not supported')
        t -= dt

        # Collect and print metrics
        trajectories.append(x_t)
        if verbose:
            print(f'Step {step}/{num_steps} | x_t = {x_t.mean().item():.6f}')
    return x_t, trajectories

In [None]:
def _test_euclidean_denoising():
    # Initial data: zeros
    batch_size = 4096
    x_0 = torch.zeros((batch_size, 1))
    
    # Diffuse forward: x_0 -> x_1
    sigma = 1.1
    diffuser = VarianceExpandingDiffusion(sigma)
    x_1 = diffuser(x_0, t=torch.ones((batch_size,)))

    # Denoise backward: x_1 -> x_0'
    num_steps = 100
    x_0, trajectories = euler_sampler(x_1, score_net, diffuser, num_steps, solver_type='ODE')
    print("x_0':", grab(x_0.mean().item()))

    # Plot trajectories
    times = [1.0, 0.75, 0.5, 0.25, 0.05]
    xs = torch.linspace(-5, 5, 100)
    cmap = mpl.colormaps.get_cmap('viridis')

    fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
    fig.suptitle('Euclidean Reverse Denoising Process')
    axes[0].set_ylabel('Density')
    for t, ax in zip(times, axes):
        # Denoised samples
        x_t = trajectories[int(num_steps * (1 - t))]
        ax.hist(grab(x_t), bins=50, density=True, color=cmap(t), alpha=0.65, label='Denoised samples')

        # Analytical heat kernel
        hk = euclidean_heat_kernel(xs, t, width=diffuser.sigma_func(t))
        hk /= hk.sum() * (xs[1] - xs[0])
        ax.plot(grab(xs), grab(hk), color=cmap(t), ls='--', label='Analytic HK')
        ax.set_xlabel(r'$x_t$')
        ax.set_title(f'$t = {t}$')
        ax.legend()
    fig.show()


_test_euclidean_denoising()

## ${\rm SU}(2)$ case

In [None]:
class SU2ScoreNet(torch.nn.Module):
    def __init__(self, input_dim=2, hidden_dim=8):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),  # 1 eigenangle + time = 2
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, 1))
        
    def forward(self, x_t, t):
        assert len(x_t.shape) == 2, \
            'input eigenangles shape should be [batch_size, Nc-1]'
        assert len(t.shape) == 1, \
            'times should only have a batch dimension'
        inp = torch.cat([x_t, t.unsqueeze(-1)], dim=-1)
        return self.net(inp)


def _test_su2_score_net():
    batch_size = 10
    Nc = 2
    x = 2*np.pi*torch.rand((batch_size, 1)) - np.pi
    t = torch.rand((batch_size,))
    s_t = SU2ScoreNet()(x, t)

    assert s_t.shape == x.shape, \
        '[FAILED: score net output must have same shape as input data]'
    print('[PASSED]')

_test_su2_score_net()

In [None]:
def score_matching_loss_sun(U_0, diffuser, score_net, tol=1e-5):
    print()
    batch_size = U_0.size(0)
    #t = torch.rand((batch_size,))
    t = torch.rand((1,)) * torch.ones((batch_size,))
    t = tol + (1 - tol) * t  # stability near endpoints
    #t = 0.05 * torch.ones((batch_size,))  # fix t to 0.5 for testing
    print('t:', grab(t))
    sigma_t = diffuser.sigma_func(t)
    
    U_t, x_t, V = diffuser.diffuse(U_0, t, n_iter=20)
    x_t = x_t.to(dtype=t.dtype)  # TODO: fix dtype issue
    print('x_t shape:', x_t.shape)
    score = score_net(x_t[..., :-1], t) / sigma_t.unsqueeze(-1)  # only gives one angle
    print('score / sigma_t shape:', score.shape)
    
    # s(x_t, t) should approximate grad log N(x_t; x_0, sigma_t^2)
    #true_score = sun_score_hk(x_t[..., :-1], width=diffuser.sigma_func(0.5))  # gives 2 angles
    true_score = sun_score_hk(x_t[..., :-1], width=diffuser.sigma_func(t))[..., :-1]
    print('true score shape:', true_score.shape)
    print('score net norm:', torch.norm(score).item())
    print('true score norm:', torch.norm(true_score).item())
    #loss = torch.mean((score - true_score)**2)
    loss = torch.mean(sigma_t.unsqueeze(-1)**2 * (score - true_score)**2)
    if loss.item() > 100:
        print('=============================================================')
        print('x_t:', x_t)
        print('true_score:', true_score)
        print('score net:', score)
        print('=============================================================')
        #loss = torch.tensor(0.0, requires_grad=True)
    return loss

In [None]:
# Make diffusion process
sigma = 1.1
score_net = SU2ScoreNet(input_dim=2, hidden_dim=64)
diffuser = VarianceExpandingDiffusionSUN(sigma)

# Make "training data"
batch_size = 16
Nc = 2
_U_0_re = torch.eye(Nc).repeat(batch_size, 1, 1)
_U_0_im = torch.zeros((batch_size, Nc, Nc))
U_0 = torch.complex(_U_0_re, _U_0_im)

# Setup training hyperparams
lr = 2e-4
epochs = 1000
optimizer = torch.optim.Adam(params=score_net.parameters(), lr=lr)

# Training loop
losses = []
score_net.train()
for epoch in tqdm.tqdm(range(epochs)):
    optimizer.zero_grad()
    loss = score_matching_loss_sun(U_0, diffuser, score_net)
    #if loss.item() > 100:  # for now: skip the unstable outliers
    #    continue
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch}/{epochs} | Loss = {loss.item():.6f}')
    losses.append(loss.item())

## TODO: gradient clipping

In [None]:
# Plot training loss vs epochs
fig, ax = plt.subplots(1, 1)
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.set_yscale('log')
ax.plot(losses, lw=0.75)
fig.show()

Run the denoising process

In [None]:
def sample_sun_gaussian(shape):
    Nc, Nc_ = shape[-2:]
    assert Nc == Nc_
    return proj_to_algebra(torch.randn(shape) + 1j*torch.randn(shape))

In [None]:
def denoise_backwards(U_1, score_net, diffuser, num_steps=200, verbose=False, solve_type='SDE'):
    trajectories = []
    dt = 1 / num_steps
    t = 1.0
    U_t = U_1.clone()
    for step in tqdm.tqdm(range(num_steps)):
        # Eigendecompose
        x_t, V, V_inv = mat_angle(U_t)

        # Get SDE params
        sigma_t = diffuser.sigma_func(t)
        g_t = diffuser.noise_coeff(t)
        score = score_net(x_t[..., :-1], t*torch.ones((U_t.size(0),))) / sigma_t

        # Integration step in reverse time
        if solve_type == 'ODE':
            x_t = x_t + 0.5 * g_t**2 * score * dt  # ODE Euler step on spectra
            D = embed_diag(torch.exp(1j * x_t)).to(V)
            U_t = V @ D @ V_inv
        elif solve_type == 'SDE':
            x_t = x_t + g_t**2 * score * dt  # SDE Euler-Maruyama drift step on spectra
            D = embed_diag(torch.exp(1j * x_t)).to(V)
            U_t = V @ D @ V_inv
            U_t = matrix_exp(g_t * dt**0.5 * sample_sun_gaussian(U_t.shape)) @ U_t  # SDE noise step    
        else:
            raise NotImplementedError(f'Integration method {solve_type} not implemented')
        t -= dt

        # Collect and print metrics
        trajectories.append(U_t)
        if verbose:
            print(f'Step {step}/{num_steps} | trace(U) = {trace(U_t/Nc).mean().item():.6f}')
            print()

    return U_t, trajectories

In [None]:
# Initial data: 2x2 identity matrices
batch_size = 4096
Nc = 2
_U_0_re = torch.eye(Nc).repeat(batch_size, 1, 1)
_U_0_im = torch.zeros((batch_size, Nc, Nc))
U_0 = torch.complex(_U_0_re, _U_0_im)

# Diffuse: U_0 -> U_1
U_1, _, _ = diffuser.diffuse(U_0, torch.ones(batch_size), n_iter=25)

# Denoise: U_1 -> U_0'
U_0, trajectories = denoise_backwards(U_1, score_net, diffuser, num_steps=500, solve_type='ODE')
print("Re[U_0'] =\n", grab(U_0.mean(0)).real)
print("Im[U_0'] =\n", grab(U_0.mean(0)).imag)

# Plot reverse trajectories
times = [1.0, 0.75, 0.5, 0.25, 0.15, 0.05]
cmap = mpl.colormaps.get_cmap('viridis')
fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
fig.suptitle('SU(2) Angular Spectral Density During Reverse Process')
axes[0].set_ylabel('Density')
xs = torch.linspace(-np.pi, np.pi, 100).unsqueeze(-1)
for t, ax in zip(times, axes):
    # Histogram denoised samples
    step = int((1 - t) * len(trajectories))
    U_t = trajectories[step]
    x_t, _, _ = mat_angle(U_t)
    ax.hist(grab(x_t[:, 0]), bins=50, density=True, color=cmap(t), alpha=0.65, label='Denoised samples')
    
    # Analytical HK spectral density
    hk = sun_hk(xs, width=diffuser.sigma_func(t), n_max=3)
    hk /= hk.sum() * (xs[1] - xs[0])
    
    # Plot
    ax.plot(grab(xs), grab(hk), color=cmap(t), ls='--', label='Analytical HK')
    ax.set_title(f'$t = {t}$')
    ax.set_xlabel(r'$\theta$')
    ax.legend()
fig.show()

In [None]:
U_0 @ adjoint(U_0)

In [None]:
torch.linalg.det(U_0)*torch.linalg.det(U_0).conj()