# `l2hmc`: Example

This notebook will (attempt) to walk through the steps needed to:
  1. Initialize an `Experiment` from a specified `ExperimentConfig`
  2. Successfully "run" an `Experiment`:
    1. Train the sampler via `Experiment.train()`
    2. Evaluate the trained sampler via `Experiment.evaluate(job_type='eval')
    3. (Optionally) Run generic HMC to compare against.

## Imports

In [1]:
from __future__ import absolute_import, print_function, annotations, division

%load_ext autoreload
%autoreload 2

!unset TF_XLA_FLAGS
!unset KMP_AFFINITY KMP_SETTINGS

import os
import hydra

import tensorflow as tf
tf.keras.backend.set_floatx('float64')

import torch
torch.set_default_dtype(torch.float64)

import numpy as np
    

from l2hmc.main import setup, setup_tensorflow, setup_torch

import horovod.tensorflow as hvdtf
hvdtf.init()
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    gpu = gpus[hvd_tf.local_rank()]
    tf.config.experimental.set_visible_devices(gpu, 'GPU')

import horovod.torch as hvdpt
hvdpt.init()

os.environ['OMP_NUM_THREADS'] = '8'

from hydra.core.global_hydra import GlobalHydra
from l2hmc.utils.rich import print_config

In [2]:
os.environ['WIDTH'] = '220'

In [3]:
import rich
console = rich.get_console()
console.width = 220
console._width = 220

# Set some reasonable defaults for `ExperimentConfig`:

**Note**: For the purposes of demonstrating functionality, we only consider a very simple `debug` example here

In [4]:
GlobalHydra.instance().clear()

defaults = [
    'mode=debug',
    'steps.nera=5',
    'steps.nepoch=100',
    'steps.test=500',
    'steps.print=5',
    'steps.log=5',
]

defaults_u1 = [
    *defaults,
    'dynamics.nchains=16',
    'dynamics.nleapfrog=4',
    'dynamics.latvolume=[8, 8]',
]

defaults_su3 = [
    *defaults,
    'dynamics=su3',
    'dynamics.nchains=5',
    'dynamics.latvolume=[8, 8, 8, 16]',
    'annealing_schedule.beta_init=1.0',
    'annealing_schedule.beta_final=1.0',
]

In [5]:
outputs = {
    'pytorch': {
        'train': {},
        'eval': {},
        'hmc': {},
    },
    'tensorflow': {
        'train': {},
        'eval': {},
        'hmc': {},
    },
}

# Initialize and Build `Experiment` objects:

- The `l2hmc.configs` module provides a function `get_experiment`:

```python
def get_experiment(overrides: list[str]) -> Experiment:
    ...
