# Train and Evaluate L2HMC dynamics

We consider the case of a 2D $U(1)$ model on a square lattice with periodic boundary conditions.

The Wilson action is given by

$$ S_{\beta}(x) = \beta \sum_{P} 1 - \cos x_{P} $$

where $x_{P}$ is the sum of the gauge links around the elementary plaquette.

## Imports / setup

In [None]:
%load_ext autoreload
%autoreload 2
%autosave 120
%load_ext rich
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
#import matplotlib.pyplot as plt
#import mpld3
#mpld3.enable_notebook()

In [None]:
import matplotx
import matplotlib.pyplot as plt
import seaborn as sns

FIGSIZE = (7, 3)

plt.style.use('default')
colors = {
    'blue': '#007DFF',
    'red': '#FF5252',
    'yellow': '#FFFF00',
    'green': '#63FF5B',
    'purple': '#AE81FF',
    'orange': '#FD971F',
    'white': '#CFCFCF',
}

#plt.style.use(matplotx.styles.dufte)
# sns.set_palette(list(colors.values()))
plt.style.use('/Users/saforem2/.matplotlib/stylelib/molokai.mplstyle')
sns.set_context('notebook', font_scale=0.8)
plt.rcParams.update({
    'image.cmap': 'viridis',
    'figure.facecolor': (1.0, 1.0, 1.0, 0.),
    'axes.facecolor': (1.0, 1.0, 1.0, 0.),
    'axes.edgecolor': (0, 0, 0, 0.0),
    'figure.edgecolor': (0, 0, 0, 0.0),
    'figure.dpi': plt.rcParamsDefault['figure.dpi'],
    'figure.figsize': plt.rcParamsDefault['figure.figsize'],
    'xtick.color': (0, 0, 0, 0.0),
    'ytick.color': (0, 0, 0, 0.0),
    'xtick.labelcolor': '#666666', 
    'ytick.labelcolor': '#666666', 
})

In [None]:
import os
import sys
from pathlib import Path

modulepath = Path(os.getcwd()).parent.parent.parent
if modulepath.as_posix() not in sys.path:
    sys.path.append(modulepath.as_posix())

## Set floating point precision

In [None]:
import tensorflow as tf
from l2hmc.utils.hvd_init import HAS_HOROVOD, IS_CHIEF, RANK
tf.keras.backend.set_floatx('float32')  # or 'float64 for double precision
tf.keras.backend.floatx() == tf.float32

In [None]:
# PyTorch
import torch
torch.set_default_dtype(torch.float32)

In [None]:
import os
os.environ['AUTOGRAPH_VERBOSITY'] = '0'
tf.autograph.set_verbosity(0)

In [None]:
from l2hmc.configs import PROJECT_DIR, HERE
PROJECT_DIR
HERE

In [None]:
import os
from pathlib import Path
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from l2hmc.common import get_timestamp
from l2hmc.common import train as train_common
from l2hmc.configs import CONF_DIR
def train(framework: str = 'tensorflow', overrides: list[str] = None):
    conf_dir = Path(CONF_DIR).resolve().absolute().as_posix()
    day = get_timestamp('%Y-%m-%d')
    time = get_timestamp('%H-%M-%S')
    outdir = Path(os.getcwd()).joinpath('outputs', 'jupyter',
                                        f'{framework}', day, time)
    outdir.mkdir(exist_ok=True, parents=True)
    with initialize_config_dir(config_dir=conf_dir):
        overrides.append(f'framework={framework}')
        overrides.append(f'+outdir={outdir.as_posix()}')
        cfg=compose(
            config_name="config.yaml", 
            overrides=overrides,
        )
        print(OmegaConf.to_yaml(cfg, resolve=True))
        cfile = outdir.joinpath('config.yaml')
        output = train_common(cfg)
        
        print(f'Saving config to: {cfile}')
        with open(cfile, 'w') as f:
            f.write(OmegaConf.to_yaml(cfg, resolve=True))

    return cfg, output

In [None]:
import torch

