# Example script for training the L2HMC sampler

 - **NOTE**:
   - The following results were generated on my local MacBook Pro with a 2.3 GHz 8-Core Intel Core i9 CPU
   - The following notebook CAN be ran WITHOUT having `Horovod` installed

## Imports / setup for training

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_enable_xla_devices'
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "True"

In [None]:
import sys
import time
import json
import logging
import datetime

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf
if os.path.abspath('..') not in sys.path:
    sys.path.append(os.path.abspath('..'))
    
from utils.hvd_init import RANK, SIZE

from rich.theme import Theme
from rich import print, get_console

from config import PROJECT_DIR, BIN_DIR
from utils.logger import Logger, print_dict

sns.set_palette('bright')
plt.style.use('default')
sns.set_context('talk')
sns.set_style('whitegrid')
sns.set_palette('bright')

plt.rc('text', usetex=False)

logger = Logger()
console = get_console()
console._width = 180
#console.use_theme(theme)

In [None]:
plt.style.use('default')
sns.set_style('ticks')
sns.set_context('notebook', font_scale=0.8)
colors = ['#228BE6', '#FA5252', '#40C057',
          '#FF920B', '#BE4BDB', '#FAB005',
          '#E64980', '#6A777E', '#4C6EF5']
sns.set_palette(colors)


## Load configs from `BIN_DIR/train_configs.json`:

In [None]:
import json
from config import BIN_DIR

from utils.logger import Logger, print_dict

logger = Logger()

train_configs_file = os.path.join(BIN_DIR, 'train_configs.json')
with open(train_configs_file, 'rt') as f:
    configs = json.load(f)
    
restore_from = os.path.abspath(
    '/lus/grand/projects/DLHMC/l2hmc-qcd/logs/GaugeModel_logs/2021_08/'
    'L16_b256_lf10_actswish_bi1_bf2_dp0025_nh16161616_sepNets_NCProj_ConvNets_bNorm_nw111111/'
)
configs.update({
    'ensure_new': False,
    'train_steps': 100000,
    'debug': False,
    'run_steps': 20000,
    'save_steps': 20000,
    'steps_per_epoch': 5000,
    'patience': 2,
    'min_lr': 1e-4,
    'logging_steps': 1000,
    'print_steps': 1000,
    'beta_init': 1.,
    'beta_final': 3.,
    'restore_from': restore_from,
})

configs['dynamics_config'].update({
    'use_conv_net': True,
    'use_mixed_loss': False,
    'aux_weight': 0.0,
    'num_steps': 10,
    'x_shape': [256, 16, 16, 2],

})

configs['network_config'].update({
    'units': [16, 16, 16, 16],
    'dropout_prob': 0.025,
})

configs['conv_config'].update({
    'input_shape': configs['dynamics_config']['x_shape'][1:]
})

logger.log(print_dict(configs, name='Configs'))

In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', 'UserWarning')
warnings.filterwarnings('ignore', 'CustomMaskWarning')
warnings.filterwarnings('ignore', 'WARNING:matplotlib')

## Run training:

In [None]:
from rich.markdown import Markdown
from utils.training_utils import train 
import seaborn as sns

plt.style.use('default')
sns.set_style('whitegrid')
sns.set_context('notebook')
colors = ['#228BE6', '#FA5252', '#40C057',
          '#FF920B', '#BE4BDB', '#FAB005',
          '#E64980', '#6A777E', '#4C6EF5']
sns.set_palette(colors)

# only make plots for 8 chains to speed up plotting
num_chains = 16 

# draw initial x uniformly from [-pi, pi]:
x_shape = configs['dynamics_config'].get('x_shape', None)
x = tf.random.uniform(x_shape, minval=-np.pi, maxval=np.pi)
x = tf.reshape(x, (x.shape[0], -1))

logger.console.log(Markdown('#Training'))
train_outputs = train(configs, x=x,
                      make_plots=True,
                      num_chains=num_chains)

In [None]:
train_dataset = train_outputs.data.get_dataset()

draws = len(train_dataset.loss)
xscale = configs['train_steps'] // draws
x = xscale * np.arange(configs['train_steps'])[:draws]

loss = train_dataset.loss
dq_int = train_dataset.dq_int.mean(dim='chain')
dq_sin = train_dataset.dq_sin.mean(dim='chain')
px = train_dataset.accept_prob.mean(dim='chain')
beta = train_dataset.beta

In [None]:
THIN = 15

fig, ax = plt.subplots(figsize=(9, 4), dpi=150,
                       constrained_layout=True)

ax.plot(x[::THIN], loss[::THIN], color='C0', label='Loss');
ax.set_ylabel('Loss', color='C0');
ax.tick_params(axis='y', labelcolor='C0');

ax1 = ax.twinx();
ax.set_xlabel(f'Training step');
ax.grid(False);
ax1.grid(False);
#ax1.set_xlim((30000, 90000))
#ax1.axhline(y=1, color='#828282');
ax1.plot(x[::THIN], beta[::THIN], color='k', ls='--', label=r"$\beta$");
ax1.plot(x[::THIN], px[::THIN], color='C2', alpha=0.9, label=r"$A(\xi'|\xi)$");
ax1.plot(x[::THIN], dq_int[::THIN], color='C3', alpha=0.9, label=r"$\delta\mathcal{Q}_{\mathbb{Z}}$");
ax1.plot(x[::THIN], dq_sin[::THIN], color='C4', alpha=0.9, label=r"$\delta\mathcal{Q}_{\mathbb{R}}$");
ax1.legend(fontsize='small', ncol=2);


## Run inference on trained Model

In [None]:
from utils.inference_utils import run as run_inference

