# `l2hmc-qcd`

This notebook contains a minimal working example for the 4D SU(3) Model

Uses `torch.complex128` by default

## Setup

In [None]:
! nvidia-smi | tail --lines -7

In [None]:
# automatically detect and reload local changes to modules
%load_ext autoreload
%autoreload 2

In [None]:
import os
# --------------------------------------
# BE SURE TO GRAB A FRESH GPU !
#os.environ['CUDA_VISIBLE_DEVICES'] = '2'
#!echo $CUDA_VISIBLE_DEVICES
# --------------------------------------

In [None]:
#devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
#print(devices)
!getconf _NPROCESSORS_ONLN

In [None]:
from __future__ import absolute_import, print_function, annotations, division

import warnings

import torch
import numpy as np

from hydra.core.global_hydra import GlobalHydra

warnings.filterwarnings('ignore')

os.environ['MASTER_PORT'] = '12345'
#os.environ['OMP_NUM_THREADS'] = '256'

np.set_printoptions(threshold=5)
torch.set_printoptions(threshold=5, precision=5)

In [None]:
from l2hmc.utils.dist import (
    setup_torch
)
_ = setup_torch(
    precision='float64',
    backend='DDP',
    seed=1234,
)

In [None]:
#from l2hmc import get_logger
#log = get_logger(level='INFO')

In [None]:
from l2hmc.experiment.pytorch.experiment import Experiment as ptExperiment

import l2hmc.group.su3.pytorch.group as gpt
ptsu3 = gpt.SU3()

import l2hmc
l2hmc.__file__

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

from l2hmc.common import grab_tensor, print_dict
from l2hmc.utils.plot_helpers import set_plot_style

set_plot_style()

from l2hmc.utils.plot_helpers import (  # noqa
    set_plot_style,
    plot_scalar,
    plot_chains,
    plot_leapfrogs
)


def plot_metrics(metrics: dict, title: Optional[str] = None):
    for key, val in metrics.items():
        fig, ax = plot_metric(val, name=key)
        if title is not None:
            ax.set_title(title)


def plot_metric(
        metric: torch.Tensor,
        name: Optional[str] = None,
):
    assert len(metric) > 0
    if isinstance(metric[0], (int, float, bool, np.floating)):
        y = np.stack(metric)
        return plot_scalar(y, ylabel=name)
    element_shape = metric[0].shape
    if len(element_shape) == 2:
        y = grab_tensor(torch.stack(metric))
        return plot_leapfrogs(y, ylabel=name)
    if len(element_shape) == 1:
        y = grab_tensor(torch.stack(metric))
        return plot_chains(y, ylabel=name)
    if len(element_shape) == 0:
        y = grab_tensor(torch.stack(metric))
        return plot_scalar(y, ylabel=name)
    raise ValueError

## Specify defaults for building Experiment

In [None]:
GlobalHydra.instance().clear()

In [None]:
DEFAULTS = {
    'save': False,
    'restore': False,
    'init_aim': False,
    'init_wandb': False,
    'backend': 'DDP',
    'framework': 'pytorch',
    'conv': None,
    'network': {
        'use_batch_norm': True,
        'activation_fn': 'tanh',
        'dropout_prob': 0.0,
        'units': [8, 8],
    },
    'dynamics': {
        'group': 'SU3',
        'eps': 0.1,
        'nchains': 2,
        'nleapfrog': 2,
        'verbose': True,
        'latvolume': [4, 4, 4, 4],
        'eps_fixed': False,
    },
    'net_weights': {
        'x': {
            's': 0.0,
            't': 1.0,
            'q': 1.0,
        },
        'v': {
            's': 1.0,
            't': 1.0,
            'q': 1.0,
        },
    },
    'loss': {
        'rmse_weight': 1.0,
        'plaq_weight': 0.0,
        'charge_weight': 0.0,
        'use_mixed_loss': False,
    },
    'steps': {
        'nera': 1,
        'nepoch': 10,
        'test': 50,
        'print': 1,
        'log': 1,
    },
    'learning_rate': {
        'lr_init': 0.00001,
        'clip_norm': 1e-8,
    },
    'annealing_schedule': {
        'beta_init': 6.0,
        'beta_final': 6.0,
    },
}

In [None]:
try:
    import gc
    import GPUtil

    gc.collect()
    with torch.no_grad():
        torch.cuda.empty_cache()
    torch.clear_autocast_cache()
    GPUtil.showUtilization()
except ImportError:
    pass

## Build Experiment

In [None]:
from l2hmc.configs import dict_to_list_of_overrides
OVERRIDES = dict_to_list_of_overrides(DEFAULTS)
OVERRIDES

In [None]:
from l2hmc.configs import get_experiment  # noqa
from l2hmc.configs import dict_to_list_of_overrides

GlobalHydra.instance().clear()

OVERRIDES = dict_to_list_of_overrides(DEFAULTS)

ptExpSU3 = get_experiment(
    overrides=[
        *OVERRIDES,
        'framework=pytorch',
        'backend=DDP',
        'init_wandb=False',
        'init_aim=False',
        'learning_rate.clip_norm=1.0',
        'restore=false',
        'save=false',
    ],
    build_networks=True,
)

