## Invariant potential CNF appraoch for learning $\mathbb{CP}^1$ model 

In [1]:
!nvidia-smi

Mon Sep  9 08:58:15 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |


|   0  Tesla T4                       Off | 00000000:00:07.0 Off |                    0 |
| N/A   42C    P0              33W /  70W |      2MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+


In [2]:
import optax
import flax
import flax.linen as nn
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp

from functools import partial
from typing import Callable
import chex

from coset_flow.glutils import sample_haar, sample_haar_lattice, SU2_GEN, liegrad, roll_lattice
from coset_flow.jaxcg import CG3, crouch_grossmann

Implement 
- Pairwise Function potential 
- The 'h Map' (Hopf Fibration) for mapping SU(2) matrices to CP^1 vectors (but keep the phase)

In [4]:
class PRNGSequence:
    def __init__(self, rng):
        self._rng = rng
    def __next__(self):
        self._rng, rng = jax.random.split(self._rng)
        return rng

rng = jax.random.PRNGKey(40)

We define the hopf map:
$$SU(2) \ni \left( \begin{matrix} z_{1} & -z_{2}^{*} \\ z_2 & z_{1}^{*} \end{matrix} \right) \rightarrow \left( \begin{matrix} \cos \phi \\ -\sin\phi e^{-i(\alpha + \beta)} \end{matrix}\right) e^{i\alpha} \in \mathbb{CP}^1$$

Mapping on Lattice

In [5]:
def su2_to_cp1(samples):
    # Assumes input shape (2, 2)
    z1 = samples[0,0]
    z2 = samples[1,0]

    # Extract U(1) phase
    phase = jnp.exp(1j * jnp.angle(z1))
    # Extract cos_phi
    cos_phi = jnp.real(z1 / phase)  
    # Construct the CP^1 vector without the phase
    cp1_vector = jnp.array([cos_phi, z2 * jnp.conjugate(phase)])

    return cp1_vector

def cp1_to_su2(cp1_vec, phase):
    # Assumes input shape cp1_vec: (Counts, 2), phase: (Counts, 1)
    cp1_vec = cp1_vec * phase

    z1 = cp1_vec[:, 0]
    z2 = cp1_vec[:, 1]

    # Construct the SU(2) matrices
    su2_recon = jnp.array([
        [z1, -jnp.conjugate(z2)],
        [z2, jnp.conjugate(z1)]
    ])

    # Rearrange axes to get the correct shape (num_samples, 2, 2)
    return jnp.transpose(su2_recon, axes=(2, 0, 1))

In [6]:
def su2_to_cp1_lattice(lattice):
    # Extract the elements z1 and z2 from the SU(2) matrices
    z1 = lattice[..., 0, 0]  # Shape: (sample size, N, N)
    z2 = lattice[..., 1, 0]  # Shape: (sample size, N, N)
    
    # Extract the U(1) phase
    phase = jnp.exp(1j * jnp.angle(z1))  # Shape: (sample size, N, N)
    
    # Normalize z1 to extract cos_phi (by removing the phase)
    cos_phi = jnp.real(z1 / phase)  # Shape: (sample size, N, N)
    
    # Construct the CP^1 vector without the phase
    cp1_vector = jnp.stack([cos_phi, z2 * jnp.conjugate(phase)], axis=-1)  # Shape: (sample size, N, N, 2)
    
    # Reshape phase to match the output shape requirement
    phase = phase[..., jnp.newaxis]  # Shape: (sample size, N, N, 1)

    return cp1_vector, phase


def cp1_to_su2_lattice(cp1_vec, phase):
    # Multiply CP^1 vectors by their corresponding phases to recover the original SU(2) vectors
    cp1_vec = cp1_vec * phase

    z1 = cp1_vec[..., 0]  # Shape: (sample size, N, N)
    z2 = cp1_vec[..., 1]  # Shape: (sample size, N, N)

    # Construct the SU(2) matrices
    su2_recon = jnp.stack([
        jnp.stack([z1, -jnp.conjugate(z2)], axis=-1),  # First row of the SU(2) matrix
        jnp.stack([z2, jnp.conjugate(z1)], axis=-1)    # Second row of the SU(2) matrix
    ], axis=-2)  # Stack along the second-to-last axis to form the 2x2 matrix

    # Final shape: (sample size, N, N, 2, 2)
    return su2_recon        

Potential 

