# Diffusion on SU(N) from the identity
As a sanity check of our score and training, we attempt to learn the analytical score for diffusion starting from the identity configuration (delta distribution) at $t = 0$. We explore several noise schedules, which just modify the time dependence of the learned score.

In [None]:
# General imports
import numpy as np
import torch
import tqdm.auto as tqdm

import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
# Imports from our repo
import sys
sys.path.insert(0, '..')  # repo source code

from src.linalg import trace, adjoint
from src.action import SUNToyAction
from src.diffusion import VarianceExpandingDiffusionSUN
from src.sun import (
    proj_to_algebra, matrix_log,
    random_sun_element, random_un_haar_element,
    group_to_coeffs, coeffs_to_group,
    extract_sun_algebra, embed_diag, mat_angle,
)
from src.heat import sun_score_hk, sample_sun_hk, sun_hk

from src.utils import grab, wrap
from src.devices import set_device, get_device, summary

In [None]:
import importlib
import src
importlib.reload(src)
importlib.reload(src.diffusion)
importlib.reload(src.heat)
importlib.reload(src.canon)

In [None]:
# Set a device
import src.devices as devices  # from our src code

devices.set_device('cpu')
print(devices.summary())

# Score network

In [None]:
class TimeDependentScoreNetSUN(torch.nn.Module):
    def __init__(self, Nc, n_hidden=8):
        super().__init__()
        n_in = 2*Nc*Nc + 1 # NxN complex elts + 1 time
        n_out = Nc**2 - 1 # number of generators
        self.net = torch.nn.Sequential(
            torch.nn.Linear(n_in, n_hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(n_hidden, n_hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(n_hidden, n_hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(n_hidden, n_out)
        )

    def forward(self, x, t):
        x = torch.view_as_real(x).flatten(-3)
        t = t.unsqueeze(-1)
        x = torch.cat([x, t], dim=-1)
        return self.net(x)

In [None]:
def _test_score_net():
    # regression to eigenangles
    # def loss(xs, score_net):
    #     batch_size, Nc = xs.shape
    #     V = random_sun_haar_element(batch_size, Nc=Nc)
    #     U = V @ embed_diag(torch.exp(1j*xs)).to(V) @ adjoint(V)
    #     return ((score_net(U, torch.ones(batch_size))[:,:Nc] - xs)**2).mean()
    torch.manual_seed(1235)
    batch_size = 16
    Nc = 2
    # regression to true score
    def loss(xs, score_net):
        batch_size, Nc = xs.shape
        V = random_un_haar_element(batch_size, Nc=Nc)
        U = V @ embed_diag(torch.exp(1j*xs)).to(V) @ adjoint(V)
        width = 0.1*torch.ones(xs.shape[0])
        true_score = sun_score_hk(xs[...,:-1], width=width)
        true_score = V @ embed_diag(true_score).to(V) @ adjoint(V)
        true_score = extract_sun_algebra(true_score).real
        assert torch.all(torch.isfinite(true_score)), f'{true_score=}'
        score = score_net(U, torch.ones(batch_size))
        assert true_score.shape == score.shape
        return ((score - true_score)**2).mean()
    # train
    score_net = TimeDependentScoreNetSUN(Nc=Nc, n_hidden=64)
    optimizer = torch.optim.Adam(score_net.parameters(), lr=3e-4)
    hist_loss = []
    for _ in tqdm.tqdm(range(5000)):
        optimizer.zero_grad()
        width = 0.1*torch.ones(batch_size)
        # NOTE: important to avoid near-zeros of heat kernel, so sampling should be high quality
        xs = torch.tensor(sample_sun_hk(batch_size, Nc, width=width, n_iter=25))
        l = loss(xs, score_net)
        hist_loss.append(grab(l))
        l.backward()
        optimizer.step()
    fig, ax = plt.subplots(1,1)
    ax.plot(hist_loss)
    ax.set_yscale('log')
    plt.show()
_test_score_net()

# Train diffusion model

**GK:** This is only partially complete, need to explore further.

In [None]:
def score_matching_loss(x_0, score_net, diffuser, tol=1e-5):
    # TODO(gkanwar): Go to back to random selection of t
    # t = torch.rand((x_0.shape[0],))
    # t = (1 - tol) * t + tol  # avoid endpoints where score can become unstable
    t = 0.5*torch.ones((x_0.shape[0],))
    x_t, ths, V = diffuser(x_0, t)
    
    sigma_t = diffuser.sigma_func(t)
    score = score_net(x_t, t)
    true_score = sun_score_hk(torch.tensor(ths)[...,:-1], width=sigma_t)
    true_score = V @ embed_diag(true_score).to(V) @ adjoint(V)
    true_score = extract_sun_algebra(true_score).real
    assert true_score.shape == score.shape
    loss = torch.sum((score - true_score)**2, dim=-1)
    return loss.mean()

In [None]:
# physics
Nc = 2

# diffusion
sigma = 1.1 # TODO(gkanwar): update to larger sigma
diffuser = VarianceExpandingDiffusionSUN(sigma)

# machine learning
score_net = TimeDependentScoreNetSUN(Nc=Nc)
num_epochs = 1_000
lr = 1e-3
optimizer = torch.optim.Adam(params=score_net.parameters(), lr=lr)

# training data
batch_size = 32
x_0 = torch.stack([torch.eye(Nc).cdouble()]*batch_size)

# training
losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    optimizer.zero_grad()
    
    loss = score_matching_loss(x_0, score_net, diffuser)
    loss.backward()
    optimizer.step()
    
    print(f'Epoch {epoch} / {num_epochs} | Loss = {loss.item():.6f}')
    losses.append(loss.item())

In [None]:
fig, ax = plt.subplots(1,1)
ax.plot(losses)
ax.set_yscale('log')
plt.show()