## Evaluate

In [None]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
x0 = state.x

### HMC

In [None]:
from l2hmc.experiment.pytorch.experiment import evaluate  # noqa

xhmc, history_hmc = evaluate(
    nsteps=100,
    exp=ptExpSU3,
    beta=6.0,
    x=x0,
    eps=0.1,
    nleapfrog=4,
    job_type='hmc',
    nlog=1,
    nprint=1,
    grab=True
)

In [None]:
plot_metrics(history_hmc, title='HMC')

### Inference

In [None]:
ptExpSU3.trainer.dynamics.init_weights(
    method='uniform',
    min=-1e-16,
    max=1e-16,
    bias=True,
)

nlf = ptExpSU3.trainer.dynamics.config.nleapfrog
eps = torch.tensor(0.10)

#ptExpSU3.trainer.dynamics.xeps = torch.stack(
#    [
#        eps for _ in range(nlf)
#    ]
#)
#ptExpSU3.trainer.dynamics.veps = torch.stack(
#    [
#        eps for _ in range(nlf)
#    ]
#)

In [None]:
np.set_printoptions(precision=5)
torch.set_printoptions(precision=5)

In [None]:
_ = print_dict({
    k: v for k, v in ptExpSU3.trainer.dynamics.named_parameters()
}, grab=True)

In [None]:
from l2hmc.experiment.pytorch.experiment import evaluate  # noqa

#ptExpSU3.trainer.dynamics.init_weights()
#    #constant=np.random.randn() / 1e10
#    constant=0.0,
#)

xeval, history_eval = evaluate(
    nsteps=100,
    exp=ptExpSU3,
    beta=6.0,
    x=x0,
    job_type='eval',
    nlog=1,
    nprint=1,
    grab=True,
)

In [None]:
plot_metrics(history_eval, title='Evaluate')

In [None]:
_ = plot_metric(history_hmc['plaqs'], name='hmc/plaqs')

In [None]:
pratio = [
    ph / pe for (pe, ph) in zip(history_hmc['plaqs'], history_eval['plaqs'])
]
pdiff = [
    (ph - pe).abs() for (pe, ph) in zip(history_hmc['plaqs'], history_eval['plaqs'])
]

In [None]:
_ = plot_metric(pratio, name='plaqs [hmc / eval]')

In [None]:
_ = plot_metric(pdiff, name='plaqs [abs(hmc - eval)]')

### Training

In [None]:
from l2hmc.trainers.pytorch.trainer import Trainer  # noqa


def train_step(
        x: torch.Tensor,
        beta: float | torch.Tensor,
        trainer: Trainer,
):
    if isinstance(beta, float):
        beta = torch.tensor(beta)
    x.requires_grad_(True)
    trainer.optimizer.zero_grad()
    xout, metrics = trainer.dynamics_engine((x, beta))
    xprop = metrics.pop('mc_states').proposed.x
    #dx = (x - xprop).abs().flatten(1).mean(-1)
    #loss = (metrics['acc'] * dx).mean()
    loss = trainer.calc_loss(
        xinit=x,
        xprop=xprop,
        acc=metrics['acc']
    )
    loss.backward()
    trainer.optimizer.step()
    metrics = {'loss': loss.item(), **metrics}
    print_dict(metrics)
    return x.detach(), metrics

In [None]:
trainer = ptExpSU3.trainer
dynamics = trainer.dynamics

_ = print_dict({
    k: v.grad for k, v in dynamics.named_parameters()
})

In [None]:
_ = print_dict({
    k: v for k, v in dynamics.named_parameters()
}, grab=False)

In [None]:
beta = 6.0
state = dynamics.random_state(beta)
xinit = state.x
xinit.requires_grad_(True)
xout, metrics = train_step(x=xinit, beta=beta, trainer=ptExpSU3.trainer)

In [None]:
x, metrics = train_step(x=x0, beta=beta, trainer=ptExpSU3.trainer)

In [None]:
print_dict??

In [None]:
_ = print_dict(
    {
        k: v.grad for k, v in dynamics.named_parameters()
    },
    grab=True,
)

In [None]:
dynamics

In [None]:
_ = print_dict(
    {
        k: v for k, v in dynamics.named_parameters()
    },
    grab=True,
)

In [None]:
ptExpSU3.trainer.dynamics.init_weights(constant=0.0)

In [None]:
ptExpSU3.trainer.optimizer.step()

In [None]:
ptExpSU3.trainer.optimizer.zero_grad()

In [None]:
_ = print_dict({
    k: torch.nan_to_num(v) for k, v in dynamics.named_parameters()
}, grab=True)

In [None]:
xout, metrics = train_step(x=xinit.requires_grad_(True), beta=beta, trainer=ptExpSU3.trainer)

In [None]:
dynamics.networks.vnet.1.tran

In [None]:
xout, metrics = train_step(x=xinit, beta=beta, trainer=ptExpSU3.trainer)

In [None]:
x = xout