In [7]:
class MLP_local(nn.Module):
    hidden_features: int = 128
    num_layers: int = 30
    activation: Callable = jax.nn.tanh

    @nn.compact
    def __call__(self, t, x):
        # Append time dimension to the input
        t_p = jnp.expand_dims(t, axis=-1)
        t_p = jnp.broadcast_to(t_p, x.shape[:-1] + (1,))
        x = jnp.concatenate([x, t_p], axis=-1)
        x = nn.Dense(10, kernel_init=nn.initializers.xavier_normal())(x)

        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_features, kernel_init=nn.initializers.xavier_normal())(x)
            x = self.activation(x)

        x = nn.Dense(1, kernel_init=nn.initializers.xavier_normal())(x)  
        return x.reshape(())

In [8]:
def get_neighbors(U):
    """Get neighboring SU(2) matrices for each point in the lattice."""
    # Input shape (N, N, 2, 2)
    U_up = roll_lattice(U, (1, 0), invert=False) 
    U_down = roll_lattice(U, (1, 0), invert=True)  
    U_left = roll_lattice(U, (0, 1), invert=False)  
    U_right = roll_lattice(U, (0, 1), invert=True)  
    return U_up, U_down, U_left, U_right

def pairwise_products(cp1, cp1_up, cp1_down, cp1_left, cp1_right):
    """Performs pairwise products at a single lattice site"""

    dot1 = jnp.vdot(cp1, cp1)        # This is always 1, perhaps can be removed
    dot2 = jnp.vdot(cp1, cp1_up)
    dot3 = jnp.vdot(cp1, cp1_down)
    dot4 = jnp.vdot(cp1, cp1_left)
    dot5 = jnp.vdot(cp1, cp1_right)

    products = jnp.array([
        dot1.real, dot1.imag,
        dot2.real, dot2.imag,
        dot3.real, dot3.imag,
        dot4.real, dot4.imag,
        dot5.real, dot5.imag,
    ])
    
    return products

class InvPotential(nn.Module):
    N: int
    hidden_features: int = 64

    def setup(self):
        self.MLP = MLP_local(hidden_features=self.hidden_features)

    def __call__(self, t, U_up, U_down, U_left, U_right, U):

        cp1 = su2_to_cp1(U)
        cp1_up = su2_to_cp1(U_up)
        cp1_down = su2_to_cp1(U_down)
        cp1_left = su2_to_cp1(U_left)
        cp1_right = su2_to_cp1(U_right)

        x = pairwise_products(cp1, cp1_up, cp1_down, cp1_left, cp1_right)

        return self.MLP(t, x)
    
class Equiv_VF(nn.Module):
    N: int
    hidden_features: int = 128

    def setup(self):
        self.potential = InvPotential(N=self.N, hidden_features=self.hidden_features)

    def vector_field_ij(self, params, t, U_up, U_down, U_left, U_right, U):
        grad_potential = liegrad.grad(self.potential.apply, argnum=2, algebra=SU2_GEN)
        return grad_potential(params, t, U_up, U_down, U_left, U_right, U)

    def __call__(self, params, t, U):

        U_up, U_down, U_left, U_right = get_neighbors(U)

        vector_field = jax.vmap(jax.vmap(self.vector_field_ij, 
                            in_axes=(None, None, 0, 0, 0, 0, 0)),
                                in_axes=(None, None, 0, 0, 0, 0, 0))

        return vector_field(params, t, U_up, U_down, U_left, U_right, U)

In [9]:
class CNF(nn.Module):
    N: int
    t0: float
    t1: float
    hidden_features: int = 64
    steps: int = 100
    int_step = 10

    def setup(self):
        self.potential = InvPotential(N=self.N, hidden_features=self.hidden_features)

    def forward(self, U, p_params):   

        def val_grad_div(t, U, p_params):

            def val_grad_div_single_sample(U):
                U_up, U_down, U_left, U_right = get_neighbors(U)
                # Calculate divergence using the potential which needs U and its neighbors
                _, grad, div = liegrad.value_grad_divergence(partial(self.potential.apply, p_params, t, U_up, U_down, U_left, U_right), U, SU2_GEN)
                return grad, div

            # Vectorize over the entire lattice
            grad, div = jax.vmap(jax.vmap(jax.vmap(val_grad_div_single_sample, in_axes=0), in_axes=0), in_axes=0)(U)
            
            return grad, -div.sum(axis=(1, 2)) 

        x = U
        logp = 0.
        for t in range(self.int_step):

            grad0, _ = val_grad_div(t / self.int_step, x, p_params)

            x = x + grad0
            _, div = val_grad_div(t / self.int_step, x, p_params)

            logp += div / self.int_step

        return U, logp
    
    def __call__(self, U, p_params):
        return self.forward(U, p_params)

