## l2hmc

In [1]:
%run /Users/saforem2/projects/l2hmc-qcd/src/l2hmc/utils/rich_style.py

## Imports

In [2]:
%load_ext autoreload
%autoreload 2
%autosave 120
%load_ext rich
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Autosaving every 120 seconds


In [3]:
from rich.theme import Theme
theme = Theme({
    "info": "dim cyan",
    "warning": "magenta",
    "danger": "bold red",
    "green": "#4CAF50",
    "yellow" :"#FFEB3B",
}, inherit=False)
console = Console(theme=theme, color_system='truecolor', log_path=False, force_jupyter=True)

In [4]:
import matplotx
import matplotlib.pyplot as plt
import seaborn as sns

FIGSIZE = (7, 3)

plt.style.use('default')
colors = {
    'blue': '#007DFF',
    'red': '#FF5252',
    'yellow': '#FFFF00',
    'green': '#63FF5B',
    'purple': '#AE81FF',
    'orange': '#FD971F',
    'white': '#CFCFCF',
}

plt.style.use(matplotx.styles.dufte)
sns.set_palette(list(colors.values()))
sns.set_context('notebook', font_scale=0.8)
plt.rcParams.update({
    'image.cmap': 'viridis',
    'figure.facecolor': (1.0, 1.0, 1.0, 0.),
    'axes.facecolor': (1.0, 1.0, 1.0, 0.),
    'axes.grid': False,
    'grid.color': '#cfcfcf',
    'figure.dpi': plt.rcParamsDefault['figure.dpi'],
    'figure.figsize': plt.rcParamsDefault['figure.figsize'],
})

In [5]:
import os
import sys
from pathlib import Path

modulepath = Path(os.getcwd()).parent.parent.parent
if modulepath.as_posix() not in sys.path:
    sys.path.append(modulepath.as_posix())

## Specify floating point precision to use for training

Training can be done in either:

 - `float32` (single precision)
 - `float64` (double precision)

In [6]:
# TensorFlow
import tensorflow as tf
tf.keras.backend.set_floatx('float32')  # or 'float64 for double precision
tf.keras.backend.floatx() == tf.float32

# PyTorch
import torch
torch.set_default_dtype(torch.float32)

In [7]:
import os
import tensorflow as tf

os.environ['AUTOGRAPH_VERBOSITY'] = '0'
# Verbosity is now 5

tf.autograph.set_verbosity(0)
# Verbosity is now 0

## Remaining Imports

In [8]:
import numpy as np

from src.l2hmc.configs import DynamicsConfig
from src.l2hmc.lattice.pytorch.lattice import Lattice as ptLattice
from src.l2hmc.lattice.tensorflow.lattice import Lattice as tfLattice

from src.l2hmc.network.pytorch.network import NetworkFactory as ptNetworkFactory
from src.l2hmc.network.tensorflow.network import NetworkFactory as tfNetworkFactory

2022-01-26 00:13:58.104584: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
d## Helper functions

In [9]:
from typing import Optional

def stack_history(history):
    history_ = {}
    for key, val in history.items():
        if not isinstance(val[0], (tf.Tensor, torch.Tensor, np.ndarray, float)):
            print(f'Skipping key: {key}, val.dtype: {type(val)}')
            continue
        elif isinstance(val[0], tf.Tensor):
            history_[key] = tf.stack(val).numpy()
        elif isinstance(val[0], torch.Tensor):
            history_[key] = torch.stack(val).detach().numpy()
        elif isinstance(val[0], (float, np.ndarray)):
            history_[key] = np.stack(val)
    return history_

def stack_directional_history_tf(history):
    hist = {}
    for key in history[0].keys():
        hist[key] = tf.stack([x.get(key) for x in history]).numpy()
    return hist

def stack_directional_history_pt(history):
    hist = {}
    for key in history[0].keys():
        hist[key] = torch.stack([x.get(key) for x in history]).detach().numpy()
    return hist

