# Spectral Diffusion for a Toy ${\rm SU}(2)$ Model using Group-Valued Score Matching

In this notebook, we apply the same machinery we used in the other notebook (where we trained a score network to learn the heat kernel and denoise the variance-expanding diffusion process) to learn a toy theory involving a single ${\rm SU}(2)$ matrix $U$ (one independent eigenangle $\theta$).

## Setup

In [None]:
# General imports
import math
import numpy as np
import torch

import tqdm.auto as tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
# Imports from our repo
import sys
sys.path.insert(0, '..')  # repo source code
import analysis as al

from src.linalg import trace, adjoint
from src.diffusion import VarianceExpandingDiffusion, VarianceExpandingDiffusionSUN, PowerDiffusionSUN
from src.sun import (
    proj_to_algebra, matrix_exp, matrix_log,
    random_sun_element, random_un_haar_element,
    group_to_coeffs, coeffs_to_group,
    extract_sun_algebra, embed_diag, mat_angle, extract_diag
)
from src.canon import canonicalize_sun
from src.heat import (
    #eucl_score_hk,
    log_sun_hk, sun_hk,
    sun_score_hk_old, sun_score_hk, 
    sample_sun_hk
)
from src.utils import grab, wrap
from src.devices import set_device, get_device, summary
from src.integrate import estimate_divergence

In [None]:
import src
import importlib
importlib.reload(src.heat)
importlib.reload(src.diffusion)

In [None]:
set_device('cpu', 0)
torch.set_default_dtype(torch.float64)
print(summary())

## Define a Target Theory and Generate Training Data

The target theory whose distribution we will try to reproduce is specified by a toy action of our choosing, which we define to be 
$$S[U] = -\beta \; {\rm Re}{\rm Tr}(U).$$

In [None]:
from src.action import SUNToyAction

# Instantiate and test a toy SU(2) matrix action
def _test_action():
    batch_size = 3
    Nc = 2
    U = random_sun_element(batch_size, Nc=2)

    action = SUNToyAction(beta=1.0)
    print('Action evaluated on configs:', grab(action(U)))

_test_action()

In [None]:
class SUNToyPolynomialAction:
    def __init__(self, beta, coeffs=[1.0, 0.0, 0.0]):
        self.beta = beta
        self.coeffs = coeffs

    def __call__(self, U):
        Nc = U.size(-1)
        action_density = 0
        for i, c in enumerate(self.coeffs):
            action_density += c * torch.matrix_power(U, i+1)
        return -self.beta * trace(action_density).real / Nc

    def value_eigs(self, th):
        Nc = th.size(-1)
        S = 0
        for i, c in enumerate(self.coeffs):
            S += c * torch.cos((i+1)*th).sum(-1) / Nc
        return -self.beta * S


def _test_toy_polynomial_action():
    batch_size = 5
    Nc = 2
    U = random_sun_element(batch_size, Nc=2)
    ths, _, _ = mat_angle(U)

    beta = 1.0
    coeffs = [1.0, 1.0, 1.0]
    action = SUNToyPolynomialAction(beta, coeffs)

    S = action(U)
    S2 = action.value_eigs(ths)
    assert torch.allclose(S, S2)
    print('Action evaluated on cfgs:', grab(S), grab(S2))

_test_toy_polynomial_action()    

To generate configurations, we will use the Metropolis algorithm for simplicity.

In [None]:
def apply_metropolis(batch_size, Nc, action, num_therm, num_iters, step_size, save_freq=10):
    """Batched Metropolis sampler."""
    action_vals = []
    accept_rates = []
    
    U = random_sun_element(batch_size, Nc=Nc)
    ens = []
    for i in tqdm.tqdm(range(-num_therm, num_iters)):
        # Proposal
        V = random_sun_element(batch_size, Nc=Nc, sigma=step_size)
        Up = V @ U
        dS = action(Up) - action(U)

        # Accept / Reject
        r = torch.rand((batch_size,))  # accept w/ prob = exp(-dS)
        accept_mask = (r < torch.exp(-dS))[:, None, None]
        U = torch.where(accept_mask, Up, U)

        action_vals.append(grab(action(U).mean()))
        accept_rates.append(grab(torch.sum(accept_mask) / batch_size))
        if i >= 0 and (i+1) % save_freq == 0:
            ens.append(U)
    return torch.cat(ens), action_vals, accept_rates