```

which will:

    1. Load the default options from `conf/config.yaml`
    2. Override the default options with any values provided in `overrides`
    3. Parse these options and build an `ExperimentConfig` which uniquely defines an experiment
    3. Instantiate / return an `Experiment` from the `ExperimentConfig`
    
**Note:** Prior to beginning training, the `Experiment` must be `built` by calling

```python
>>> experiment = get_experiment(overrides=['mode=debug'])
>>> _ = experiment.build(init_wandb=(RANK == 0), init_aim=(RANK == 0))
```

After which, the model can be trained and evaluated via:

```python
>>> train_output = experiment.train()
>>> eval_output = experiment.evaluate(job_type='eval')
>>> hmc_output = experiment.evaluate(job_type='hmc')
```

## PyTorch

We build models for both:

1. 2D $U(1)$ model
2. 4D $SU(3)$ model

In [7]:
from l2hmc.configs import get_experiment

ptExpU1 = get_experiment(
    overrides=[
        *defaults_u1,
        'framework=pytorch',
    ]
)

ptExpU1_ = get_experiment(
    overrides=[
        *defaults_u1,
        'framework=pytorch',
        'net_weights.x.s=0.0',
        'net_weights.x.t=0.0',
        'net_weights.x.q=0.0',
        'net_weights.v.s=0.0',
        'net_weights.v.t=0.0',
        'net_weights.v.q=0.0',
    ]
)

ptExpSU3 = get_experiment(
    overrides=[
        *defaults_su3,
        'framework=pytorch',
    ]
)

ptObjsU1 = ptExpU1.build(init_wandb=True, init_aim=True)
ptObjsU1_ = ptExpU1_.build(init_wandb=False, init_aim=True)
ptObjsSU3 = ptExpSU3.build(init_wandb=False, init_aim=False)

[34m[1mwandb[0m: Currently logged in as: [33msaforem2[0m ([33ml2hmc-qcd[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from l2hmc.utils.rich import get_console
console = get_console(width=210)

In [None]:
%matplotlib inline

In [None]:
#ptExpU1.trainer.dynamics.init_weights('zeros')
#ptExpU1.trainer.reset_optimizer()

In [None]:
ptExpU1.trainer.dynamics.train()

In [None]:
outputs['pytorch']['train'] = ptExpU1.train()

In [None]:
output_ptU1 = ptExpU1_.train()

In [None]:
outputs['pytorch']['eval'] = ptExpU1.evaluate('eval')

In [None]:
from l2hmc.common import plot_dataset

plot_dataset(ptExpU1.trainer.histories['train'].get_dataset())

In [None]:
outputs['pytorch']['eval'] = ptExpU1.evaluate(job_type='eval')

In [None]:
outputs['pytorch']['hmc'] = ptExpU1.evaluate(job_type='hmc')

In [None]:
from l2hmc.utils.plot_helpers import plot_dataset

tdsetpt = ptExpU1.trainer.histories['train'].get_dataset()
_ = plot_dataset(tdsetpt)

In [None]:
import matplotlib.pyplot as plt

from l2hmc.common import plot_dataset

plt.rcParams['axes.labelcolor'] = '#FFFFFF'

plot_dataset(tdsetpt, nchains=4)

In [None]:
edsetpt = ptExpU1.trainer.histories['eval'].get_dataset()

In [None]:
plot_dataset(edsetpt, nchains=36)

## TensorFlow: 2D $U(1)$ and 4D $SU(3)$ models:

In [None]:
tfExpU1 = get_experiment(
    overrides=[
        *defaults_u1,
        'framework=tensorflow',
    ]
)

tfExpSU3 = get_experiment(
    overrides=[
        *defaults_su3,
        'framework=tensorflow',
    ]
)

tfObjsU1 = tfExpU1.build(init_wandb=False, init_aim=True)
tfObjsSU3 = tfExpSU3.build(init_wandb=False, init_aim=False)

In [None]:
outputs['tensorflow']['train'] = tfExpU1.train()

In [None]:
outputs['tensorflow']['eval'] = tfExpU1.evaluate(job_type='eval')

In [None]:
tdsettf = tfExpU1.trainer.histories['train'].get_dataset()
edsettf = tfExpU1.trainer.histories['eval'].get_dataset()

In [None]:
%matplotlib notebook

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['axes.labelcolor'] = '#bdbdbd'
sns.set_context('notebook', font_scale=0.8)
plot_dataset(tdsettf)

In [None]:
plt.rcParams['figure.dpi'] = 100
sns.set_context('notebook', font_scale=0.8)
plt.rcParams['axes.labelcolor'] = '#666666'
plot_dataset(edsettf, therm_frac=0.2)

In [None]:
import l2hmc.utils.plot_helpers as hplt

In [None]:
hplt.plot_dataset(tdsettf, therm_frac=0.2)

In [None]:
hplt.plot_dataset(edsettf, therm_frac=0.2)

## Debugging / Testing

In [None]:
from l2hmc.experiment.pytorch.experiment import Experiment as ptExperiment
from l2hmc.experiment.tensorflow.experiment import Experiment as tfExperiment

In [None]:
def get_lattice_metrics(experiment: ptExperiment | tfExperiment, beta: float = 1.) -> dict:
    if experiment.cfg.framework == 'pytorch':
        assert isinstance(experiment, ptExperiment)
    if experiment.cfg.framework == 'tensorflow':
        assert isinstance(experiment, tfExperiment)
    state = experiment.trainer.dynamics.random_state(beta)
    metrics = experiment.lattice.calc_metrics(state.x)
    return metrics

In [None]:
from typing import Optional
import rich

from l2hmc.configs import State

console = rich.console.Console(log_path=False, force_jupyter=True)

def check_diff(x, y, name: Optional[str] = None):
    if isinstance(x, State):
        xd = {'x': x.x, 'v': x.v, 'beta': x.beta}
        yd = {'x': y.x, 'v': y.v, 'beta': y.beta}
        check_diff(xd, yd, name=f'State')
        
    if isinstance(x, dict) and isinstance(y, dict):
        for (kx, vx), (ky, vy) in zip(x.items(), y.items()):
            check_diff(vx, vy, name=kx)
    else:
        if isinstance(x, torch.Tensor):
            x = x.detach().numpy()
        if isinstance(y, torch.Tensor):
            y = y.detach().numpy()

        if isinstance(x, tf.Tensor):
            x = x.numpy()
        if isinstance(y, tf.Tensor):
            y = y.numpy()

        dstr = []
        if name is not None:
            dstr.append(f"'{name}''")
        dstr.append(f'  sum(diff): {(x - y).sum()}')
        dstr.append(f'  min(diff): {(x - y).min()}')
        dstr.append(f'  max(diff): {(x - y).max()}')
        dstr.append(f'  mean(diff): {(x - y).mean()}')
        dstr.append(f'  std(diff): {(x - y).std()}')
        dstr.append(f'  np.allclose: {np.allclose(x, y)}')
        console.log('\n'.join(dstr))

# Check Lattice Methods

In [None]:
state_pt = ptExpU1.trainer.dynamics.random_state(1.0)
xpt = state_pt.x
xtf = tf.constant(xpt.detach().numpy())
check_diff(xpt, xtf, name='x')

In [None]:
x = np.random.randn(*xpt.detach().numpy().shape)

In [None]:
dpt = ptExpU1.trainer.dynamics
dtf = tfExpU1.trainer.dynamics

xcpt = dpt.g.compat_proj(torch.tensor(x))
xctf = dtf.g.compat_proj(tf.constant(x))

In [None]:
check_diff(xcpt, xctf)

In [None]:
xpt.shape

In [None]:
xtf.shape

In [None]:
lmetrics_tf = tfExpU1.trainer.lattice.calc_metrics(xtf)
lmetrics_pt = ptExpU1.trainer.lattice.calc_metrics(xpt)

In [None]:
check_diff(xtf, xpt)

In [None]:
wloopspt = ptExpU1.trainer.lattice.wilson_loops(xpt)
wloopstf = tfExpU1.trainer.lattice.wilson_loops(xtf)
check_diff(wloopspt, wloopstf, name='wloops')

In [None]:
from l2hmc.lattice.u1.tensorflow.lattice import project_angle as proj_tf
from l2hmc.lattice.u1.pytorch.lattice import project_angle as proj_pt
lattice_pt = ptExpU1.trainer.lattice
lattice_tf = tfExpU1.trainer.lattice

In [None]:
wprojpt = proj_pt(wloopspt)
wprojtf = proj_tf(tf.constant(wloopstf))
check_diff(wprojpt.detach().numpy(), wprojtf.numpy(), name='wproj')

In [None]:
qintpt = lattice_pt._int_charges(wloopspt)
qinttf = lattice_tf._int_charges(wloopstf)
check_diff(qintpt.detach().numpy(), qinttf.numpy(), name='qint')

In [None]:
qsinpt = lattice_pt._sin_charges(wloopspt)
qsintf = lattice_tf._sin_charges(wloopstf)
check_diff(qsinpt, qsintf, name='qsin')

In [None]:
wloopspt.shape

In [None]:
wloopstf.shape

In [None]:
qint_pt = proj_pt(wloopspt).sum((1, 2)) / 2 * np.pi
qint_tf = proj_tf(wloopstf).numpy().sum((1, 2)) / 2 * np.pi

In [None]:
wproj_pt = proj_pt(wloopspt)
wproj_tf = proj_tf(wloopstf)

check_diff(wproj_pt, wproj_tf, name='wproj')

In [None]:
wproj_pt.shape

In [None]:
wproj_tf.shape

In [None]:
wprojsum_pt = wproj_pt.sum((1, 2)) / 2 * np.pi
wprojsum_tf = tf.reduce_sum(wproj_tf, (1, 2)) / 2 * np.pi

wprojsum_pt_ = np.array(wprojsum_pt.detach(), dtype=int)
wprojsum_tf_ = np.array(wprojsum_tf.numpy(), dtype=int)

check_diff(wprojsum_pt.detach().numpy().round(6), wprojsum_tf.numpy().round(6), name='wprojsum')
check_diff(wprojsum_pt_, wprojsum_tf_, name='wprojsum_')

In [None]:
check_diff(qint_pt, qint_tf, name='qint')

In [None]:
qpt = ptExpU1.trainer.lattice.charges(x=xpt)
qtf = tfExpU1.trainer.lattice.charges(x=xtf)
check_diff(qpt.intQ, qtf.intQ, name='intQ')
check_diff(qpt.sinQ, qtf.sinQ, name='sinQ')

In [None]:
np.array(lmetrics_pt['intQ'].detach().numpy()[:10]).round(5)
np.array(lmetrics_tf['intQ'].numpy()[:10]).round(5)

np.array(lmetrics_pt['sinQ'].detach().numpy()[:10]).round(5)
np.array(lmetrics_tf['sinQ'].numpy()[:10]).round(5)

In [None]:
check_diff(lmetrics_tf, lmetrics_pt)

# Check `Dynamics`

## Check Action and Force term:

- $S(x)$
- $F = \partial_x S(x)$

In [None]:
from l2hmc.configs import State

dynamics_tf = tfExpU1.trainer.dynamics
dynamics_pt = ptExpU1.trainer.dynamics


state_pt = dynamics_pt.random_state(1.0)
state_tf = State(
    x=tf.constant(state_pt.x.detach().numpy()),
    v=tf.constant(state_pt.v.detach().numpy()),
    beta=tf.constant(state_pt.beta.detach().numpy())
)

stf = dynamics_tf.potential_energy(state_tf.x, state_tf.beta)
spt = dynamics_pt.potential_energy(state_pt.x, state_pt.beta)

In [None]:
check_diff(stf, spt, name='action')

In [None]:
TF_FLOAT = tf.keras.backend.floatx()
dstf = dynamics_tf.grad_potential(state_tf.x, tf.cast(state_tf.beta, TF_FLOAT))
dspt = dynamics_pt.grad_potential(state_pt.x, state_pt.beta)

In [None]:
check_diff(dstf, dspt)

In [None]:
state_tf_ = dynamics_tf.random_state(1.)
state_pt_ = State(
    x=torch.from_numpy(state_tf_.x.numpy()),
    v=torch.from_numpy(state_tf_.v.numpy()),
    beta=torch.tensor(state_tf_.beta.numpy())
)
check_diff(state_tf_.x.numpy(), state_pt_.x.detach().numpy())

In [None]:
xnet_tf = dynamics_tf.xnet
vnet_tf = dynamics_tf.vnet

xnet_pt = dynamics_pt.xnet
vnet_pt = dynamics_pt.vnet

In [None]:
from  l2hmc.configs import State

beta = 1.0
dynamics_pt = ptExpU1.trainer.dynamics
dynamics_tf = tfExpU1.trainer.dynamics

lattice_pt = ptExpU1.trainer.loss_fn.lattice
lattice_tf = tfExpU1.trainer.loss_fn.lattice

state_pt = dynamics_pt.random_state(beta)

beta_tf = tf.constant(beta)
beta_pt = torch.tensor(beta)

_x = tf.constant(state_pt.x.detach().numpy())
_v = tf.constant(state_pt.v.detach().numpy())
state_tf = State(_x, _v, beta_tf)

In [None]:
check_diff(state_tf.x, state_pt.x, name='x')
check_diff(state_tf.v, state_pt.v, name='v')

## Check `Dynamics.xNet`

In [None]:
def zero_weights(model):
    for layer in model.layers:
        weights = layer.get_weights()
        zeros = []
        for w in weights:
            console.log(f'Zeroing layer for: {layer} in {model}')
            zeros.append(np.zeros_like(w))
            
        layer.set_weights(zeros)
        #if len(weights) > 0:
        #    layer.set_weights([
        #        np.zeros_like(layer.get_weights()[0]),
        #        np.zeros_like
        #    ])
        #    w, b = weights
        #    zw = np.zeros_like(w)
        #    zb = np.zeros_like(b)
        #    console.log(f'Zeroing layer for: {layer} in {model}')
        #    layer.set_weights([w, b])
            
    return model

def check_weights(mpt, mtf):
    wpt = mpt.weight
    bpt = mpt.bias
    wtf, btf = mtf.get_weights()
    try:
        check_diff(
            wpt, wtf
        )
    except ValueError:
        check_diff(
            wpt.T, wtf
        )
        
    check_diff(
        bpt, btf
    )

In [None]:
xnet0tf = dynamics_tf._get_xnet(0, first=True)
xnet0pt = dynamics_pt._get_xnet(0, first=True)

vnet0tf = dynamics_tf._get_vnet(0)
vnet0pt = dynamics_pt._get_vnet(0)

In [None]:
from l2hmc.network.pytorch.network import zero_weights as zero_weights_pt
xnet0pt.apply(zero_weights_pt)
vnet0pt.apply(zero_weights_pt)

xnet0tf = zero_weights(xnet0tf)
vnet0tf = zero_weights(vnet0tf)

In [None]:
vnetxl_tf = vnet0tf.get_layer('vnet_0_xLayer')
vnetxl_pt = vnet0pt.x_layer

vxw_tf, vxb_tf = vnetxl_tf.get_weights()
vxw_pt, vxb_pt = vnetxl_pt.weight, vnetxl_pt.bias

In [None]:
vxw_tf.shape

In [None]:
vxw_pt.shape

In [None]:
xnetxl_tf = xnet0tf.get_layer('xnet_0_first_xLayer')
xnetxl_pt = xnet0pt.x_layer

xxw_tf, xxb_tf = xnetxl_tf.get_weights()
xxw_pt, xxb_pt = xnetxl_pt.weight, xnetxl_pt.bias

In [None]:
xxw_tf.shape

In [None]:
xxw_pt.shape

In [None]:
check_diff(
    xxw_pt.T, xxw_tf
)
check_diff(
    xxb_pt.T, xxb_tf
)

check_diff(
    vxw_pt.T, vxw_tf
)
check_diff(
    vxb_pt.T, vxb_tf
)

$$
\begin{aligned}
A \mathbf{x} &= \begin{bmatrix}
a_{11} & a_{12} & a_{13} & \ldots \\
a_{21} & a_{22} & a_{23} & \ldots \\
\vdots & \vdots & \vdots & \ldots \\
a_{m1} & a_{m2} & a_{m3} & \ldots
\end{bmatrix}
\begin{bmatrix}
x_{1} \\
x_{2} \\
\vdots \\
x_{m}
\end{bmatrix} \\
&= \begin{bmatrix}
a_{11} x_{1} & a_{12} x_{2} & a_{13} x_{3} & \ldots \\
a_{21} x_{1} & a_{22} x_{2} & a_{23} x_{3} & \ldots \\
\vdots & \vdots & \vdots & \ldots \\
a_{m1} x_{1} & a_{m2} x_{2} & a_{m3} x_{3} & \ldots
\end{bmatrix}
\end{aligned}
$$

In [None]:
check_weights(vnetxl_pt, vnetxl_tf)
check_weights(vnetxl_pt, vnetxl_tf)

In [None]:
xnet0tf_ = zero_weights(xnet0tf)

In [None]:
from l2hmc.network.pytorch.network import zero_weights

dynamics_pt.networks['xnet'].apply(zero_weights)
dynamics_pt.networks['vnet'].apply(zero_weights)

In [None]:
xnet0pt = dynamics_pt._get_xnet(0, first=True)

In [None]:
s0pt, t0pt, q0pt = xnet0pt((dynamics_pt._stack_as_xy(state_pt.x), state_pt.v))

In [None]:
s0pt.shape

In [None]:
s0pt[0].detach().numpy().round(5)

In [None]:
xtf_ = dynamics_tf._stack_as_xy(tf.reshape(state_tf.x, (state_tf.x.shape[0], -1)))
s0tf, t0tf, q0tf = xnet0tf((xtf_, state_tf.v))

In [None]:
s0tf[0].numpy().round(5)

In [None]:
check_diff(s0tf, s0pt)

In [None]:
state1_pt = dynamics_pt.leapfrog_hmc(state_pt, eps=0.1)

In [None]:
state_tf = State(
    x=state_tf.x,
    v=state_tf.v,
    beta=tf.cast(state_tf.beta, tf.keras.backend.floatx()),
)
state1_tf = dynamics_tf.leapfrog_hmc(state_tf, eps=0.1)

In [None]:
check_diff(state1_pt.x.detach().numpy(), state1_tf.x.numpy(), name='test-dx-hmc')

In [None]:
check_diff(state1_pt.v.detach().numpy(), state1_tf.v.numpy(), name='test-dv-hmc')

In [None]:
state_pt = State(
    state_pt.x.flatten(1),
    state_pt.v.flatten(1),
    state_pt.beta
)

def flatten(x):
    return tf.reshape(x, (x.shape[0], -1))

state_tf = State(
    flatten(state_tf.x),
    flatten(state_tf.v),
    tf.constant(state_tf.beta,
                dtype=tf.keras.backend.floatx())
)

state2_pt, metrics2_pt = dynamics_pt.transition_kernel_fb(state_pt)
state2_tf, metrics2_tf = dynamics_tf.transition_kernel_fb(state_tf)

state3f_pt, metrics3_pt = dynamics_pt.transition_kernel(state_pt, forward=True)
state3f_tf, metrics3_tf = dynamics_tf.transition_kernel(state_tf, forward=True)

state3b_pt, metrics3_pt = dynamics_pt.transition_kernel(state_pt, forward=False)
state3b_tf, metrics3_tf = dynamics_tf.transition_kernel(state_tf, forward=False)

In [None]:
state1vf_pt, logdet1vf_pt = dynamics_pt._update_v_fwd(0, state_pt)
state1vf_tf, logdet1vf_tf = dynamics_tf._update_v_fwd(0, state_tf)

In [None]:
m, mb = dynamics_pt._get_mask(0)
m_ = tf.constant(m.detach().numpy(), dtype=TF_FLOAT)
mb_ = tf.constant(mb.detach().numpy(), dtype=TF_FLOAT)

state1xf_pt, logdet1xf_pt = dynamics_pt._update_x_fwd(0, state_pt, m=m, first=True)
state1xf_tf, logdet1xf_tf = dynamics_tf._update_x_fwd(0, state_tf, m=m_, first=True)

In [None]:
logdet1xf_pt.shape

In [None]:
logdet1xf_tf

In [None]:
check_diff(state1xf_pt.x, state1xf_tf.x)

In [None]:
from l2hmc.group.u1.pytorch.group import U1Phase as ptU1Phase
from l2hmc.group.u1.tensorflow.group import U1Phase as tfU1Phase

def update_v_fwd_pt(dynamics, step, state) -> tuple[State, torch.Tensor]:
    eps = dynamics.veps[str(step)]
    force = dynamics.grad_potential(state.x, state.beta)
    s, t, q = dynamics._call_vnet(step, (state.x, force))
    jac = eps * s / 2.
    logdet = jac.sum(dim=1)
    exp_s = jac.exp()
    exp_q = (eps * q).exp()
    vf = exp_s * state.v - 0.5 * eps * (force * exp_q + t)
    return State(state.x, vf, state.beta), logdet

def update_x_fwd_pt(dynamics, step, state, first = True, m: Optional[Tensor] = None):
    eps = dynamics.xeps['0']
    if m is None:
        m, mb = dynamics._get_mask(step)
    else:
        mb = torch.ones_like(m) - m
        
    xm_init = m * state.x
    s, t, q = dynamics._call_xnet(step, (xm_init, state.v), first=first)
    s = eps * s
    q = eps * q
    exp_s = s.exp()
    exp_q = q.exp()
    if dynamics.config.use_ncp:
        halfx = state.x / 2.
        _x = 2. * (halfx.tan() * exp_s).atan()
        xp = _x + eps * (state.v * exp_q + t)
        xf = xm_init + (mb * xp)
        cterm = halfx.cos() ** 2
        sterm = (exp_s * halfx.sin()) ** 2
        logdet_ = (exp_s / (cterm + sterm)).log()
        logdet = (mb * logdet_).sum(dim=1)
    else:
        xp = state.x * exp_s + eps * (state.v * exp_q + t)
        xf = xm_init + (mb * xp)
        logdet = (mb * s).sum(dim=1)
        
    xf = dynamics.g.compat_proj(xf)
    return State(x=xf, v=state.v, beta=state.beta), logdet


def update_v_fwd_tf(dynamics, step, state) -> tuple[State, tf.Tensor]:
    eps = dynamics.veps[step]
    force = dynamics.grad_potential(state.x, state.beta)
    s, t, q = dynamics._call_vnet(step, (state.x, force), training=False)
    jac = eps * s / 2.
    logdet = tf.reduce_sum(jac, axis=1)
    exp_s = tf.exp(jac)
    exp_q = tf.exp(eps * q)
    vf = exp_s * state.v - 0.5 * eps * (force * exp_q + t)
    return State(state.x, vf, state.beta), logdet


def update_x_fwd_tf(dynamics, step, state, first=True, m: Optional[Tensor] = None):
    eps = dynamics.xeps[step]
    if m is None:
        m, mb = dynamics._get_mask(step)
    else:
        mb = tf.ones_like(m) - m
        
    xm_init = tf.multiply(m, state.x)
    s, t, q = dynamics._call_xnet(step, (xm_init, state.v), first=first, training=False)
    s = eps * s
    q = eps * q
    exp_s = tf.exp(s)
    exp_q = tf.exp(q)
    TWO = tf.constant(2.0, dtype=TF_FLOAT)
    if dynamics.config.use_ncp:
        halfx = state.x / TWO
        _x = TWO * tf.math.atan(tf.math.tan(halfx) * exp_s)
        xp = _x + eps * (state.v * exp_q + t)
        xf = xm_init + (mb * xp)
        cterm = tf.math.square(tf.math.cos(halfx))
        sterm = (exp_s * tf.math.sin(halfx)) ** 2
        logdet_ = tf.math.log(exp_s / (cterm + sterm))
        logdet = tf.reduce_sum(mb * logdet_, axis=1)
    else:
        xp = state.x * exp_s + eps * (state.v * exp_q + t)
        xf = xm_init + (mb * xp)
        logdet = tf.reduce_sum((mb * s), axis=1)
        
    xf = dynamics.g.compat_proj(xf)
    return State(x=xf, v=state.v, beta=state.beta), logdet

In [None]:
m, mb = dynamics_tf._get_mask(0)
m_ = torch.tensor(m.numpy())
mb_ = torch.tensor(mb.numpy())

state1xf_pt_, logdet1xf_pt_ = update_x_fwd_pt(dynamics_pt, step=0, state=state_pt, first=True, m=m_)
state1xf_tf_, logdet1xf_tf_ = update_x_fwd_tf(dynamics_tf, step=0, state=state_tf, first=True, m=m)

In [None]:
state1vf_pt_, logdet1vf_pt_ = update_v_fwd_pt(dynamics_pt, step=0, state=state_pt)
state1vf_tf_, logdet1vf_tf_ = update_v_fwd_tf(dynamics_tf, step=0, state=state_tf)

In [None]:
state1vf_pt, logdet1vf_pt = dynamics_pt._update_v_bwd(0, state=state_pt)
state1vf_tf, logdet1vf_tf = dynamics_tf._update_v_bwd(0, state=state_tf)

In [None]:
state1xf_tf, logdet1xf_tf = dynamics_tf._update_x_fwd(0, state_tf, m=m, first=True)
state1xf_pt, logdet1xf_pt = dynamics_pt._update_x_fwd(0, state_pt, m=m_, first=True)

In [None]:
xm_init_tf = (m * state_tf.x)
xm_init_pt = (m_ * state_pt.x)

sxtf, txtf, qxtf = dynamics_tf._call_xnet(0, (xm_init_tf, state_tf.v), first=True, training=False)
sxpt, txpt, qxpt = dynamics_pt._call_xnet(0, (xm_init_pt, state_pt.v), first=True)

In [None]:
check_diff(
    sxtf, sxpt
)
check_diff(
    txtf, txpt
)
check_diff(
    qxtf, qxpt
)


In [None]:
exps_pt = sxpt.exp()
exps_tf = tf.exp(sxtf)

halfxpt = state_pt.x / 2.
halfxtf = state_tf.x / 2.

_xpt = 2. * (halfxpt.tan() * exps_pt).atan()
_xtf = 2. * tf.math.atan(tf.math.tan(halfxtf) * exps_tf)

In [None]:
check_diff(
    _xpt, _xtf
)

In [None]:
xeps_pt = dynamics_pt.xeps['0']
xeps_tf = dynamics_tf.xeps[0]

expq_pt = qxpt.exp()
expq_tf = tf.exp(qxtf)

xp_pt = _xpt + xeps_pt * (state_pt.v * expq_pt + txpt)
xp_tf = _xtf + xeps_tf * (state_tf.v * expq_tf + txtf)

check_diff(
    xp_pt, xp_tf
)

In [None]:
xf_pt = xm_init_pt + (mb_ * xp_pt)
xf_tf = xm_init_tf + (mb * xp_tf)

check_diff(
    xf_pt, xf_tf
)

In [None]:
cterm_tf = tf.math.square(tf.math.cos(halfxtf))
cterm_pt = halfxpt.cos() ** 2
sterm_tf = tf.math.square(exps_tf * tf.math.sin(halfxtf))
sterm_pt = (exps_pt * halfxpt.sin()) ** 2

check_diff(
    cterm_tf, cterm_pt
)
check_diff(
    sterm_tf, sterm_pt
)

In [None]:
logdetpt_ = (exps_pt / (cterm_pt + sterm_pt)).log()
logdettf_ = tf.math.log(exps_tf / (cterm_tf + sterm_tf))

check_diff(
    logdettf_, logdetpt_
)

In [None]:
logdet_pt = (mb_ * logdetpt_).sum(1)
logdet_tf = tf.reduce_sum((mb * logdettf_), axis=1)

check_diff(
    logdet_tf, logdet_pt
)

In [None]:
check_diff(
    tf.exp(sxtf), sxpt.exp()
)

In [None]:
check_diff(
    state1xf_pt_.v, state1xf_tf_.v
)

In [None]:
check_diff(
    m, m_
)

In [None]:
check_diff(
    logdet1xf_tf, logdet1xf_pt
)

In [None]:
check_diff(
    state1xf_tf.x, state1xf_pt.x
)

In [None]:
check_diff(
    state1vf_pt.v, state1vf_tf.v
)

In [None]:
check_diff(
    state1vf_pt_.v, state1vf_tf_.v
)

In [None]:
check_diff(
    dynamics_pt.veps['0'].detach().numpy(), dynamics_tf.veps[0].numpy()
)

In [None]:
force_pt = dynamics_pt.grad_potential(state_pt.x, state_pt.beta)
force_tf = dynamics_tf.grad_potential(state_tf.x, state_tf.beta)

check_diff(
    force_pt,
    force_tf
)

In [None]:
vnet0pt = dynamics_pt._get_vnet(0)
vnet0tf = dynamics_tf._get_vnet(0)

In [None]:
vnet0pt_xw = vnet0pt.x_layer.weight
vnet0pt_xb = vnet0pt.x_layer.bias
vnet0tf_xw, vnet0tf_xb = vnet0tf.get_layer('vnet_0_xLayer').get_weights()

In [None]:
vnet0pt_xw

In [None]:
vnet0tf_xw

In [None]:
vnet0pt_xw.shape
vnet0tf_xw.shape

In [None]:
check_diff(
    vnet0pt_xw.T, vnet0tf_xw
)

In [None]:
spt, tpt, qpt = dynamics_pt._call_vnet(0, (state_pt.x, force_pt))
stf, ttf, qtf = dynamics_tf._call_vnet(0, (state_tf.x, force_tf), training=False)

In [None]:
check_diff(
    spt, stf
)
check_diff(
    tpt, ttf
)
check_diff(
    qpt, qtf
)

In [None]:
check_diff(
   state1vf_pt_.v, state1vf_tf_.v
)

In [None]:
s1xf_pt = state1xf_pt.to_numpy()
s1xf_tf = state1xf_tf.to_numpy()

In [None]:
check_diff(
    m, m_
)

check_diff(
    mb, mb_
)

In [None]:
check_diff(
    logdet1xf_pt, logdet1xf_tf
)

In [None]:
check_diff(
    s1xf_pt, s1xf_tf
)

In [None]:
from l2hmc.utils.tests import check_diff as cd

s1vf_pt = state1vf_pt.to_numpy()
s1vf_tf = state1vf_tf.to_numpy()

In [None]:
check_diff(
    s1vf_pt, s1vf_tf
)

In [None]:
%debug

In [None]:
check_diff(
    dynamics_pt.g.compat_proj(state_pt.x),
    dynamics_tf.g.compat_proj(state_tf.x),
)

In [None]:
check_diff(
    dynamics_tf.hamiltonian(state_tf),
    dynamics_pt.hamiltonian(state_pt),
)

In [None]:
check_diff(
    
)

In [None]:
check_diff(
    metrics2_pt, metrics2_tf
)

In [None]:
check_diff(dynamics_pt.g.compat_proj(state2_pt.x), dynamics_tf.g.compat_proj(state2_tf.x))

In [None]:
list(dynamics_pt.xeps.values())

In [None]:
dynamics_tf.xeps = 0.1

In [None]:
dynamics_tf.veps = 0.1

In [None]:
dynamics_pt.xeps =

In [None]:
check_diff(state2_pt.v.detach().numpy(), state2_tf.v.numpy(), name='test-dv-l2hmc')

In [None]:
check_diff(state1_pt.x.detach().numpy(), state1_tf.x.numpy(), name='test-dx')

In [None]:
xtf_ = tf.reshape(xtf, (xtf.shape[0], -1))
xpt_ = xpt.flatten(1)
xtf1, metricstf1 = dynamics_tf((xtf_, beta_tf))
xpt1, metricspt1 = dynamics_pt((xpt_, beta_pt))

In [None]:
dynamics_tf.xeps[0]

In [None]:
dynamics_pt.xeps['0']

In [None]:
_ = metricstf1.pop('mc_states')
_ = metricspt1.pop('mc_states')
check_diff(metricstf1, metricspt1)

In [None]:
check_diff(xtf1, xpt1)

In [None]:
state_tf1, metricstf1 = dynamics_tf.transition_kernel_hmc(state_tf, eps=0.1, nleapfrog=10)
state_pt1, metricspt1 = dynamics_pt.transition_kernel_hmc(state_pt, eps=0.1, nleapfrog=10)

lmetricspt1 = lattice_pt.calc_metrics(state_pt1.x)
lmetricstf1 = lattice_tf.calc_metrics(state_tf1.x)

In [None]:
check_diff(metricstf1, metricspt1)

In [None]:
check_diff(lmetricspt1, lmetricstf1)

In [None]:
lmetrics = {
    'pt': get_lattice_metrics(ptExpU1, beta=1.0),
    'tf': get_lattice_metrics(tfExpU1, beta=1.0),
}

In [None]:
for (kp, vp), (kt, vt) in zip(lmetrics['pt'].items(), lmetrics['tf'].items()):