def plot_plaqs_diffs(cfg, output, xarr=None):
    dynamics = output['setup']['dynamics']
    lattice = output['setup']['lattice']
    if xarr is None:
        xarr = output['train_output']['xarr']

    plaqs = []
    for x in xarr:
        if isinstance(x, torch.Tensor):
            plaqs.append(
                lattice.plaqs_diff(beta=cfg.get('beta'), x=x).detach().numpy()
            )
        else:
            plaqs.append(lattice.plaqs_diff(beta=cfg.get('beta'), x=x))

    plaqs = np.array(plaqs)
    ndraws, nchains = plaqs.shape
    xplot = np.arange(ndraws)
    with plt.style.context(matplotx.styles.dufte):
        fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)
        _ = ax.plot(xplot, plaqs.mean(-1), label='avg', lw=2.0, color='C0');
        for idx in range(16):
            _ = ax.plot(xplot, plaqs[:, idx], lw=1.0, alpha=0.5, color='C0')

        _ = ax.set_ylabel(r'$\delta x_{P}$')
        _ = ax.set_xlabel('Train Epoch')
        _ = ax.grid(True, alpha=0.4)

In [None]:
import xarray as xr
from l2hmc.utils.plot_helpers import plot_dataArray, make_ridgeplots

def make_plots(dataset: xr.Dataset, title: str = None, **kwargs):
    for key, val in dataset.data_vars.items():
        _ = plot_dataArray(val, key=key, num_chains=10, title=title, **kwargs)

In [None]:
OPTIONS = [
    'beta_init=1.0',
    'beta_final=3.0',
    'mode=cpu',
    '+width=150',
    'dynamics.nleapfrog=5',
    'dynamics.xshape=[128, 8, 8, 2]',
    'conv=none',
    'loss.aux_weight=1.0',
    'steps.nera=10',
    'steps.nepoch=100',
    'steps.print=20',
    'steps.log=10',
    'steps.test=500',
]
frameworks = {
    'pytorch': {},
    'tensorflow': {},
}
outputs = {}
configs = {}
outputs.update(frameworks)
configs.update(frameworks)

# TensorFlow: `merge_directions = True`

In [None]:
sns.set_context('notebook', font_scale=0.8)
overrides = ['dynamics.merge_directions=true'] + OPTIONS
config, output = train(framework='tensorflow', overrides=overrides)
_ = make_ridgeplots(output['train']['history'].get_dataset(), num_chains=64)

outputs['tensorflow'].update({
    'merge': {
        'config': config,
        'output': output,
    }
})

In [None]:
from l2hmc.common import analyze_dataset
train_dataset = output['train']['history'].get_dataset(therm_frac=0.25)
title = 'Training: Tensorflow, merge_directions=True, '
_ = analyze_dataset(dataset=train_dataset,
                    outdir=Path(config.outdir).joinpath('train'),
                    lattice=output['setup']['lattice'],
                    xarr=output['train']['output']['xarr'],
                    name='train', title=title, save=False)
_ = make_ridgeplots(train_dataset, num_chains=64)

In [None]:
eval_dataset = output['eval']['history'].get_dataset(therm_frac=0.25)
title = 'Evaluation: Tensorflow, merge_directions=True'
_ = analyze_dataset(dataset=eval_dataset,
                    outdir=Path(config.outdir).joinpath('eval'),
                    lattice=output['setup']['lattice'],
                    xarr=output['eval']['output']['xarr'],
                    name='eval', title=title, save=False)
_ = make_ridgeplots(eval_dataset, num_chains=64)

# PyTorch: `merge_directions = True`

In [None]:
overrides = ['dynamics.merge_directions=true'] + OPTIONS
config, output = train(framework='pytorch', overrides=overrides)
_ = make_ridgeplots(output['train']['history'].get_dataset(), num_chains=64)

outputs['pytorch'].update({
    'merge': {
        'config': config,
        'output': output,
    }
})

In [None]:
train_dataset = outputs['pytorch']['merge']['output']['train']['history'].get_dataset(therm_frac=0.2)
title = 'Training: PyTorch, merge_directions=True'
_ = analyze_dataset(dataset=train_dataset,
                    outdir=Path(config.outdir).joinpath('train'),
                    lattice=output['setup']['lattice'],
                    xarr=output['train']['output']['xarr'],
                    name='train', title=title, save=False)
_ = make_ridgeplots(train_dataset, num_chains=64)