def stack(x):
    try:
        if isinstance(x, list):
            if isinstance(x[0], tf.Tensor):
                return tf.stack(x).numpy()
            if isinstance(x[0], torch.Tensor):
                try:
                    return torch.stack(x).detach().numpy()
                except:
                    return torch.Tensor(x).detach().numpy()
            return np.stack(x)
        return np.stack(x)
    except:
        return np.zeros_like(x)

def dict_summary(d):
    strs = []
    for key, val in d.items():
        if isinstance(val, dict):
            strs.append(f'{key}={dict_summary(val)}')
        else:
            try:
                strs.append(f'{key}={np.mean(val):.3g}')
            except:
                strs.append(f'{key}={val.mean():.3g}')
    return strs

## Define configurations

Note: We use a shared set of configuration objects for both the `pytorch` and `tensorflow` implementations

In [10]:
from src.l2hmc.configs import (
    Steps,
    InputSpec,
    LossConfig,
    NetworkConfig,
    NetWeight,
    DynamicsConfig,
    NetWeights,
)

beta = 2.
nleapfrog = 5           # trajectory length
eps_init = 0.005        # initial step size (trainable)
xshape = (64, 8, 8, 2)  # (nbatch, nt, nx, dim)

steps = {
    'log': 5,
    'nera': 10,
    'print': 25,
    'nepoch': 500,
    'test': 0,
}
steps = Steps(**steps)

# scaling factors multiplying the   (s , t ,  q) network functions
net_weights = NetWeights(x=NetWeight(1., 1., 1.),
                         v=NetWeight(1., 1., 1.))

net_config = NetworkConfig(units=[8, 8, 8, 8],  # sizes of hidden layers
                           dropout_prob=0.,  # dropout probability
                           activation_fn='relu',  # activation fn
                           use_batch_norm=False)  # use batch_norm


dynamics_config_fb = DynamicsConfig(xshape=xshape,
                                    eps=eps_init,
                                    nleapfrog=nleapfrog,
                                    use_ncp=True,
                                    verbose=True,
                                    eps_fixed=False,
                                    use_split_xnets=True,
                                    merge_directions=True,
                                    use_separate_networks=True)

dynamics_config = DynamicsConfig(xshape=xshape,
                                 eps=eps_init,
                                 nleapfrog=nleapfrog,
                                 use_ncp=True,
                                 verbose=True,
                                 eps_fixed=False,
                                 use_split_xnets=True,
                                 merge_directions=False,
                                 use_separate_networks=True)

xdim = dynamics_config.xdim
input_spec = InputSpec(xshape=xshape,
                       # note: we stack the input to the xNetwork
                       # as [cos(x), sin(x)], hence (xdim, 2) below
                       xnet={'x': (xdim, 2), 'v': (xdim,)},
                       vnet={'x': (xdim,), 'v': (xdim,)})

loss_config = LossConfig(use_mixed_loss=True,
                         plaq_weight=0.,
                         charge_weight=0.01)

## Build `pytorch` Dynamics object

In [11]:
from src.l2hmc.lattice.pytorch.lattice import Lattice as ptLattice
from src.l2hmc.dynamics.pytorch.dynamics import Dynamics as ptDynamics
from src.l2hmc.network.pytorch.network import NetworkFactory as ptNetworkFactory
from src.l2hmc.loss.pytorch.loss import LatticeLoss as ptLatticeLoss

from accelerate import Accelerator
accelerator = Accelerator()

device = accelerator.device

optim = torch.optim

pt_lattice = ptLattice(xshape)
potential_fn_pt = pt_lattice.action

pt_net_factory = ptNetworkFactory(input_spec=input_spec,
                                  net_weights=net_weights,
                                  network_config=net_config)

dynamics_pt = ptDynamics(potential_fn=potential_fn_pt,
                         config=dynamics_config,
                         network_factory=pt_net_factory)

dynamics_pt_fb = ptDynamics(potential_fn=potential_fn_pt,
                            config=dynamics_config_fb,
                            network_factory=pt_net_factory)