In [None]:
# Define physical theory
beta = -1.0
#action = SUNToyAction(beta)
#action = SUNToyPolynomialAction(beta, [1.0, 0., 0.])
action = SUNToyPolynomialAction(beta, [0.17, -0.65, 1.22])
#action = SUNToyPolynomialAction(beta, [0.98, -0.63, -0.21])


# Generate samples
batch_size = 128
num_therm = 1_000
num_iters = 10_000
save_freq = 10
step_size = 0.9

num_train = batch_size * (num_iters // save_freq)
print(f'{num_train=}')

U_train, action_vals, accept_rates = apply_metropolis(
    batch_size = batch_size,
    Nc = 2,
    action = action,
    num_therm = num_therm,
    num_iters = num_iters,
    step_size = step_size
)

In [None]:
# Visualize Metropolis
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax in axes:
    ax.set_xlabel('Metropolis Steps')

axes[0].plot(*al.bin_data(action_vals, binsize=100))
axes[0].set_ylabel('Average Action')

axes[1].plot(*al.bin_data(accept_rates, binsize=100))
axes[1].set_ylabel('Acceptance Rate')

fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
ax.hist(grab(mat_angle(U_train[::10])[0][:, 0]), bins=50, density=True)
ths = torch.linspace(-np.pi, np.pi, steps=101)
dth = grab(ths[1]-ths[0])
ths = torch.stack([ths, -ths], dim=-1)
ps = grab((-action.value_eigs(ths)).exp()) * np.sin(grab(ths[:,0]))**2
ps /= np.sum(ps, axis=-1, keepdims=True) * dth
print(f'{ps.shape=}')
ax.plot(grab(ths), ps, color='k', linestyle='--')
ax.set_xlabel(r'$\theta$')
ax.set_ylabel('Density')
#ax.set_yscale('log')
plt.show()

In [None]:
sigma = 3.0
# diffuser = VarianceExpandingDiffusionSUN(sigma)
diffuser = PowerDiffusionSUN(sigma, alpha=1)

times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
cmap = mpl.colormaps.get_cmap('viridis')
fig, axes = plt.subplots(1, len(times), figsize=(4*len(times), 4), sharey=True)
axes[0].set_ylabel('Density')
#axes[0].set_yscale('log')

bins = np.linspace(-np.pi, np.pi, num=51)

U_0 = U_train.clone()[::10]
for t, ax in zip(times, axes):
    if t == 0:  # avoid sampling from HK at t=0
        x_t, _, _ = mat_angle(U_0)
    else:
        U_t, _, _ = diffuser.diffuse(U_0, t*torch.ones(U_0.size(0)), n_iter=20)
        x_t, _, _ = mat_angle(U_t)
    ax.hist(grab(x_t[:, 0]), bins=bins, density=True, color=cmap(t))
    ax.set_xlabel(r'$\theta_t$')
    ax.set_title(f'$t = {t}$')

# also save histograms for future comparisons
true_times = [1.0, 0.75, 0.5, 0.25, 0.15, 0.10, 0.05, 0.01]
true_hists = []
for t in tqdm.tqdm(true_times):
    hist_xt = np.zeros(len(bins)-1, dtype=np.float64)
    for chunk_U in torch.chunk(U_0, 10):
        U_t, _, _ = diffuser.diffuse(chunk_U, t*torch.ones(chunk_U.size(0)), n_iter=100)
        x_t, _, _ = mat_angle(U_t)
        hist_xt += np.histogram(grab(x_t[:, 0]), bins=bins)[0]
    hist_xt /= U_0.size(0) * (bins[1]-bins[0])
    true_hists.append((bins, hist_xt))

# Plot uniform SU(2) Haar measure for comparison
xs = torch.linspace(-np.pi, np.pi, 100)
haar = (1 / np.pi) * torch.sin(xs)**2
axes[-1].plot(grab(xs), grab(haar), ls='--', color='red', label='Haar Uniform')
axes[-1].legend(frameon=False)

plt.show()

## Train a Score Network

Now we must construct a score network that will take as input the eigenangle $\theta$ and time $t$.

In [None]:
class SU2ScoreNet(torch.nn.Module):
    def __init__(self, input_dim=3, hidden_dim=8):
        super().__init__()
        assert input_dim % 2 == 1
        self.nk = (input_dim - 1)//2
        self.net = torch.nn.Sequential(
            # input_dim = 1 eigenangle + encoded time
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, 1))
        
    def forward(self, x_t, t):
        assert len(x_t.shape) == 2, \
            'input eigenangles shape should be [batch_size, Nc-1]'
        assert len(t.shape) == 1, \
            'times should only have a batch dimension'
        cos_t = torch.cos(torch.exp(-10.0*(torch.arange(1,self.nk+1)/self.nk - 0.5))*t.unsqueeze(-1))
        sin_t = torch.sin(torch.exp(-10.0*(torch.arange(1,self.nk+1)/self.nk - 0.5))*t.unsqueeze(-1))
        inp = torch.cat([x_t, cos_t, sin_t], dim=-1)
        return self.net(inp) * torch.sin(x_t) # enforce score -> 0 at endpoints


