# `l2hmc-qcd`

This notebook contains a minimal working example for the 4D $SU(3)$ model

Uses `torch.complex128` by default

## Setup

In [1]:
# %matplotlib inline
# import matplotlib_inline
# matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import lovely_tensors as lt
lt.monkey_patch()
lt.set_config(color=False)
# automatically detect and reload local changes to modules
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from l2hmc.utils.plot_helpers import FigAxes, set_plot_style
set_plot_style()

Using device: cpu


In [2]:
import os
from pathlib import Path
from typing import Optional
from rich import print

import lovely_tensors as lt
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml

from l2hmc.utils.dist import setup_torch
seed = np.random.randint(2 ** 32)
print(f"seed: {seed}")
_ = setup_torch(precision='float64', backend='DDP', seed=seed)

import l2hmc.group.su3.pytorch.group as g
from l2hmc.utils.rich import get_console
from l2hmc.common import grab_tensor, print_dict
from l2hmc.configs import dict_to_list_of_overrides, get_experiment
from l2hmc.experiment.pytorch.experiment import Experiment, evaluate  # noqa  # noqa
from l2hmc.utils.plot_helpers import set_plot_style

os.environ['COLORTERM'] = 'truecolor'
os.environ['MASTER_PORT'] = '5439'
# os.environ['MPLBACKEND'] = 'module://matplotlib-backend-kitty'
# plt.switch_backend('module://matplotlib-backend-kitty')
console = get_console()


set_plot_style()

from l2hmc.utils.plot_helpers import (  # noqa
    set_plot_style,
    plot_scalar,
    plot_chains,
    plot_leapfrogs
)

def savefig(fig: plt.Figure, fname: str, outdir: os.PathLike):
    pngfile = Path(outdir).joinpath(f"pngs/{fname}.png")
    svgfile = Path(outdir).joinpath(f"svgs/{fname}.svg")
    pngfile.parent.mkdir(exist_ok=True, parents=True)
    svgfile.parent.mkdir(exist_ok=True, parents=True)
    fig.savefig(svgfile, transparent=True, bbox_inches='tight')
    fig.savefig(pngfile, transparent=True, bbox_inches='tight', dpi=300)

def plot_metrics(metrics: dict, title: Optional[str] = None, **kwargs):
    outdir = Path(f"./plots-4dSU3/{title}")
    outdir.mkdir(exist_ok=True, parents=True)
    for key, val in metrics.items():
        fig, ax = plot_metric(val, name=key, **kwargs)
        if title is not None:
            ax.set_title(title)
        console.log(f"Saving {key} to {outdir}")
        savefig(fig, f"{key}", outdir=outdir)
        plt.show()

def plot_metric(
        metric: torch.Tensor,
        name: Optional[str] = None,
        **kwargs,
):
    assert len(metric) > 0
    if isinstance(metric[0], (int, float, bool, np.floating)):
        y = np.stack(metric)
        return plot_scalar(y, ylabel=name, **kwargs)
    element_shape = metric[0].shape
    if len(element_shape) == 2:
        y = grab_tensor(torch.stack(metric))
        return plot_leapfrogs(y, ylabel=name)
    if len(element_shape) == 1:
        y = grab_tensor(torch.stack(metric))
        return plot_chains(y, ylabel=name, **kwargs)
    if len(element_shape) == 0:
        y = grab_tensor(torch.stack(metric))
        return plot_scalar(y, ylabel=name, **kwargs)
    raise ValueError