configs['run_steps'] = 5000
log_dir = configs.get('log_dir', None)
beta = configs.get('beta_final', None)
dynamics = train_outputs.dynamics

inference_results = run_inference(dynamics=dynamics,  # pass the trained dynamics
                                  configs=configs, 
                                  md_steps=0,
                                  beta=beta,
                                  make_plots=True,
                                  therm_frac=0.,
                                  num_chains=16,
                                  save_x=False)

### Make plots from training data

- **NOTE**: Primed quantities below refer to the (modified) proposal configurations at the end of each trajectory, i.e. for an initial configuration $\xi=(x, v, \pm)$, we generate a proposal configuration by passing $\xi$ through $N_{\mathrm{LF}}$ *leapfrog layers*, i.e.
$\xi\rightarrow \xi_1\rightarrow \cdots\rightarrow \xi_{N_{\mathrm{LF}}-1}\rightarrow \xi_{N_{\mathrm{LF}}} \equiv \xi'$

- Mostly, we are interested in the following quantities:

    - Various **stepsizes**, $\varepsilon^{k}_{x},\, \varepsilon^{k}_{v}$, for $k = 0, 1, \ldots, N_{\mathrm{LF}}$

    - **Error in the average plaquette**: $\langle\delta x_{P}\rangle \equiv  x_{P}^{*} - \langle x_{P}\rangle$ where $ x_{P}^{*}$ is the exact result from the infinite volume limit, and is given by
      $ x_{P}^{*}(\beta) = \tfrac{I_{1}(\beta)}{I_{0}(\beta)}$

    - **Error in the average $4\times4$ Wilson loop** $\langle\mathcal{W}_{4\times4}\rangle$

    - **log Jacobian factor**, denoted by `sumlogdet`, i.e. the total log determinant of the Jacobian of transformation $\xi\rightarrow\xi'$, given by:
      $\sum\left|\mathcal{J}\right| = \left|\tfrac{\partial\xi'}{\partial\xi^{T}}\right|$

    - **Effective energy**, $\tilde{\mathcal{H}} = \mathcal{H} - \sum\left|\mathcal{J}\right|$

    - **Acceptance probability** $A(\xi'|\xi) = \min\left\{1, \tfrac{p(\xi')}{p(\xi)}\left|\tfrac{\partial\xi'}{\partial\xi^{T}}\right|\right\}$

    - *Integer-valued* **topological charge** $\mathcal{Q}_{\mathbb{Z}} \equiv \tfrac{1}{2\pi}\sum_{P}\left\lfloor x_{P}\right\rfloor$, where 
      $\left\lfloor x_{P}\right\rfloor \equiv  x_{P} - 2\pi\left\lfloor\tfrac{ x_{P}+\pi}{2\pi}\right\rfloor$

    - *Real-valued* **topological charge** $\mathcal{Q}_{\mathbb{R}} \equiv = \tfrac{1}{2\pi}\sum_{P}\sin x_{P}$

    - **Tunneling rates** $\delta\mathcal{Q}_{\mathcal{X}}(x', x) \equiv |\mathcal{Q}_{\mathcal{X}}' -  \mathcal{Q}_{\mathcal{X}}|$ for $\mathcal{X} \in \left\{\mathbb{R}, \mathbb{Z}\right\}$.

    - **Loss** $\mathcal{L}_{\theta}\left(x', x, A(\xi'|\xi)\right) = -A(\xi'|\xi)\cdot \delta\mathcal{Q}_{\mathbb{R}}$
    
    - **Integrated autocorrelation time** $\tau_{\mathcal{Q}_{\mathcal{Z}}}$ which serves as a usefeul metric for quantifying the models' improvements when compared with HMC.

In [None]:
import utils.file_io as io
from utils.plotting_utils import set_size, make_ridgeplots, plot_data

logger.rule('Plotting training dataset')
num_chains_to_plot = 10
plt.style.use('default')
sns.set_style('ticks')
sns.set_context('notebook', font_scale=0.8)
colors = ['#228BE6', '#FA5252', '#40C057',
          '#FF920B', '#BE4BDB', '#FAB005',
          '#E64980', '#6A777E', '#4C6EF5']
sns.set_palette(colors)

output = plot_data(train_outputs.data,
                   out_dir=None,
                   configs=configs,
                   therm_frac=0.,
                   num_chains=16,
                   cmap='viridis_r',
                   logging_steps=configs['logging_steps'])

## Plot inference results

In [None]:
dynamics = inference_results.dynamics
run_data = inference_results.run_data
xeps_avg = tf.reduce_mean(dynamics.xeps)
veps_avg = tf.reduce_mean(dynamics.veps)
eps_avg = (xeps_avg + veps_avg) / 2.
run_params = {
    'hmc': dynamics.config.hmc,
    'beta': beta,
    'run_steps': configs['run_steps'],
    'plaq_weight': dynamics.plaq_weight,
    'charge_weight': dynamics.charge_weight,
    'x_shape': dynamics.x_shape,
    'num_steps': dynamics.config.num_steps,
    'net_weights': dynamics.net_weights,
    'input_shape': dynamics.x_shape,
    'xeps': dynamics.xeps,
    'veps': dynamics.veps,
    'xeps_avg': xeps_avg,
    'veps_avg': veps_avg,
    'eps_avg': eps_avg,
    'traj_len': tf.reduce_sum(dynamics.xeps),
}
output = plot_data(inference_results.run_data,
                   out_dir=None,
                   configs=configs,
                   therm_frac=0.,
                   num_chains=16,
                   params=run_params,
                   hmc=dynamics.config.hmc,
                   cmap='crest',
                   logging_steps=1)