def _test_su2_score_net():
    batch_size = 10
    Nc = 2
    x = 2*np.pi*torch.rand((batch_size, 1)) - np.pi
    t = torch.rand((batch_size,))
    s_t = SU2ScoreNet()(x, t)

    assert s_t.shape == x.shape, \
        '[FAILED: score net output must have same shape as input data]'
    print('[PASSED]')

_test_su2_score_net()

In [None]:
def score_matching_loss_sun(U_0, diffuser, score_net, tol=1e-5):
    batch_size = U_0.size(0)
    t = tol + (1 - tol) * torch.rand((batch_size,))
    # t = torch.zeros((batch_size,))
    # t.exponential_(3) # exponential concentration for t
    # t.clamp_(max=1.0)
    sigma_t = diffuser.sigma_func(t)

    U_t, xs, V = diffuser.diffuse(U_0, t, n_iter=50)
    # x_0, _, _ = mat_angle(U_0)
    x_t, P, Pinv = mat_angle(U_t)
    x_t = x_t.to(dtype=t.dtype)

    score = score_net(x_t[..., :-1], t)
    # NOTE(gkanwar): Passing xs instead of (x_t - x_0) is important here
    # NOTE(gkanwar): Scaling by sigma_t**2 results in a roughly scale-invariant true score,
    # making training much more stable.
    true_score_xs = sigma_t.unsqueeze(-1)**2 * sun_score_hk(xs[..., :-1], width=sigma_t)
    true_score = extract_diag(Pinv @ V @ embed_diag(true_score_xs).to(V) @ adjoint(V) @ P).real[...,:-1]

    diff = score - true_score
    loss = torch.mean(diff**2)

    # if torch.isnan(loss) or loss.item() > 100:
    #     print('WARNING:', 'NaN' if torch.isnan(loss) else f'Loss blow-up! Loss = {loss.item()}')
    #     big_inds = torch.nonzero(torch.abs(diff) > 100)
    #     if big_inds.numel() > 0:
    #         for idx in big_inds:
    #             b, j = idx.tolist()
    #             print(f'  sample {b}, angle {j}:')
    #             print(f'    t = {t[b].item():.6f}')
    #             print(f'    x_t - x_0 = {wrap((x_t - x_0)[b, j]).item():.6f}')
    #             print(f'    score = {score[b, j].item():.6f}')
    #             print(f'    true_score = {true_score[b, j].item():.6f}')
    #             print(f'    diff = {diff[b, j].item():.6f}')
    #             print(f'    K(x_t, t) = {sun_hk(x_t, width=sigma_t)[b]}')
    #     else:
    #         print('  (no large diffs found)')
    #     print()

    return loss

In [None]:
from torch.utils.data import DataLoader, TensorDataset