seed: [1;36m1444224244[0m
[38;2;105;105;105m[07/20/23 09:22:51][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mdist.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m338[0m[2;38;2;144;144;144m][0m - Global Rank: [38;2;32;148;243m0[0m [32m/[0m [38;2;32;148;243m0[0m


## Load config + build Experiment

In [3]:
from rich import print
set_plot_style()

from l2hmc.configs import CONF_DIR
su3conf = Path(f"{CONF_DIR}/su3test.yaml")
with su3conf.open('r') as stream:
    conf = dict(yaml.safe_load(stream))
# overrides = {
#     'backend': 'DDP',
#     'dynamics': {
#         'eps': 0.15,
#         'merge_directions': True,
#     },
#     'network': {
#         'use_batch_norm': True,
#     },
#     'loss': {
#         'use_mixed_loss': False,
#     },
#     'net_weights': {
#         'x': {
#             's': 0.0,
#             't': 0.0,
#             'q': 0.0,
#         },
#         'v': {
#             's': 1.0,
#             't': 1.0,
#             'q': 1.0,
#         },
#     }
# }
# conf |= overrides
console.print(conf)

[1m{[0m
    [38;2;80;161;79m'annealing_schedule'[0m: [1m{[0m[38;2;80;161;79m'beta_final'[0m: [38;2;32;148;243m6.0[0m, [38;2;80;161;79m'beta_init'[0m: [38;2;32;148;243m6.0[0m[1m}[0m,
    [38;2;80;161;79m'backend'[0m: [38;2;80;161;79m'DDP'[0m,
    [38;2;80;161;79m'conv'[0m: [38;2;80;161;79m'none'[0m,
    [38;2;80;161;79m'dynamics'[0m: [1m{[0m
        [38;2;80;161;79m'eps'[0m: [38;2;32;148;243m0.01[0m,
        [38;2;80;161;79m'eps_fixed'[0m: [3;91mFalse[0m,
        [38;2;80;161;79m'group'[0m: [38;2;80;161;79m'SU3'[0m,
        [38;2;80;161;79m'latvolume'[0m: [1m[[0m[38;2;32;148;243m4[0m, [38;2;32;148;243m4[0m, [38;2;32;148;243m4[0m, [38;2;32;148;243m8[0m[1m][0m,
        [38;2;80;161;79m'nchains'[0m: [38;2;32;148;243m4[0m,
        [38;2;80;161;79m'nleapfrog'[0m: [38;2;32;148;243m4[0m,
        [38;2;80;161;79m'use_separate_networks'[0m: [3;92mTrue[0m,
        [38;2;80;161;79m'use_split_xnets'[0m: [3;92mTrue[0m,
        [3

In [4]:
overrides = dict_to_list_of_overrides(conf)
ptExpSU3 = get_experiment(overrides=[*overrides], build_networks=True)
console.print(ptExpSU3.config)
state = ptExpSU3.trainer.dynamics.random_state(6.0)
console.print(f"checkSU(state.x): {g.checkSU(state.x)}")
console.print(f"checkSU(state.x): {g.checkSU(g.projectSU(state.x))}")
assert isinstance(state.x, torch.Tensor)
assert isinstance(state.beta, torch.Tensor)
assert isinstance(ptExpSU3, Experiment)

[38;2;105;105;105m[07/20/23 09:23:05][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mdist.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m226[0m[2;38;2;144;144;144m][0m - Caught MASTER_PORT:[38;2;32;148;243m5439[0m from environment!
[38;2;105;105;105m[07/20/23 09:23:05][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mdist.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m226[0m[2;38;2;144;144;144m][0m - Caught MASTER_PORT:[38;2;32;148;243m5439[0m from environment!
[38;2;105;105;105m[07/20/23 09:23:05][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mtrainer.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m285[0m[2;38;2;144;144;144m][0m - num_params in model: [38;2;32;148;243m89762056[0m
[1;38;2;255;0;255mExperimentConfig[0m[1m([0m
    [1m{[0m[38;2;80;161;79m'setup'[0m: [1m{[0m[38;2;80;161;79m'id'[0m: [3;38;2;255;0;255mNone[0m, [38;2;80;161;79m'group'[0m: [3;38;2;255;0;255mNone[0m, [38;2;

## HMC

In [None]:
from l2hmc.utils.plot_helpers import set_plot_style
set_plot_style()

In [None]:
from l2hmc.common import get_timestamp
TSTAMP = get_timestamp()
OUTPUT_DIR = Path(f"./outputs/pt4dSU3/{TSTAMP}")
HMC_DIR = OUTPUT_DIR.joinpath('hmc')
EVAL_DIR = OUTPUT_DIR.joinpath('eval')
TRAIN_DIR = OUTPUT_DIR.joinpath('train')
HMC_DIR.mkdir(exist_ok=True, parents=True)
EVAL_DIR.mkdir(exist_ok=True, parents=True)
TRAIN_DIR.mkdir(exist_ok=True, parents=True)

In [5]:
xhmc, history_hmc = evaluate(
    nsteps=100,
    exp=ptExpSU3,
    beta=state.beta,
    x=state.x,
    eps=0.1,
    nleapfrog=4,
    job_type='hmc',
    nlog=5,
    nprint=10,
    grab=True
)

[38;2;105;105;105m[07/20/23 09:28:52][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mexperiment.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m117[0m[2;38;2;144;144;144m][0m - Running [38;2;32;148;243m100[0m steps of hmc at [38;2;125;134;151mbeta[0m=[38;2;32;148;243m6[0m[38;2;32;148;243m.0000[0m
[38;2;105;105;105m[07/20/23 09:28:52][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mexperiment.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m117[0m[2;38;2;144;144;144m][0m - Running [38;2;32;148;243m100[0m steps of hmc at [38;2;125;134;151mbeta[0m=[38;2;32;148;243m6[0m[38;2;32;148;243m.0000[0m
[38;2;105;105;105m[07/20/23 09:28:52][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mexperiment.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m121[0m[2;38;2;144;144;144m][0m - STEP: [38;2;32;148;243m0[0m
[38;2;105;105;105m[07/20/23 09:28:52][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;

In [6]:
import l2hmc.utils.plot_helpers as ph
xhmc = ptExpSU3.trainer.dynamics.unflatten(xhmc)
console.log(f"checkSU(x_hmc): {g.checkSU(xhmc)}")
dataset_hmc = history_hmc.get_dataset()
ph.plot_dataset(dataset_hmc, outdir=HMC_DIR)
# plot_metrics(history_hmc, title='HMC', marker='.')

[38;2;105;105;105m[09:34:20][0m[38;2;105;105;105m [0m[1;38;2;255;0;255mcheckSU[0m[1m([0mx_hmc[1m)[0m: [1m([0mtensor[1m[[0m[38;2;32;148;243m4[0m[1m][0m f64 x∈[1m[[0m[38;2;32;148;243m1.535e-16[0m, [38;2;32;148;243m3.394e-16[0m[1m][0m [38;2;125;134;151mμ[0m=[38;2;32;148;243m2[0m[38;2;32;148;243m.424e-16[0m [38;2;125;134;151mσ[0m=[38;2;32;148;243m7[0m[38;2;32;148;243m.754e-17[0m [1m[[0m[38;2;32;148;243m1.535e-16[0m, [38;2;32;148;243m2.570e-16[0m,  
[38;2;105;105;105m           [0m[38;2;32;148;243m2.197e-16[0m, [38;2;32;148;243m3.394e-16[0m[1m][0m, tensor[1m[[0m[38;2;32;148;243m4[0m[1m][0m f64 x∈[1m[[0m[38;2;32;148;243m7.452e-16[0m, [38;2;32;148;243m9.438e-16[0m[1m][0m [38;2;125;134;151mμ[0m=[38;2;32;148;243m8[0m[38;2;32;148;243m.309e-16[0m [38;2;125;134;151mσ[0m=[38;2;32;148;243m8[0m[38;2;32;148;243m.260e-17[0m [1m[[0m[38;2;32;148;243m7.452e-16[0m,       
[38;2;105;105;105m           [0m[38;2;32;148;243m8.

<xarray.Dataset>
Dimensions:    (chain: 4, leapfrog: 5, draw: 19)
Coordinates:
  * chain      (chain) int64 0 1 2 3
  * leapfrog   (leapfrog) int64 0 1 2 3 4
  * draw       (draw) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
Data variables:
    energy     (chain, leapfrog, draw) float64 -7.282e+03 ... -1.033e+04
    logprob    (chain, leapfrog, draw) float64 -7.282e+03 ... -1.033e+04
    logdet     (chain, leapfrog, draw) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
    acc        (chain, draw) float64 1.0 1.0 1.0 ... 1.0 0.0003556 0.006828
    sumlogdet  (chain, draw) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
    acc_mask   (chain, draw) float32 1.0 1.0 1.0 0.0 0.0 ... 0.0 1.0 1.0 0.0 0.0
    plaqs      (chain, draw) float64 0.3914 0.4412 0.4666 ... 0.5545 0.5545
    sinQ       (chain, draw) float64 -0.002635 -3.447e-05 ... 0.002555 0.002555
    intQ       (chain, draw) float64 -0.07688 -0.001006 ... 0.07456 0.07456
    dQint      (chain, draw) float64 0.02149 0.03452 0.0

In [7]:
history_hmc.plot_all(title="HMC", outdir=HMC_DIR)

[38;2;105;105;105m[07/20/23 10:25:53][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/hmc/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 10:25:53][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/hmc/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 10:25:53][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to:

<xarray.Dataset>
Dimensions:    (chain: 4, leapfrog: 5, draw: 19)
Coordinates:
  * chain      (chain) int64 0 1 2 3
  * leapfrog   (leapfrog) int64 0 1 2 3 4
  * draw       (draw) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
Data variables:
    energy     (chain, leapfrog, draw) float64 -7.282e+03 ... -1.033e+04
    logprob    (chain, leapfrog, draw) float64 -7.282e+03 ... -1.033e+04
    logdet     (chain, leapfrog, draw) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
    acc        (chain, draw) float64 1.0 1.0 1.0 ... 1.0 0.0003556 0.006828
    sumlogdet  (chain, draw) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
    acc_mask   (chain, draw) float32 1.0 1.0 1.0 0.0 0.0 ... 0.0 1.0 1.0 0.0 0.0
    plaqs      (chain, draw) float64 0.3914 0.4412 0.4666 ... 0.5545 0.5545
    sinQ       (chain, draw) float64 -0.002635 -3.447e-05 ... 0.002555 0.002555
    intQ       (chain, draw) float64 -0.07688 -0.001006 ... 0.07456 0.07456
    dQint      (chain, draw) float64 0.02149 0.03452 0.0

In [8]:
# history_hmc.plot_dataArray1(dataset_hmc.plaqs, key='Plaqs (HMC)')
# ph.plot_array(dataset_hmc.plaqs.values, key='Plaqs (HMC)')

(<Figure size 1240x480 with 1 Axes>, <Axes: >)


## Evaluation

In [9]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
ptExpSU3.trainer.dynamics.init_weights(
    # method='zeros',
    # constant=0.0,
    # method='uniform',
    # min=-1e-3,
    # max=1e-3,
    # bias=True,
    xeps=0.05,
    veps=0.05,
)
xeval, history_eval = evaluate(
    nsteps=100,
    exp=ptExpSU3,
    beta=6.0,
    x=state.x,
    job_type='eval',
    nlog=5,
    nprint=10,
    grab=True,
)

[38;2;105;105;105m[07/20/23 10:47:00][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mexperiment.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m117[0m[2;38;2;144;144;144m][0m - Running [38;2;32;148;243m100[0m steps of eval at [38;2;125;134;151mbeta[0m=[38;2;32;148;243m6[0m[38;2;32;148;243m.0000[0m
[38;2;105;105;105m[07/20/23 10:47:00][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mexperiment.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m117[0m[2;38;2;144;144;144m][0m - Running [38;2;32;148;243m100[0m steps of eval at [38;2;125;134;151mbeta[0m=[38;2;32;148;243m6[0m[38;2;32;148;243m.0000[0m
[38;2;105;105;105m[07/20/23 10:47:00][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mexperiment.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m121[0m[2;38;2;144;144;144m][0m - STEP: [38;2;32;148;243m0[0m
[38;2;105;105;105m[07/20/23 10:47:00][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;14

In [None]:
dataset_eval = history_eval.get_dataset()

In [10]:
# plot_metrics(history_eval, title='Evaluate', marker='.')
history_eval.plot_all(outdir=EVAL_DIR, title='Eval')

xeval = ptExpSU3.trainer.dynamics.unflatten(xeval)
console.log(f"checkSU(x_eval): {g.checkSU(xeval)}")
console.log(f"checkSU(x_eval): {g.checkSU(g.projectSU(xeval))}")

[38;2;105;105;105m[07/20/23 10:58:26][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/eval/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 10:58:26][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/eval/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 10:58:26][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure t

In [11]:
# plt.rcParams['figure.dpi'] = 300
# plot_metric(np.stack(history_eval.history['plaqs']), name='Plaqs (Eval)', marker='.')

[0;31m---------------------------------------------------------------------------[0m
[0;31mTypeError[0m                                 Traceback (most recent call last)
Cell [0;32mIn[74], line 2[0m
[1;32m      1[0m plt[38;5;241m.[39mrcParams[[38;5;124m'[39m[38;5;124mfigure.dpi[39m[38;5;124m'[39m] [38;5;241m=[39m [38;5;241m300[39m
[0;32m----> 2[0m [43mplot_metric[49m[43m([49m[43mnp[49m[38;5;241;43m.[39;49m[43mstack[49m[43m([49m[43mhistory_eval[49m[38;5;241;43m.[39;49m[43mhistory[49m[43m[[49m[38;5;124;43m'[39;49m[38;5;124;43mplaqs[39;49m[38;5;124;43m'[39;49m[43m][49m[43m)[49m[43m,[49m[43m [49m[43mname[49m[38;5;241;43m=[39;49m[38;5;124;43m'[39;49m[38;5;124;43mPlaqs (Eval)[39;49m[38;5;124;43m'[39;49m[43m,[49m[43m [49m[43mmarker[49m[38;5;241;43m=[39;49m[38;5;124;43m'[39;49m[38;5;124;43m.[39;49m[38;5;124;43m'[39;49m[43m)[49m

Cell [0;32mIn[4], line 73[0m, in [0;36mplot_metric[0;34m(metric, name, **kwargs)

## Training

In [12]:
ptExpSU3.trainer.dynamics.init_weights(
    # method='xavier_uniform',
    # constant=0.0,
    # method='uniform',
    # min=-1e-6,
    # max=1e-6,
    # bias=True,
    xeps=0.05,
    veps=0.05,
)
# ptExpSU3.trainer.optimizer.zero_grad()
ptExpSU3.trainer.print_grads_and_weights()

[38;2;105;105;105m[07/20/23 11:16:26][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mtrainer.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m2121[0m[2;38;2;144;144;144m][0m - --------------------------------------------------------------------------------
[38;2;105;105;105m[07/20/23 11:16:26][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mtrainer.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m2122[0m[2;38;2;144;144;144m][0m - GRADS:
[38;2;105;105;105m[07/20/23 11:16:26][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mcommon.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m97[0m[2;38;2;144;144;144m][0m - networks.xnet.[38;2;32;148;243m0.[0mfirst.input_layer.xlayer.weight: [3;38;2;255;0;255mNone[0m [3;38;2;255;0;255mNone[0m 
[3;38;2;255;0;255mNone[0m
networks.xnet.[38;2;32;148;243m0.[0mfirst.input_layer.xlayer.bias: [3;38;2;255;0;255mNone[0m [3;38;2;255;0;255mNone[0m 
[3;38;2;255;0;255mNone[0m

In [13]:
from l2hmc.utils.history import BaseHistory
# history = {}
history: BaseHistory = BaseHistory()
# state = ptExpSU3.trainer.dynamics.random_state(6.0)
# x = state.x
freq = {'print': 5, 'save': 5}
for step in range(100):
    console.print(f'TRAIN STEP: {step}')
    x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
    if (step > 0 and step % freq['print'] == 0):
        print_dict(metrics, grab=True)
    if (step > 0 and step % freq['save'] == 0):
        history.update(metrics)
        # for key, val in metrics.items():
        #     try:
        #         history[key].append(val)
        #     except KeyError:
        #         history[key] = [val]
# plot_metrics(history, title='train', marker='.')

TRAIN STEP: [38;2;32;148;243m0[0m

TRAIN STEP: [38;2;32;148;243m1[0m

TRAIN STEP: [38;2;32;148;243m2[0m

TRAIN STEP: [38;2;32;148;243m3[0m

TRAIN STEP: [38;2;32;148;243m4[0m

TRAIN STEP: [38;2;32;148;243m5[0m

[38;2;105;105;105m[07/20/23 11:14:17][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mcommon.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m97[0m[2;38;2;144;144;144m][0m - energy: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0m[38;2;32;148;243m-9171.37842821[0m [38;2;32;148;243m-5353.76897792[0m [38;2;32;148;243m-9731.61150597[0m [38;2;32;148;243m-8698.55691276[0m[1m][0m
 [1m[[0m[38;2;32;148;243m-9045.99306112[0m [38;2;32;148;243m-5384.02122852[0m [38;2;32;148;243m-9571.10269974[0m [38;2;32;148;243m-8574.71252864[0m[1m][0m
 [1m[[0m[38;2;32;148;243m-8743.96534469[0m [38;2;32;148;243m-5462.76512169[0m [38;2;32;14

In [14]:
dataset = history.get_dataset()

[38;2;105;105;105m[07/20/23 11:13:25][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/train/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 11:13:25][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/train/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 11:13:25][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure

In [15]:
history.plot_all(outdir=TRAIN_DIR, title='Train')

xeval = ptExpSU3.trainer.dynamics.unflatten(xeval)
console.log(f"checkSU(x_train): {g.checkSU(x)}")
console.log(f"checkSU(x_train): {g.checkSU(g.projectSU(x))}")

[38;2;105;105;105m[07/20/23 11:09:31][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/train/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 11:09:31][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dSU3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m20[0m-[38;2;32;148;243m092918[0m/train/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/20/23 11:09:31][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1028[0m[2;38;2;144;144;144m][0m - Saving figure

In [16]:
x = ptExpSU3.trainer.dynamics.unflatten(x)
console.print(f"checkSU(x_train): {g.checkSU(x)}")
dataset = history.get_dataset()

[1;38;2;255;0;255mcheckSU[0m[1m([0mx_train[1m)[0m: [1m([0mtensor[1m[[0m[38;2;32;148;243m4[0m[1m][0m f64 x∈[1m[[0m[38;2;32;148;243m1.439e-16[0m, [38;2;32;148;243m1.487e-16[0m[1m][0m [38;2;125;134;151mμ[0m=[38;2;32;148;243m1[0m[38;2;32;148;243m.461e-16[0m [38;2;125;134;151mσ[0m=[38;2;32;148;243m2[0m[38;2;32;148;243m.449e-18[0m [1m[[0m[38;2;32;148;243m1.441e-16[0m, [38;2;32;148;243m1.476e-16[0m, [38;2;32;148;243m1.439e-16[0m, [38;2;32;148;243m1.487e-16[0m[1m][0m, tensor[1m[[0m[38;2;32;148;243m4[0m[1m][0m f64 x∈[1m[[0m[38;2;32;148;243m7.127e-16[0m, [38;2;32;148;243m8.391e-16[0m[1m][0m [38;2;125;134;151mμ[0m=[38;2;32;148;243m7[0m[38;2;32;148;243m.605e-16[0m [38;2;125;134;151mσ[0m=[38;2;32;148;243m5[0m[38;2;32;148;243m.458e-17[0m [1m[[0m[38;2;32;148;243m7.451e-16[0m, [38;2;32;148;243m8.391e-16[0m, [38;2;32;148;243m7.127e-16[0m, [38;2;32;148;243m7.451e-16[0m[1m][0m[1m)[0m



In [17]:
# matplotlib.use('module://matplotlib-kitty')
import l2hmc.utils.plot_helpers as ph
from pathlib import Path
from l2hmc.common import get_timestamp

tstamp = get_timestamp()
outdir = Path(f"./outputs/pt4dsu3/2023-07-19/{tstamp}")
ph.plot_dataset(dataset, outdir=outdir, title='Training')

No event loop hook running.
Using matplotlib backend: <object object at 0x1073d22d0>
[38;2;105;105;105m[07/19/23 22:09:42][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1006[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dsu3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m19[0m/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m19[0m-[38;2;32;148;243m220942[0m/ridgeplots/svgs/energy_ridgeplot.svg
[38;2;105;105;105m[07/19/23 22:09:43][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mplot_helpers.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m1006[0m[2;38;2;144;144;144m][0m - Saving figure to: outputs/pt4dsu3/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m19[0m/[38;2;32;148;243m2023[0m-[38;2;32;148;243m07[0m-[38;2;32;148;243m19[0m-[38;2;32;148;243m220942[0m/ridgeplots/svgs/logprob_ridgep

<xarray.Dataset>
Dimensions:    (chain: 9, leapfrog: 9, draw: 19)
Coordinates:
  * chain      (chain) int64 0 1 2 3 4 5 6 7 8
  * leapfrog   (leapfrog) int64 0 1 2 3 4 5 6 7 8
  * draw       (draw) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
Data variables: (12/18)
    energy     (chain, leapfrog, draw) float64 -4.774e+03 -4.831e+03 ... nan nan
    logprob    (chain, leapfrog, draw) float64 -4.774e+03 -4.831e+03 ... nan nan
    logdet     (chain, leapfrog, draw) float64 0.0 0.0 0.0 0.0 ... nan nan nan
    sldf       (chain, leapfrog, draw) float64 0.0 0.0 0.0 0.0 ... nan nan nan
    sldb       (chain, leapfrog, draw) float64 0.0 0.0 0.0 0.0 ... nan nan nan
    sld        (chain, leapfrog, draw) float64 0.0 0.0 0.0 0.0 ... nan nan nan
    ...         ...
    loss       (draw) float64 -0.7719 -0.7836 -0.7867 -0.8042 ... 0.0 0.0 0.0
    plaqs      (chain, draw) float64 0.2526 0.2629 0.2779 0.2908 ... nan nan nan
    sinQ       (chain, draw) float64 0.002607 -0.0002384 0.000401 ..

In [None]:
import l2hmc.utils.plot_helpers as ph
from l2hmc.common import di

In [None]:
# fig, ax = plt.subplots()
figax = plot_metric(
    history['plaqs'],
    name='Plaqs (Training)',
    marker='.'
)
fig, ax = figax
fig.savefig('4dSU3-train-plaqs-2023-07-19.svg', bbox_inches='tight')

In [18]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
acc = torch.ones(state.x.shape[0])
m, mb = ptExpSU3.trainer.dynamics._get_mask(0)
loss = torch.tensor([0.])
with torch.autograd.detect_anomaly(check_nan=True):  # flake8: noqa  pyright:ignore
    ptExpSU3.trainer.optimizer.zero_grad()
    # state_vb, logdet_vf = ptExpSU3.trainer.dynamics._update_v_bwd(
    #     step=0, state=state_vf
    # )
    # console.log(f'TRAIN STEP: {step}')
    x, metrics = ptExpSU3.trainer.train_step((state.x, state.beta))
    loss = metrics['loss']
    # loss_xb = ptExpSU3.trainer.calc_loss(state_xb.x, state.x, acc=acc)
    # loss_xf = ptExpSU3.trainer.calc_loss(state_xf.x, state.x, acc=acc)
    # loss_ = loss_xb + loss_xf
    # loss = ptExpSU3.trainer.backward_step(loss_)
    console.print(f"loss: {loss:.5f}")

loss: [38;2;32;148;243m-0.04936[0m



In [19]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
acc = torch.ones(state.x.shape[0])
m, mb = ptExpSU3.trainer.dynamics._get_mask(0)
loss = torch.tensor([0.])
step = 0
from l2hmc.configs import State
with torch.autograd.detect_anomaly(check_nan=True):  # flake8: noqa  pyright:ignore
    ptExpSU3.trainer.optimizer.zero_grad()
    # sumlogdet = torch.zeros(state.x.shape[0], device=self.device)
    state_vf1, logdet = ptExpSU3.trainer.dynamics._update_v_fwd(step, state)
    sumlogdet = logdet
    # state_ = State(state.x, vf1, state.beta)
    state_xf1, logdet = ptExpSU3.trainer.dynamics._update_x_fwd(step, state_vf1, m, first=True)
    sumlogdet = sumlogdet + logdet
    state_xf2, logdet = ptExpSU3.trainer.dynamics._update_x_fwd(step, state_xf1, mb, first=False)
    sumlogdet = sumlogdet + logdet
    state_vf2, logdet = ptExpSU3.trainer.dynamics._update_v_fwd(step, state_xf2)
    sumlogdet = sumlogdet + logdet
    # state_, logdet = ptExpSU3.trainer.dynamics._forward_lf(step=0, state=state)
    # state_, logdet = ptExpSU3.trainer.dynamics._update_x_fwd(
    #     step=0, state=state, m=m, first=True
    # )
    # state_xb, logdet_xb = ptExpSU3.trainer.dynamics._update_x_bwd(
    #     step=0, state=state_xf, m=m, first=True
    # )
    # state_vb, logdet_vf = ptExpSU3.trainer.dynamics._update_v_bwd(
    #     step=0, state=state_vf
    # )
    # loss_xb = ptExpSU3.trainer.calc_loss(state_xb.x, state.x, acc=acc)
    loss_ = ptExpSU3.trainer.calc_loss(state_xf1.x, state.x, acc=acc)
    # loss_ = loss_xb + loss_xf
    loss = ptExpSU3.trainer.backward_step(loss_)
    console.print(f"loss: {loss.item():.5f}")

loss: [38;2;32;148;243m-0.00020[0m



In [20]:
state = ptExpSU3.trainer.dynamics.random_state(6.0)
acc = torch.ones(state.x.shape[0])
m, mb = ptExpSU3.trainer.dynamics._get_mask(0)
loss = torch.tensor([0.])
with torch.autograd.detect_anomaly(check_nan=True):  # flake8: noqa  pyright:ignore
    ptExpSU3.trainer.optimizer.zero_grad()
    state_vf, logdet_vf = ptExpSU3.trainer.dynamics._update_v_fwd(
        step=0, state=state
    )
    state_vb, logdet_vf = ptExpSU3.trainer.dynamics._update_v_bwd(
        step=0, state=state_vf
    )
    loss_ = ptExpSU3.trainer.calc_loss(state.x, state_vf.x, acc=acc)
    loss = ptExpSU3.trainer.backward_step(loss_)
    console.print(f"loss: {loss.item():.5f}")

loss: [38;2;32;148;243m0.00000[0m



In [21]:
from torch import autograd
state = ptExpSU3.trainer.dynamics.random_state(6.0)
acc = torch.ones(state.x.shape[0])
m, mb = ptExpSU3.trainer.dynamics._get_mask(0)
loss = torch.tensor([0.])
# ptExpSU3.trainer.dynamics.init_weights(
#     method='uniform',
#     min=-1e-32,
#     max=1e-32,
#     bias=True,
#     xeps=0.001,
#     veps=0.001,
# )

with autograd.detect_anomaly(check_nan=True):
    ptExpSU3.trainer.optimizer.zero_grad()
    state_vf, logdet_vf = ptExpSU3.trainer.dynamics._update_v_fwd(
        step=0, state=state
    )
    state_xf, logdet_xf = ptExpSU3.trainer.dynamics._update_x_fwd(
        step=0, state=state_vf, m=m, first=True
    )
    console.print(f"state_xf.x.shape: {state_xf.x.shape}")
    avg, diff = ptExpSU3.trainer.g.checkSU(state_xf.x)
    console.print(f"avg: {avg}, diff: {diff}")
    loss_ = ptExpSU3.trainer.calc_loss(state.x, state_xf.x, acc=acc)
    loss = ptExpSU3.trainer.backward_step(loss_)
    console.print(f"loss: {loss.item():.5f}")

state_xf.x.shape: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m4[0m, [38;2;32;148;243m4[0m, [38;2;32;148;243m4[0m, [38;2;32;148;243m4[0m, [38;2;32;148;243m4[0m, [38;2;32;148;243m8[0m, [38;2;32;148;243m3[0m, [38;2;32;148;243m3[0m[1m][0m[1m)[0m

avg: tensor[1m[[0m[38;2;32;148;243m4[0m[1m][0m f64 x∈[1m[[0m[38;2;32;148;243m8.420e-08[0m, [38;2;32;148;243m1.064e-07[0m[1m][0m [38;2;125;134;151mμ[0m=[38;2;32;148;243m9[0m[38;2;32;148;243m.412e-08[0m [38;2;125;134;151mσ[0m=[38;2;32;148;243m1[0m[38;2;32;148;243m.068e-08[0m grad SqrtBackward0 [1m[[0m[38;2;32;148;243m8.420e-08[0m, [38;2;32;148;243m1.064e-07[0m, [38;2;32;148;243m8.622e-08[0m, [38;2;32;148;243m9.967e-08[0m[1m][0m, diff: tensor[1m[[0m[38;2;32;148;243m4[0m[1m][0m f64 x∈[1m[[0m[38;2;32;148;243m1.942e-07[0m, [38;2;32;148;243m2.979e-07[0m[1m][0m [38;2;125;134;151mμ[0m=[38;2;32;148;243m2[0m[38;2;32;148;243m.382e-07[0m [38;2;125;134;151mσ[

In [22]:
from torch import autograd
from l2hmc.dynamics.pytorch.dynamics import sigmoid

step = 0
dynamics = ptExpSU3.trainer.dynamics
state = ptExpSU3.trainer.dynamics.random_state(6.0)
acc = torch.ones(state.x.shape[0])
m, mb = ptExpSU3.trainer.dynamics._get_mask(0)
loss = torch.tensor([0.])
ptExpSU3.trainer.dynamics.init_weights(
    method='uniform',
    min=-1e-32,
    max=1e-32,
    bias=True,
    xeps=0.001,
    veps=0.001,
)

with autograd.detect_anomaly(check_nan=True):
    eps = sigmoid(dynamics.veps[step].log())
    force = dynamics.grad_potential(state.x, state.beta)
    s, t, q = dynamics._call_vnet(step, (state.x, force))
    logjac = eps * s / 2.  # jacobian factor, also used in exp_s below
    logdet = dynamics.flatten(logjac).sum(1)
    force = force.reshape_as(state.v)
    exp_s = (logjac.exp()).reshape_as(state.v)
    exp_q = (eps * q).exp().reshape_as(state.v)
    t = t.reshape_as(state.v)
    vf = (exp_s * state.v) - (0.5 * eps * (force * exp_q + t))
    if dynamics.config.group == 'SU3':
        vf = dynamics.g.projectTAH(vf)
    loss_ = ptExpSU3.trainer.calc_loss(state.v, vf, acc=acc)
    loss_.register_hook(lambda grad: grad.clamp_(max=1.0))
    loss_.register_hook(lambda grad: console.print(f"grad: {grad}"))
    ptExpSU3.trainer.optimizer.zero_grad()
    loss_.backward()
    torch.nn.utils.clip_grad.clip_grad_norm(dynamics.parameters(), max_norm=1.0)
    ptExpSU3.trainer.optimizer.step()
    # loss = ptExpSU3.trainer.backward_step(loss_)
    console.print(f"loss: {loss_.item():.5f}")

grad: [38;2;32;148;243m1.0[0m

[0;31m---------------------------------------------------------------------------[0m
[0;31mRuntimeError[0m                              Traceback (most recent call last)
Cell [0;32mIn[80], line 36[0m
[1;32m     34[0m loss_[38;5;241m.[39mregister_hook([38;5;28;01mlambda[39;00m grad: console[38;5;241m.[39mprint([38;5;124mf[39m[38;5;124m"[39m[38;5;124mgrad: [39m[38;5;132;01m{[39;00mgrad[38;5;132;01m}[39;00m[38;5;124m"[39m))
[1;32m     35[0m ptExpSU3[38;5;241m.[39mtrainer[38;5;241m.[39moptimizer[38;5;241m.[39mzero_grad()
[0;32m---> 36[0m [43mloss_[49m[38;5;241;43m.[39;49m[43mbackward[49m[43m([49m[43m)[49m
[1;32m     37[0m torch[38;5;241m.[39mnn[38;5;241m.[39mutils[38;5;241m.[39mclip_grad[38;5;241m.[39mclip_grad_norm(dynamics[38;5;241m.[39mparameters(), max_norm[38;5;241m=[39m[38;5;241m1.0[39m)
[1;32m     38[0m ptExpSU3[38;5;241m.[39mtrainer[38;5;241m.[39moptimizer[38;5;241m.[39mstep()

Fi

In [23]:
history = {}
x = state.x
for step in range(50):
    console.log(f'TRAIN STEP: {step}')
    x, metrics = ptExpSU3.trainer.train_step((x, state.beta))
    if (step > 0 and step % 2 == 0):
        print_dict(metrics, grab=True)
    if (step > 0 and step % 1 == 0):
        for key, val in metrics.items():
            try:
                history[key].append(val)
            except KeyError:
                history[key] = [val]

x = ptExpSU3.trainer.dynamics.unflatten(x)
console.log(f"checkSU(x_train): {g.checkSU(x)}")
plot_metrics(history, title='train', marker='.')



[38;2;105;105;105m[07/17/23 11:09:40][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mcommon.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m97[0m[2;38;2;144;144;144m][0m - energy: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m[1m][0m
logprob: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [

[38;2;105;105;105m[07/17/23 11:09:42][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mcommon.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m97[0m[2;38;2;144;144;144m][0m - energy: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m[1m][0m
logprob: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [

[38;2;105;105;105m[07/17/23 11:09:44][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mcommon.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m97[0m[2;38;2;144;144;144m][0m - energy: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m[1m][0m
logprob: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [

[38;2;105;105;105m[07/17/23 11:09:45][0m[34m[INFO][0m[2;38;2;144;144;144m[[0m[2;38;2;144;144;144mcommon.py[0m[2;38;2;144;144;144m:[0m[2;38;2;144;144;144m97[0m[2;38;2;144;144;144m][0m - energy: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m[1m][0m
logprob: [1;38;2;255;0;255mtorch.Size[0m[1m([0m[1m[[0m[38;2;32;148;243m9[0m, [38;2;32;148;243m4[0m[1m][0m[1m)[0m torch.float64 
[1m[[0m[1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [1m[[0mnan nan nan nan[1m][0m
 [

KeyboardInterrupt: 