The $\mathbb{CP}^1$ Action 
$$S = \frac{1}{g}\sum_{n,\mu} (1 - \left| \bar{z}_{n+ae_{\mu}} \cdot z_n\right|^2)$$
where $a$ = lattice spacing, $e_\mu$ is the unit vector in $\mu$ direction.

In [10]:
def compute_action(cp1_lattice, g):
       
    action = 0.0

    # Iterate over the directions (mu_x = (1, 0) and mu_y = (0, 1))
    for mu in [(1, 0), (0, 1)]:
        # Shift cp1_lattice in the mu direction using periodic boundary conditions
        shifted_lattice = jnp.roll(cp1_lattice, shift=mu, axis=(0, 1))
        z_dot = jnp.sum(jnp.conj(cp1_lattice) * shifted_lattice, axis=-1)
        norm_z_dot = jnp.abs(z_dot) ** 2
        action += jnp.sum(1 - norm_z_dot, axis=(0, 1))

    return action / g

@chex.dataclass
class CP1Theory:
    """CP^1 model theory."""
    shape: tuple[int, ...]  # Lattice shape (e.g., (N, N) for 2D lattice)
    g: chex.Scalar          # Coupling constant

    @property
    def lattice_size(self):
        return jnp.prod(jnp.array(self.shape))

    @property
    def dim(self):
        return len(self.shape)

    def action(self, cp1_lattice: jnp.ndarray, *, g: chex.Scalar = None) -> jnp.ndarray:

        g = self.g if g is None else g

        # Determine if we're working with a batch or a single configuration
        if cp1_lattice.ndim == self.dim + 1:
            # Single configuration: shape (N, N, 2)
            chex.assert_shape(cp1_lattice, self.shape + (2,))
            action = compute_action(cp1_lattice, g)
            return action
        else:
            # Batch of configurations: shape (batch_size, N, N, 2)
            chex.assert_shape(cp1_lattice[0], self.shape + (2,))
            act = partial(compute_action, g=g)
            action = jax.vmap(act)(cp1_lattice)
            return action
    

"def compute_action(cp1_lattice, g):\n    \n    shifts = jnp.array([[1, 0], [0, 1]])\n\n    shifted_lattices = jax.vmap(lambda shift: jnp.roll(cp1_lattice, shift=shift, axis=(0, 1)))(shifts)\n    z_dot = jnp.einsum('dnij,nij->dnij', shifted_lattices, jnp.conj(cp1_lattice))\n    norm_z_dot = jnp.abs(z_dot) ** 2\n    \n    action = jnp.sum(1 - norm_z_dot)\n    \n    return action / g"

## Plotting functions, Loss functions and ESS

In [11]:
@jax.jit
def reverse_dkl(logp: jnp.ndarray, logq: jnp.ndarray) -> jnp.ndarray:
    return jnp.mean(logq - logp)

@jax.jit
def effective_sample_size(logp: jnp.ndarray, logq: jnp.ndarray) -> jnp.ndarray:
    logw = logp - logq
    log_ess = 2*jax.nn.logsumexp(logw, axis=0) - jax.nn.logsumexp(2*logw, axis=0)
    ess_per_sample = jnp.exp(log_ess) / len(logw)
    return ess_per_sample

def moving_average(x: jnp.ndarray, window: int = 10):
    if len(x) < window:
        return jnp.mean(x, keepdims=True)
    else:
        return jnp.convolve(x, jnp.ones(window), 'valid') / window

In [12]:
def init_live_plot(figsize=(8, 4), logit_scale=True, **kwargs):
    fig, ax_ess = plt.subplots(1, 1, figsize=figsize, **kwargs)

    ess_line = plt.plot([0], [0.5], color='C0', label='ESS')
    plt.grid(False)
    plt.ylabel('ESS')
    if logit_scale:
        ax_ess.set_yscale('logit')
    else:
        plt.ylim(0, 1)

    ax_loss = ax_ess.twinx()
    loss_line = plt.plot([0], [1], color='C1', label='KL Loss')
    plt.grid(False)
    plt.ylabel('Loss')
    plt.xlabel('Steps')
    plt.legend(loc='upper right')

    lines = ess_line + loss_line
    plt.legend(lines, [line.get_label() for line in lines], loc='upper center', ncol=2)

    setup = dict(
        fig=fig, ax_ess=ax_ess, ax_loss=ax_loss,
        ess_line=ess_line, loss_line=loss_line, logit=logit_scale)
    
    display_id = display(fig, display_id=True)
    setup['display_id'] = display_id

    return setup