# Make diffusion process
# sigma = 1.1
sigma = 3.0
score_net = SU2ScoreNet(input_dim=51, hidden_dim=64)
# diffuser = VarianceExpandingDiffusionSUN(sigma)
diffuser = PowerDiffusionSUN(sigma, alpha=1)

# Setup training hyperparams
lr = 1e-4
epochs = 100
optimizer = torch.optim.Adam(params=score_net.parameters(), lr=lr)

# Prepare dataloader
batch_size = 1024
# x_train, _, _ = mat_angle(U_train)
dataset = TensorDataset(U_train) 
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=get_device()))

# Training loop
losses = []
score_net.train()
for epoch in tqdm.tqdm(range(epochs)):
    total_loss = 0.0
    for batch_idx, (U_0_batch,) in enumerate(dataloader):
        optimizer.zero_grad()
        # loss = score_matching_loss_eigs(x_0_batch, diffuser, score_net)
        loss = score_matching_loss_sun(U_0_batch, diffuser, score_net)
        # if loss.item() > 100:  # for now: skip the unstable outliers
        #     continue
        # if torch.isnan(loss):  # for now: skip NaNs
        #     continue
        torch.nn.utils.clip_grad_norm_(score_net.parameters(), max_norm=10.0)
        loss.backward()
        optimizer.step()
        total_loss += grab(loss)
        losses.append(grab(loss))
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch}/{epochs} | Loss = {avg_loss:.6f}')
    # losses.append(avg_loss)

In [None]:
# Plot training loss vs epochs
fig, ax = plt.subplots(1, 1, figsize=(6, 2.5))
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
#ax.set_yscale('log')
xs, ys = al.bin_data(losses, binsize=100)
xs = xs / (num_train/batch_size) # convert to epochs
ax.plot(xs, ys, lw=0.75)
fig.set_tight_layout(True)
plt.show()