loss_pt = ptLatticeLoss(lattice=pt_lattice,
                        loss_config=loss_config)

optimizer_pt = optim.Adam(dynamics_pt.parameters())
optimizer_pt_fb = optim.Adam(dynamics_pt_fb.parameters())

loss_pt = ptLatticeLoss(lattice=pt_lattice, loss_config=loss_config)

## Build `tensorflow` Dynamics object

In [12]:
from src.l2hmc.lattice.tensorflow.lattice import Lattice as tfLattice
from src.l2hmc.dynamics.tensorflow.dynamics import Dynamics as tfDynamics
from src.l2hmc.network.tensorflow.network import NetworkFactory as tfNetworkFactory
from src.l2hmc.loss.tensorflow.loss import LatticeLoss as tfLatticeLoss

tf_lattice = tfLattice(xshape)
potential_fn_tf = tf_lattice.action

tf_net_factory = tfNetworkFactory(input_spec=input_spec,
                                  net_weights=net_weights,
                                  network_config=net_config)

dynamics_tf = tfDynamics(potential_fn=potential_fn_tf,
                         config=dynamics_config,
                         network_factory=tf_net_factory)

dynamics_tf_fb = tfDynamics(potential_fn=potential_fn_tf,
                            config=dynamics_config_fb,
                            network_factory=tf_net_factory)

loss_tf = tfLatticeLoss(lattice=tf_lattice, loss_config=loss_config)

In [13]:
dynamics_tf = tfDynamics(potential_fn=potential_fn_tf,
                         config=dynamics_config,
                         network_factory=tf_net_factory)

## Test lattice methods

In [50]:
xinit_tf = tf.random.uniform(xshape, *(-np.pi, np.pi))
xinit_np = xinit_tf.numpy()
xinit_pt = torch.tensor(xinit_np, requires_grad=True)

# Check that wilson loops agree between tensorflow and pytorch
wl_init_tf = tf_lattice.wilson_loops(x=xinit_tf)
wl_init_pt = pt_lattice.wilson_loops(x=xinit_pt)

dxinit = xinit_tf.numpy() - xinit_pt.detach().numpy()
dwl_init = wl_init_tf.numpy() - wl_init_pt.detach().numpy()
dwl_init.sum()
dwl_init.mean()
dxinit.sum()

In [42]:
tlatm = tf_lattice.observables(xinit_tf)
platm = pt_lattice.observables(xinit_pt)

In [43]:
(xinit_tf.numpy() == xinit_pt.detach().numpy()).all()

In [44]:
tlatm.plaqs.shape
tlatm.charges.intQ.shape
tlatm.charges.sinQ.shape
tlatm.p4x4.shape

In [45]:
(tlatm.plaqs.numpy() == platm.plaqs.detach().numpy()).all()
(tlatm.charges.intQ.numpy() == platm.charges.intQ.detach().numpy()).all()
print(
    f'sum(plaq_tf - plaq_pt) = '
    f'{(tlatm.plaqs.numpy() -  platm.plaqs.detach().numpy()).sum()}'
)
print(
    f'sum(intQ_tf - intQ_pt) = '
    f'{(tlatm.charges.intQ.numpy() - platm.charges.intQ.detach().numpy()).sum()}'
)
print(
    f'sum(sinQ_tf - sinQ_pt) = '
    f'{(tlatm.charges.sinQ.numpy() - platm.charges.sinQ.detach().numpy()).sum()}'
)

# Training

In [46]:
!echo $AUTOGRAPH_VERBOSITY
!set -e $AUTOGRAPH_VERBOSITY

0


## Tensorflow

### Train with single forward/backward update

Explicitly, we can set:

```python
DynamicsConfig.merge_directions = True
```

In [15]:
from src.l2hmc.trainers.tensorflow.trainer import Trainer as tfTrainer

optimizer_tf_fb = tf.keras.optimizers.Adam()

trainer_tf_fb = tfTrainer(steps=steps,
                          dynamics=dynamics_tf_fb,
                          optimizer=optimizer_tf_fb,
                          loss_fn=loss_tf)