def update_plots(history, setup, window_size=15):
    ess_line = setup['ess_line']
    loss_line = setup['loss_line']
    ax_loss = setup['ax_loss']
    ax_ess = setup['ax_ess']
    fig = setup['fig']

    ess = np.array(history['ess'])
    ess = moving_average(ess, window=window_size)
    steps = np.arange(len(ess))
    ess_line[0].set_ydata(ess)
    ess_line[0].set_xdata(steps)
    if setup['logit'] and len(ess) > 1:
        ax_ess.relim()
        ax_ess.autoscale_view()

    loss = np.array(history['loss'])
    loss = moving_average(loss, window=window_size)
    loss_line[0].set_ydata(loss)
    loss_line[0].set_xdata(steps)
    if len(loss) > 1:
        ax_loss.relim()
        ax_loss.autoscale_view()

    setup['display_id'].update(fig)

## Implement

In [13]:
lattice_shape = (6,6)
batch_size = 128
U_init = sample_haar_lattice(rng, 1, lattice_shape)[0]
U_1, U_2, U_3, U_4 = get_neighbors(U_init)

rng_seq = PRNGSequence(rng)
rng = next(rng_seq)

In [14]:
N = lattice_shape[0]
g = 2.5
t0 = 0
t1 = 1 
hidden_features = 128

potential = InvPotential(N=N, hidden_features=hidden_features)
p_params = potential.init(rng, t0, U_1, U_2, U_3, U_4, U_init) 
# p_params = jax.tree.map(lambda x: x.astype(jnp.float64), p_params)

rng_seq = PRNGSequence(rng)
rng = next(rng_seq)

model = CNF(N=N, t0=t0, t1=t1, hidden_features=hidden_features)
model_params = {}

theory = CP1Theory(shape=lattice_shape, g=g)

def _loss_fn(p_params, rng):

    su2_haar = sample_haar_lattice(rng, batch_size, lattice_shape)
    # _, phases_lattice = su2_to_cp1_lattice(su2_haar)

    U_t, logq = model.apply(model_params, su2_haar, p_params)
    cp1_lattice, _ = su2_to_cp1_lattice(U_t)
    logp = -theory.action(cp1_lattice)
    
    dkl = reverse_dkl(logp, logq)
    return dkl, (logq, logp, cp1_lattice)

# using jax, we can generate the gradient function
value_and_grad = jax.value_and_grad(_loss_fn, has_aux=True)

In [15]:
# choose an optimizer
lr = optax.exponential_decay(0.00005, 8000, 1e-1)
opt = optax.adam(lr, .8, .9)
opt_state = opt.init(p_params)

# given parameters and optimizer state, do one update step
@jax.jit
def _update_step(rng, p_params, opt_state):
    (loss, (logq, logp, cp1)), grad = value_and_grad(p_params, rng)
    updates, opt_state = opt.update(grad, opt_state)
    p_params = optax.apply_updates(p_params, updates)
    return p_params, opt_state, loss, effective_sample_size(logp, logq), cp1

# we will later keep lists of the loss and ESS history, which we update here
def update_step(rng, p_params, opt_state, history):
    p_params, opt_state, loss, ess, _ = _update_step(rng, p_params, opt_state)
    history['loss'].append(loss)
    history['ess'].append(ess)
    return p_params, opt_state, ess, loss


In [16]:
# For plotting
epochs = 50
epoch_size = 10

history = {
    'loss': [],
    'ess': [],
}

In [17]:
# initialize plotting
plot_config = init_live_plot()
counter = 0
for era in range(epochs):
    for epoch in range(epoch_size):
        counter += 1
        p_params, opt_state, ess, loss = update_step(rng, p_params, opt_state, history)
        rng_seq = PRNGSequence(rng)
        rng = next(rng_seq)
        print(f"Epoch: {counter}, ESS:, {ess}, KL: {loss}")
    update_plots(history, plot_config)
plt.close()