In [None]:
trainer.optimizer.zero_grad()
# xout, metrics = trainer.dynamics_engine((xinit, beta))
xout, metrics = dynamics((x, torch.tensor(beta)))

In [None]:
xprop = dynamics.g.compat_proj(metrics.pop('mc_states').proposed.x)
loss = trainer.calc_loss(
    xinit=xinit,
    xprop=xprop,
    acc=metrics['acc']
)

In [None]:
loss.backward()
trainer.optimizer.step()
metrics = {
    'loss': loss.item(),
    **metrics,
}
_ = print_dict(metrics, grab=True)

In [None]:
loss

In [None]:
_ = print_dict(metrics, grab=True)

In [None]:
loss

In [None]:
beta = torch.tensor(6.0)
xout1, metrics = dynamics((xout, beta))

In [None]:
print_dict({
    k: v.grad for k, v in trainer.dynamics.named_parameters()
})
for name, param in trainer.dynamics.named_parameters():
    #print(name, torch.isnan(param.grad))
    console.print(f'{name}:\n{param.grad}')

In [None]:
x

In [None]:
loss

In [None]:
trainer = ptExpSU3.trainer
dynamics = trainer.dynamics
beta = torch.tensor(6.0).to(trainer.device)

x, metrics = train_step(x=xeval, beta=beta, trainer=trainer)

In [None]:
x, metrics = train_step(x=xhmc, beta=beta, trainer=trainer)

In [None]:
for name, param in ptExpSU3.trainer.dynamics.named_parameters():
    #print(name, torch.isnan(param.grad))
    console.print(f'{name}: {param.grad.sum()}')

In [None]:
loss.item()

In [None]:
x.requires_grad_(True)
trainer.optimizer.zero_grad()
xout, metrics = trainer.dynamics_engine((x, beta))
xprop = metrics.pop('mc_states').proposed.x
ploss = trainer.calc_loss(
    xinit=x,
    xprop=xprop,
    acc=metrics['acc']
)
dx = (xprop.flatten(1) - xout.flatten(1)).abs().sum(-1)
rmse_loss = (metrics['acc'] * dx).mean()
loss = ploss + rmse_loss
loss.backward()
trainer.optimizer.step()

In [None]:
xout, metrics = trainer.forward_step(x, beta)
xprop = metrics.pop('mc_states').proposed.x
loss = trainer.calc_loss(xinit=x, xprop=xprop, acc=metrics['acc'])
#loss = ((x - xprop) ** 2).sum()
loss

In [None]:
x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
print_dict(metrics)

In [None]:
from l2hmc.utils.rich import get_console
console = get_console()
for name, param in ptExpSU3.trainer.dynamics.named_parameters():
    #print(name, torch.isnan(param.grad))
    console.print(f'{name}: {param.grad}')

In [None]:
ptExpSU3.trainer.optimizer.zero_grad()

In [None]:
x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
print_dict(metrics)

In [None]:
x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
print_dict(metrics)

In [None]:
ptExpSU3.trainer.optimizer.zero_grad()

In [None]:
from rich import print  # noqa
for name, param in ptExpSU3.trainer.dynamics.named_parameters():
    #print(name, torch.isnan(param.grad))
    print(f'{name}: {param.grad}')

In [None]:
loss

In [None]:
loss = trainer.backward_step(loss)

In [None]:
loss

In [None]:
from l2hmc.common import print_dict  # noqa

xinit = x.detach()
xout, metrics = trainer.forward_step(xinit, beta)
xprop = metrics.pop('mc_states').proposed.x
loss = trainer.calc_loss(xinit=xinit, xprop=xprop, acc=metrics['acc'])
loss = trainer.backward_step(loss)
print_dict(metrics)

In [None]:
x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
print_dict(metrics)

In [None]:
ptExpSU3.trainer.grad_scaler

## Train

In [None]:
for step in range(10):
    log.info(f'Train step: {step}')
    x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
    print_dict(metrics)

In [None]:
log.info('\n'.join([f'{k}={grab_tensor(v)}' for k, v in metrics.items()]))

In [None]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
x, metrics = ptExpSU3.trainer.train_step_detailed(x=state.x)

In [None]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
#x, metrics = ptExpSU3.trainer.train_step_detailed(x=state.x)
x = state.x
for _ in range(10):
    x, metrics = ptExpSU3.trainer.train_step_detailed(x=x)

In [None]:
%matplotlib widget
import seaborn as sns
from l2hmc.utils.plot_helpers import set_plot_style

set_plot_style()
sns.set(rc={"figure.dpi":100, 'savefig.dpi':300})
sns.set_context('notebook')
sns.set_style("ticks")
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('retina')

train_out = ptExpSU3.train(beta=6.0)

## Evaluation

In [None]:
from l2hmc.utils.plot_helpers import set_plot_style

set_plot_style()

eval_out = ptExpSU3.evaluate(
    job_type='eval',
    beta=6.0,
    eval_steps=100,
    nprint=1,
)

## Generic HMC

In [None]:
hmc_out = ptExpSU3.evaluate(
    job_type='hmc',
    beta=6.0,
    eval_steps=10,
    nprint=1,
    eps=0.075,
)