output_tf_fb = trainer_tf_fb.train(compile=True, jit_compile=False)
history_tf_fb = output_tf_fb['history']
dataset_tf_fb = history_tf_fb.get_dataset()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

### Train using separate forward/backward updates (as usual)

In [18]:
from src.l2hmc.trainers.tensorflow.trainer import Trainer as tfTrainer

optimizer_tf = tf.keras.optimizers.Adam()

trainer_tf = tfTrainer(steps=steps,
                       dynamics=dynamics_tf,
                       optimizer=optimizer_tf,
                       loss_fn=loss_tf)
output_tf = trainer_tf.train(compile=True, jit_compile=False)
history_tf = output_tf['history']
dataset_tf = history_tf.get_dataset()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

## Pytorch

In [None]:
import time
from rich import print
from torch import optim
from src.l2hmc.dynamics.pytorch.dynamics import random_angle
from rich.console import Console

dynamics_pt.train()
dynamics_pt_fb.train()

optimizer_pt = optim.Adam(dynamics_pt.parameters())
optimizer_pt_fb = optim.Adam(dynamics_pt_fb.parameters())

xpt = random_angle(xinit_pt.shape)
xpt = xpt.reshape(xpt.shape[0], -1)


### Train with single forward/backward update (unified, `DynamicsConfig.merge_directions = True`)

In [None]:
from src.l2hmc.trainers.pytorch.trainer import Trainer as ptTrainer

dynamics_pt_fb = dynamics_pt_fb.to(accelerator.device)
dynamics_pt_fb, optimizer_pt_fb = accelerator.prepare(dynamics_pt_fb, optimizer_pt_fb)
trainer_pt_fb = ptTrainer(steps=steps,
                          dynamics=dynamics_pt_fb,
                          optimizer=optimizer_pt_fb,
                          loss_fn=loss_pt,
                          accelerator=accelerator)

output_pt_fb = trainer_pt_fb.train()
history_pt_fb = output_pt_fb['history']
dataset_pt_fb = history_pt_fb.get_dataset()

### Train with separate forward/backward updates (as usual)

In [59]:
from src.l2hmc.trainers.pytorch.trainer import Trainer as ptTrainer

dynamics_pt = dynamics_pt.to(accelerator.device)
dynamics_pt, optimizer_pt = accelerator.prepare(dynamics_pt, optimizer_pt)
trainer_pt = ptTrainer(steps=steps,
                       dynamics=dynamics_pt,
                       optimizer=optimizer_pt,
                       loss_fn=loss_pt,
                       accelerator=accelerator)

output_pt = trainer_pt.train()
history_pt = output_pt['history']
dataset_pt = history_pt.get_dataset()