In [None]:
# Check evaluated score vs noisy
def _check_score():
    U_0 = U_train.clone()
    batch_size = U_0.size(0)

    bins = np.linspace(-np.pi, np.pi, num=40)

    ts = [0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
    fig, axes = plt.subplots(2, 3, sharey=True)
    for t_sc, ax in zip(tqdm.tqdm(ts), axes.flatten()):
        # t = tol + (1 - tol) * torch.rand((batch_size,))

        true_chunks = []
        true_counts_chunks = []
        for U_chunk in tqdm.tqdm(torch.chunk(U_0, 50), leave=False):
            true = np.zeros(len(bins)-1, dtype=np.float64)
            true_counts = np.zeros(len(bins)-1)
            for _ in range(1):
                t = t_sc * torch.ones((U_chunk.size(0),))
                sigma_t = diffuser.sigma_func(t)
                U_t, xs, V = diffuser.diffuse(U_chunk, t, n_iter=50)
                # x_0, _, _ = mat_angle(U_0)
                x_t, P, Pinv = mat_angle(U_t)
                x_t = x_t.to(dtype=t.dtype)

                true_score_xs = sigma_t.unsqueeze(-1)**2 * sun_score_hk(xs[..., :-1], width=sigma_t)
                true_score = extract_diag(Pinv @ V @ embed_diag(true_score_xs).to(V) @ adjoint(V) @ P).real[...,:-1]
                true += np.histogram(grab(x_t[...,0]), weights=grab(true_score[...,0]), bins=bins)[0]
                true_counts += np.histogram(grab(x_t[...,0]), bins=bins)[0]
            true_chunks.append(true)
            true_counts_chunks.append(true_counts)

        # NN score
        xs_plot = np.linspace(-np.pi, np.pi, num=101)
        score = score_net(torch.tensor(xs_plot).unsqueeze(-1), t_sc * torch.ones((len(xs_plot),)))
        ax.plot(xs_plot, grab(score[...,0]))

        # est of true score
        true_chunks = np.stack(true_chunks)
        true_counts_chunks = np.stack(true_counts_chunks)
        est_true = al.bootstrap(
            true_chunks, true_counts_chunks, Nboot=1000,
            f=lambda x,c: np.sum(x, axis=0)/np.sum(c, axis=0))
        xs = (bins[1:]+bins[:-1])/2
        # ax.plot(xs, true/true_counts, color='k')
        al.add_errorbar(est_true, ax=ax, xs=xs, color='k', linestyle='', marker='.')
        # ax.plot(grab(x_t[..., 0]), grab(true_score[...,0]), marker='x', linestyle='')

        ax.set_title(rf'$t = {t_sc:.02f}$')
    fig.set_tight_layout(True)
    fig.suptitle('Learned score vs data score estimates')
    plt.show()
_check_score()

## Denoising Process

Now we run the reverse process to generate new samples.

In [None]:
def sample_sun_gaussian(shape):
    Nc, Nc_ = shape[-2:]
    assert Nc == Nc_
    return proj_to_algebra(torch.randn(shape) + 1j*torch.randn(shape))

In [None]:
def compute_ess(logp, logq):
    logw = logp - logq
    log_ess = 2*torch.logsumexp(logw, dim=0) - torch.logsumexp(2*logw, dim=0)
    return torch.exp(log_ess) / len(logw)
def np_compute_ess(logp, logq):
    logw = logp - logq
    log_ess = 2*np.logaddexp.reduce(logw, axis=0) - np.logaddexp.reduce(2*logw, axis=0)
    return np.exp(log_ess) / len(logw)

def compute_kl_div(logp, logq):
    # OV: reverse KL divergence since samples drawn from model q
    return torch.mean(logq - logp)
def np_compute_kl_div(logp, logq):
    return np.mean(logq - logp)

In [None]:
@torch.no_grad()
def solve_reverse_ODE_eigs(U_1, logr, score_net, diffuser, num_steps=200, verbose=False):
    trajectories = {
        'U_t': [],
        'logp': [],
        'logq': [],
        'kl_div': [],
        'ess': [],
        't': [],
        'Z': [],
    }
    dt = 1 / num_steps
    t = 1.0
    batch_size = U_1.size(0)
    x_1, V, V_inv = mat_angle(U_1)
    x_t = x_1.clone()

    logJ = 0.
    for step in tqdm.tqdm(range(num_steps)):
        # Get ODE params
        sigma_t = diffuser.sigma_func(t)
        g_t = diffuser.noise_coeff(t)
        score = score_net(x_t[..., :-1], t*torch.ones((batch_size,))) / sigma_t**2

        # Skilling-Hutchinson Divergence Estimation
        # func = lambda x: score_net(x, t*torch.ones((batch_size,))) / sigma_t**2
        # div = estimate_divergence(func, x_t[..., :-1], num_estimates=10)
        jac = torch.func.vmap(torch.func.jacfwd(lambda x: score_net(x[None], torch.tensor(t)[None])[0]))(x_t[..., :-1])
        div = torch.einsum('...ii->...', jac) / sigma_t**2
        
        # Integration step in reverse time
        score = torch.cat([score, -score.sum(-1, keepdim=True)], dim=-1)
        x_t = x_t + 0.5 * g_t**2 * score * dt
        # x_t = x_t + g_t**2 * score * dt
        logJ = logJ + 0.5 * g_t**2 * div * dt

        # Eigen-recomposition
        D = embed_diag(torch.exp(1j * x_t)).to(V)
        U_t = V @ D @ V_inv
        # FORNOW: SDE
        # U_t = random_sun_matrix(U_t.size(0), Nc=Nc, sigma=g_t * dt**0.5) @ U_t
        # x_t, V, Vinv = mat_angle(U_t)
        t -= dt

        # Collect and print metrics
        logp = -action(U_t) + log_haar_su2(x_t)
        logq = logr - logJ
        Z = al.bootstrap(grab((logp - logq).exp()), Nboot=1000, f=al.rmean)
        kl_div = al.bootstrap(grab(logp), grab(logq), Nboot=1000, f=np_compute_kl_div)
        ess = al.bootstrap(grab(logp), grab(logq), Nboot=1000, f=np_compute_ess)
        if verbose:
            print(f'Step {step}/{num_steps}')
            print('logp =', logp.mean().item())
            print('logq =', logq.mean().item())
            print('Dkl =', kl_div)
            print('ESS =', ess)
            print()
        trajectories['t'].append(t)
        trajectories['U_t'].append(U_t)
        trajectories['logp'].append(al.bootstrap(grab(logp), Nboot=1000, f=al.rmean))
        trajectories['logq'].append(al.bootstrap(grab(logq), Nboot=1000, f=al.rmean))
        trajectories['kl_div'].append(kl_div)
        trajectories['ess'].append(ess)
        trajectories['Z'].append(Z)

    for key in ['Z', 'kl_div', 'ess', 'logp', 'logq']:
        trajectories[key] = np.stack(trajectories[key], axis=1)
    
    return U_t, logJ, trajectories

In [None]:
def log_haar_su2(x):
    """Computes log likelihood of SU(2) Haar uniform density."""
    x = x[:, 0]
    log_sin2 = 2*torch.log(torch.abs(torch.sin(x)))
    #log_norm = math.log(2 * np.pi**2)
    log_norm = math.log(np.pi)
    return log_sin2 - log_norm

In [None]:
# Initial data: 2x2 identity matrices
batch_size = 4096
Nc = 2
# U_0 = U_train.clone()

# Diffuse: U_0 -> U_1
# U_1, _, _ = diffuser.diffuse(U_0, torch.ones((U_0.size(0),)), n_iter=25)
# U_1 ~ Haar
U_1 = random_un_haar_element(batch_size, Nc=Nc)
U_1 *= (torch.linalg.det(U_1)**(-1/Nc) * torch.exp(2j*np.pi*torch.randint(Nc, size=(batch_size,))/Nc))[...,None,None]

In [None]:
# Get prior log likelihood
x_1, _, _ = mat_angle(U_1)
logr = log_haar_su2(x_1)
#logr = log_sun_hk(x_1, width=diffuser.sigma_func(1.0), eig_meas=False)
print('avg logr =', logr.mean().item())
print('std logr =', logr.std().item())

In [None]:
U_0, logJ, history = solve_reverse_ODE_eigs(U_1, logr, score_net, diffuser, num_steps=200, verbose=True)
x_0, _, _ = mat_angle(U_0)

In [None]:
# Plot reverse trajectories
# times = [1.0, 0.75, 0.5, 0.25, 0.15, 0.01]
times = true_times
cmap = mpl.colormaps.get_cmap('viridis')
fig, axes = plt.subplots(1, len(times), figsize=(2*len(times), 4), sharey=True)
fig.suptitle('SU(2) Angular Spectral Density During Reverse Process')
axes[0].set_ylabel('Density')
xs = torch.linspace(-np.pi, np.pi, 100).unsqueeze(-1)
for t, hist, ax in zip(times, true_hists, axes):
    # Histogram denoised samples
    step = int((1 - t) * len(history['U_t']))
    U_t = history['U_t'][step]
    x_t, _, _ = mat_angle(U_t)
    ax.hist(grab(x_t[:, 0]), bins=50, density=True, color=cmap(t), alpha=0.65, label='Denoised samples')
    # (hist[1][1:]+hist[1][:-1])/2, hist[0], width=hist[1][1]-hist[1][0]
    ax.stairs(hist[1], hist[0], ec='red', color='none', label='Target Samples') # width=hist[1][1]-hist[1][0]
    # ax.hist(grab(mat_angle(U_train)[0][:, 0]), bins=50, histtype='step', density=True, color='red', label='Target Samples')
    # integrate score to get model distribution
    sigma_t = max(0.1, diffuser.sigma_func(t))
    bx = score_net(xs, t*torch.ones((xs.size(0),)))[...,0] / sigma_t**2
    logp_hat = torch.cumsum(bx, dim=-1) * (xs[1,0]-xs[0,0])
    p = grab((logp_hat + log_haar_su2(xs)).exp())
    p /= np.sum(p)*(xs[1,0]-xs[0,0])
    ax.plot(grab(xs), grab(p), color='k')
    # Plot
    ax.set_title(f'$t = {t}$')
    ax.set_xlabel(r'$\theta$')
    ax.legend() 
plt.show()

In [None]:
def _measure_Z():
    th = torch.linspace(-np.pi, np.pi, steps=501)
    x = torch.stack([th, -th], axis=-1)
    U = embed_diag((1j*x).exp())
    est_Z = (th[1]-th[0])*(-action(U) + log_haar_su2(x)).exp().sum()
    return est_Z
true_Z = _measure_Z()

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(16, 4))