In [None]:
eval_dataset = outputs['pytorch']['merge']['output']['eval']['history'].get_dataset(therm_frac=0.2)
title = 'Evaluation: PyTorch, merge_directions=True'
_ = analyze_dataset(dataset=eval_dataset,
                    outdir=Path(config.outdir).joinpath('eval'),
                    lattice=output['setup']['lattice'],
                    xarr=output['eval']['output']['xarr'],
                    name='eval', title=title, save=False)
_ = make_ridgeplots(eval_dataset, num_chains=64)

# TensorFlow: `merge_directions = False`

In [None]:
overrides = ['dynamics.merge_directions=false'] + OPTIONS
config, output = train(framework='tensorflow', overrides=overrides)
_ = make_ridgeplots(output['train']['history'].get_dataset(), num_chains=64)
outputs['tensorflow'].update({
    'no_merge': {
        'config': config,
        'output': output,
    }
})

In [None]:
train_dataset = outputs['tensorflow']['no_merge']['output']['train']['history'].get_dataset(therm_frac=0.2)
title = 'Training: Tensorflow, merge_directions=False'
_ = analyze_dataset(dataset=train_dataset,
                    outdir=Path(config.outdir).joinpath('train'),
                    lattice=output['setup']['lattice'],
                    xarr=output['train']['output']['xarr'],
                    name='train', title=title, save=False)
_ = make_ridgeplots(train_dataset, num_chains=64)

In [None]:
eval_dataset = outputs['tensorflow']['no_merge']['output']['eval']['history'].get_dataset(therm_frac=0.25)
title = 'Evaluation: Tensorflow, merge_directions=False'
_ = analyze_dataset(dataset=eval_dataset,
                    outdir=Path(config.outdir).joinpath('eval'),
                    lattice=output['setup']['lattice'],
                    xarr=output['eval']['output']['xarr'],
                    name='eval', title=title, save=False)
_ = make_ridgeplots(eval_dataset, num_chains=64)

# PyTorch: `merge_directions = False`

In [None]:
overrides = ['dynamics.merge_directions=false'] + OPTIONS
config, output = train(framework='pytorch', overrides=overrides)

_ = make_ridgeplots(output['train']['history'].get_dataset(), num_chains=64)

outputs['pytorch'].update({
    'no_merge': {
        'config': config,
        'output': output,
    }
})

In [None]:
train_dataset = outputs['pytorch']['no_merge']['output']['train']['history'].get_dataset(therm_frac=0.2)
title = 'Training: PyTorch, merge_directions=False'
_ = analyze_dataset(dataset=train_dataset,
                    outdir=Path(config.outdir).joinpath('train'),
                    lattice=output['setup']['lattice'],
                    xarr=output['train']['output']['xarr'],
                    name='train', title=title, save=False)
_ = make_ridgeplots(train_dataset, num_chains=64)

In [None]:
eval_dataset = outputs['pytorch']['no_merge']['output']['eval']['history'].get_dataset(therm_frac=0.2)
title = 'Evaluation: PyTorch, merge_directions=False'
_ = analyze_dataset(dataset=eval_dataset,
                    outdir=Path(config.outdir).joinpath('eval'),
                    lattice=output['setup']['lattice'],
                    xarr=output['eval']['output']['xarr'],
                    name='eval', title=title, save=False)
_ = make_ridgeplots(eval_dataset, num_chains=64)

# Test Reversibility

In [None]:
dynamics_merge_pt = outputs['pytorch']['merge']['output']['setup']['dynamics']
dynamics_no_merge_pt = outputs['pytorch']['no_merge']['output']['setup']['dynamics']
dynamics_merge_tf = outputs['tensorflow']['merge']['output']['setup']['dynamics']
dynamics_no_merge_tf = outputs['tensorflow']['no_merge']['output']['setup']['dynamics']

In [None]:
def test_reversibility(dynamics, name: str = None):
    diff = dynamics.test_reversibility()
    dx = diff['dx']
    dv = diff['dv']
    print(f'{name} ' + ', '.join([f'avg(dx): {dx.mean()}', f'avg(dv): {dv.mean()}']))
    #print('\n'.join([
        #', '.join([f'sum(dx): {dx.sum()}', f'sum(dv): {dv.sum()}']),
    #]))
    return diff