[92m──────────────────────────────────── [0mERA: [1;36m0[0m[92m ────────────────────────────────────[0m


Output()

In [None]:
%debug

## Compare

Explicitly, look at how training metrics compare from `TensorFlow` vs. `PyTorch` models.

First, we define a helper function for plotting both `tensorflow` and `pytorch` metrics (converted to `np.ndarray`'s) on the same graph: 

In [None]:
from __future__ import annotations
def plot_both(
        ytf: np.ndarray,
        ypt: np.ndarray,
        figsize: Optional[tuple[int]] = FIGSIZE,
        xlabel: Optional[str] = None,
        ylabel: Optional[str] = None,
) -> tuple[plt.Figure, plt.Axes]:
    fig, ax = plt.subplots(figsize=figsize)
    _ = ax.plot(np.arange(ytf.shape[0]), ytf, label='Tensorflow');
    _ = ax.plot(np.arange(ypt.shape[0]), ypt, label='Pytorch');
    if xlabel is not None:
        _ = ax.set_xlabel(xlabel)
    if ylabel is not None:
        _ = ax.set_ylabel(ylabel)
    _ = ax.grid(True, color='#252525');
    matplotx.line_labels(ax=ax)
    return fig, ax

## Compare training loss:

In [None]:
plt.style.use(matplotx.styles.dufte)
plt.rcParams.update({'axes.grid': True})
loss_tf = np.array(history_tf.history['loss'])
loss_pt = np.array(history_pt.history['loss'])

fig, ax = plot_both(ytf=loss_tf, ypt=loss_pt, figsize=(9, 4),
                    xlabel='Train Epoch', ylabel='Loss')
ax.grid(axis='y', color='#686868')

# Plot data

In [None]:
import warnings
warnings.filterwarnings('ignore')

### Aggregate metrics to plot

1. Identify all keys across both histories
2. 

In [None]:
keys = list(history_tf.history.keys()) + list(history_pt.history.keys())
ytf_dict = {key: dataset_tf.data_vars.get(key, None) for key in keys}
ypt_dict = {key: dataset_pt.data_vars.get(key, None) for key in keys}
yboth_dict = {}
for k in keys:
    ytf = ytf_dict.get(k, None)
    ypt = ypt_dict.get(k, None)
    if ytf is not None and ypt is not None:
        yboth_dict[k] = {
            'pt': ypt,
            'tf': ytf,
        }
#yboth_dict = {k: {'pt': ypt_dict.get(k, None), 'tf': ytf_dict.get(k, None)} for k in keys}
ytf_dict = {k: v for k, v in ytf_dict.items() if v is not None}
ypt_dict = {k: v for k, v in ypt_dict.items() if v is not None}

In [None]:
ytf_dict.keys()
ypt_dict.keys()
yboth_dict.keys()

### Plot `TensorFlow` training data:

In [None]:
plt.style.use('default')
plt.style.use(matplotx.styles.dufte)
import src.l2hmc.utils.plot_helpers as hplt
sns.set_palette(list(colors.values()))
sns.set_context('paper')
plt.rcParams.update({
    'figure.dpi': plt.rcParamsDefault['figure.dpi'],
    'figure.figsize': plt.rcParamsDefault['figure.figsize'],
    'figure.facecolor': (0, 0, 0, 0.0),
    'axes.facecolor': (0, 0, 0, 0.0),
    'figure.edgecolor': (0, 0, 0, 0.0),
    'axes.edgecolor': (0, 0, 0, 0.0),
    'axes.spines.bottom': False,
    'axes.spines.left': False,
})
plt.style.use(matplotx.styles.dufte)
#plt.rcParams['axes.spines.bottom'] = False
for key, val in ytf_dict.items():
    fig, subfigs, axes = hplt.plot_dataArray(val,
                                             key=key,
                                             num_chains=5,
                                             title='Tensorflow');

### Plot `PyTorch` training data:

In [None]:
sns.set_context('notebook')

In [None]:
plt.style.use(matplotx.styles.dufte)
for key, val in ypt_dict.items():
    fig, subfigs, axes = hplt.plot_dataArray(val,
                                             key=key,
                                             num_chains=6,
                                             title='Pytorch')
    sns.despine(fig=fig, left=True, bottom=True, trim=True)

## Plot metrics common to both training datasets for comparison:

In [None]:
for key, val in yboth_dict.items():
    with plt.style.context(matplotx.styles.dufte):
        ytf = val['tf']
        ypt = val['pt']

        if len(ytf.shape) == 2:
            ytf = ytf.mean('chain')
            ypt = ypt.mean('chain')

        fig, ax = plot_both(ytf=ytf,
                            ypt=ypt,
                            figsize=FIGSIZE,
                            xlabel='Train Epoch', ylabel=key)
        _ = ax.grid(visible=False, axis='x');

## Plot **all** `TensorFlow` training metrics:

In [None]:
with plt.style.context(matplotx.styles.dufte):
    _ = history_tf.plot_all(num_chains=10, title="Tensorflow")

## Plot **all** `PyTorch` training metrics:

In [None]:
with plt.style.context(matplotx.styles.dufte):
    _ = history_pt.plot_all(num_chains=10, title='Pytorch')

In [None]:
_ = hplt.make_ridgeplots(dataset_pt)