axes[0].errorbar(history['t'], history['logp'][0], yerr=history['logp'][1])
axes[0].set_ylabel(r'$\log p$')

axes[1].errorbar(history['t'], history['logq'][0], yerr=history['logq'][1])
axes[1].set_ylabel(r'$\log q$')

axes[2].errorbar(history['t'], history['kl_div'][0], yerr=history['kl_div'][1])
axes[2].set_ylabel(r'Reverse KL-divergence $D_{\rm KL}(q || p)$')

axes[3].errorbar(history['t'], history['ess'][0], yerr=history['ess'][1])
axes[3].set_ylabel('ESS')

axes[4].errorbar(history['t'], history['Z'][0], yerr=history['Z'][1])
axes[4].axhline(true_Z, color='k', linestyle='--')
axes[4].set_ylabel('Z')

fig.tight_layout()
plt.show()

In [None]:
# Model likelihood
logq = logr - logJ
print('avg logq =', logq.mean().item())
print('std logq =', logq.std().item())

In [None]:
# TODO: check normalization of q over time
# NOTE(gkanwar): Done, see plot of Z which should remain constant
# bootstrap over samples -> errors
# NOTE(gkanwar): Done, see al.bootstrap calls
# calculate divergence explicitly
# NOTE(gkanwar): Done, see torch.func.jacfwd impl