In [None]:
diff_merge_pt = test_reversibility(dynamics_merge_pt, name='Pytorch, merge')
diff_merge_tf = test_reversibility(dynamics_merge_tf, name='Tensorflow, merge')

diff_no_merge_pt = test_reversibility(dynamics_no_merge_pt, name='Pytorch, no merge')
diff_no_merge_tf = test_reversibility(dynamics_no_merge_tf, name='Tensorflow, no merge')

In [None]:
state_tf = dynamics_merge_tf.random_state()

_ = dynamics_merge_tf((state_tf.x, tf.constant(1.)))

In [None]:
dynamics_merge_tf.summary()

In [None]:
dynamics_merge_tf.xnet['0']['first'].build(input_shape=[(512,), (512,2)])

In [None]:
dynamics_merge_tf.vnet['0'].build(input_shape=[(512,), (512,)])

In [None]:
_ = dynamics_merge_tf.vnet['0']((tf.constant(state_tf.x), tf.constant(state_tf.v)))

In [None]:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

In [None]:
tf.keras.utils.plot_model(dynamics_merge_tf.vnet['0'])

In [None]:
%debug

In [None]:
#sum([sum(i) for i in dynamics_merge_pt.parameters()])
sum(p.numel() for p in dynamics_merge_pt.parameters()) # if p.requires_grad)
#len(list(dynamics_merge_pt.parameters()))

In [None]:
state_tf = dynamics_merge_tf.random_state()
#out = dynamics_merge_tf((state_tf.x, state_tf.beta))

In [None]:
#x = tf.reshape(dynamics_merge_tf._stack_as_xy(state_tf.x), (state_tf.x.shape[0], -1))
x = dynamics_merge_tf._stack_as_xy(state_tf.x)
x.shape

In [None]:
dynamics_merge_tf.compile()

In [None]:
out = dynamics_merge_tf((state_tf.x, state_tf.beta))

In [None]:
%debug

In [None]:
tf.keras.utils.plot_model(dynamics_merge_tf.xnet['0']['first'])

In [None]:
tf.keras.utils.plot_model(dynamics_merge_tf.vnet)

In [None]:
diff_merge_pt = dynamics_merge_pt.test_reversibility()
dx = diff_merge_pt["dx"].mean().detach().numpy()
dv = diff_merge_pt["dv"].mean().detach().numpy()
print(f'(dx, dv) = ({dx:.4g}, {dv:.4g})')

In [None]:
diff_merge_tf = dynamics_merge_tf.test_reversibility()
dx = tf.reduce_mean(diff_merge_tf["dx"])
dv = tf.reduce_mean(diff_merge_tf["dv"])
print(f'(dx, dv) = ({dx:.4g}, {dv:.4g})')

## Look at differences

In [None]:
outputs['tensorflow']['merge'].keys()

In [None]:
losses = {
    'tensorflow_merge': outputs['tensorflow']['merge']['output']['history'].history['loss'],
    'tensorflow_no_merge': outputs['tensorflow']['no_merge']['output']['history'].history['loss'],
    'pytorch_merge': outputs['pytorch']['merge']['output']['history'].history['loss'],
    'pytorch_no_merge': outputs['pytorch']['no_merge']['output']['history'].history['loss'],
}

In [None]:
import matplotx
plt.rcParams.update({
    'figure.dpi': 150,
})
fig, ax = plt.subplots(figsize=(9, 5))

COLORS = {
    'blue': '#03A9F4',
    'alt_blue': '#80D8FF',
    'red': '#F44336',
    'alt_red': '#FF8A80',
}

styles = {
    'tensorflow_merge': {'color': COLORS['alt_blue'], 'ls': '--'},
    'tensorflow_no_merge': {'color': colors['blue'], 'ls': '-'},
    'pytorch_merge': {'color': COLORS['alt_red'], 'ls': '--'},
    'pytorch_no_merge': {'color': colors['red'], 'ls': '-'}
}
# colors = ['#FF5252', '#007DFF', '#63FF5b', '#AE81FF']
for key, val in losses.items():
    _ = ax.plot(val[5::2], label=key, **styles[key])

matplotx.line_labels(ax=ax)
_ = ax.grid(alpha=0.2, axis='y')
_ = ax.set_ylabel(r'Loss')
_ = ax.set_xlabel('Train Epoch')