In [None]:
# Target likelihood
logp = -action(U_0) + log_haar_su2(x_0)
print('avg logp =', logp.mean().item())
print('std logp =', logp.std().item())

In [None]:
# Effective Sample Size
print('ESS = ', compute_ess(logp, logq).item())

# Old code

**GK:** We will probably want some version of this when score_net accepts the matrices directly.

In [None]:
@torch.no_grad()
def solve_reverse_ODE(U_1, logr, score_net, diffuser, num_steps=200, verbose=False):
    trajectories = {
        'U_t': [],
        'logp': [],
        'logq': [],
        'kl_div': [],
        'ess': []
    }
    dt = 1 / num_steps
    t = 1.0
    U_t = U_1.clone()

    logJ = 0.
    for step in tqdm.tqdm(range(num_steps)):
        # Eigendecompose
        x_t, V, V_inv = mat_angle(U_t)

        # Get ODE params
        sigma_t = diffuser.sigma_func(t)
        g_t = diffuser.noise_coeff(t)
        score = score_net(x_t[..., :-1], t*torch.ones((U_t.size(0),))) / sigma_t**2

        # Skilling-Hutchinson Divergence Estimation
        func = lambda x: score_net(x, t*torch.ones((U_t.size(0),))) / sigma_t**2
        div = estimate_divergence(func, x_t[..., :-1], num_estimates=10)
        
        # Integration step in reverse time
        x_t = x_t + 0.5 * g_t**2 * score * dt
        logJ = logJ + 0.5 * g_t**2 * div * dt

        # Eigen-recomposition
        D = embed_diag(torch.exp(1j * x_t)).to(V)
        U_t = V @ D @ V_inv
        t -= dt

        # Collect and print metrics
        logp = -action(U_t)
        logq = logr - logJ
        kl_div = compute_kl_div(logp, logq).item()
        ess = compute_ess(logp, logq).item()
        if verbose:
            print(f'Step {step}/{num_steps}')
            print('logp =', logp.mean().item())
            print('logq =', logq.mean().item())
            print('Dkl =', kl_div)
            print('ESS =', ess)
            print()
        trajectories['U_t'].append(U_t)
        trajectories['logp'].append(logp.mean().item())
        trajectories['logq'].append(logq.mean().item())
        trajectories['kl_div'].append(kl_div)
        trajectories['ess'].append(ess)
 
    return U_t, logJ, trajectories