## Testing `pytorch` network

## Imports

In [None]:
%load_ext rich

In [None]:
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import tensorflow as tf
#os.environ['COLUMNS'] = '150'
        
#%load_ext autoreload
#%autoreload 2
#%matplotlib widget

In [None]:
from pathlib import Path
project_dir = Path().cwd().parent
if project_dir not in sys.path:
    sys.path.append(str(project_dir))

plt.style.use('default')
sns.set_context('talk')
sns.set_style('whitegrid')
sns.set_palette('bright')

In [None]:
import datetime
import time
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
plt.style.use('/Users/saforem2/.matplotlib/stylelib/molokai.mplstyle')

## Specify `NetworkConfig` and `DynamicsConfig` for `GaugeDynamics` object

**Note**: 
 - `lx` is the size of the 2D _square_ lattice $\ell_{x}\times \ell_{x}$
 - `nb` is the number of chains in our _batch_, i.e. the number of chains to be updated in parallel 
 - $\Longrightarrow$ `x.shape = (nb, lx, lx, 2)`

<!--<div class="alert alert-info" role="alert">-->
<!--</div>-->

In [None]:
from dynamics.pytorch.dynamics import DynamicsConfig, GaugeDynamics, NetWeights
from network.pytorch.network import (
    GaugeNetwork,
    LearningRateConfig,
    NetworkConfig,
    State,
    xy_repr,
)

lx = 8 
nb = 16 

net_config = NetworkConfig(**{
    'units': [16, 16, 16, 16],
    #'lattice_size': lx,
    'dropout_prob': 0.0,
    'use_batch_norm': False,
    'activation_fn': torch.nn.LeakyReLU(),
})

lr_config = LearningRateConfig(lr_init=0.001)

net_weights = NetWeights(1., 1., 1., 1., 1., 1.)

dc = dict(
    eps=0.1,
    num_steps=10,
    x_shape=(nb, lx, lx, 2),
    hmc=False,
    separate_networks=True,
    #use_ncp=False,
    eps_fixed=False,
    aux_weight=0.,
    use_mixed_loss=False,
    plaq_weight=0.,
    charge_weight=0.001,
    optimizer='adam',
    clip_val=0.,
    net_weights=net_weights,
)

dynamics_config = DynamicsConfig(**dc)
#gauge_net = GaugeNetwork(net_config=net_config)
dynamics_pt = GaugeDynamics(dynamics_config, net_config, lr_config)

In [None]:
import json
from rich import print_json
import logging
from lattice.pytorch.lattice import area_law, plaq_exact

#logger = logging.getLogger('jupyter')
jdict = {str(int(i)): plaq_exact(i) for i in np.arange(10)}
print_json(json.dumps(jdict, indent=4))

In [None]:
PI = np.pi
TWO_PI = 2. * PI

x = TWO_PI * torch.rand(dynamics_pt.config.x_shape, requires_grad=True) - PI
x = x.reshape(x.shape[0], -1)

## HMC:

### Training:

In [None]:
from copy import deepcopy
from dataclasses import asdict, dataclass

from dynamics.pytorch.dynamics import Steps, to_u1, train_and_test, train_step
from utils.data_containers import DataContainer
from utils.history import History
from utils.step_timer import StepTimer
from utils.logger import Logger
#from utils.pytorch.io import Logger

dchmc = deepcopy(dc)
dchmc['hmc'] = True
dchmc['eps'] = 0.2

dynamics_config_hmc = DynamicsConfig(**dchmc)

dynamics_hmc_pt = GaugeDynamics(dynamics_config_hmc, net_config, lr_config)
optimizer_hmc = optim.Adam(dynamics_hmc_pt.parameters(), lr=0.001)

beta = 3.
ntest = 1000
ntrain = 1000

steps = Steps(train=ntrain, test=ntest)

In [None]:
from dynamics.pytorch.dynamics import train_and_test

dynamics_hmc_pt.train()

#beta = np.array(np.linspace(0.5, beta, steps.train), dtype=np.float32).tolist()
skip = ['logdets', 'px', 'Qi', 'Qs', 'p4x4']

logger = Logger()
hmc_outputs = train_and_test(dynamics_hmc_pt,
                             optimizer_hmc,
                             steps=steps,
                             # ------------
                             #x=x,
                             window=0,
                             beta=beta,
                             skip=skip)
                             #logger=console)

In [None]:
%matplotlib inline

In [None]:
import matplotx

plt.style.use('/Users/saforem2/.matplotlib/stylelib/molokai.mplstyle')
plt.style.use('~/.matplotlib/stylelib/molokai.mplstyle')
sns.set_context('paper')
#plt.style.use(matplotx.styles.github['dark'])
with plt.style.context(matplotx.styles.dufte_bar):
    #plt.rcParams['text.color'] = '#FFFFFF'
    #plt.rcParams['axes.labelcolor'] = '#FFFFFF'
    plt.rcParams['figure.dpi'] = 150
    subplots_kwargs = {
        'figsize': (7, 3),
        'constrained_layout': True,
    }

    dataset_hmc_pt = hmc_outputs['train'].plot_all(
        num_chains=16, therm_frac=0.0,
        subplots_kwargs=subplots_kwargs,
    )

In [None]:
import warnings
warnings.filterwarnings('ignore')
plt.style.use('~/.matplotlib/stylelib/molokai.mplstyle')
with plt.style.context(matplotx.styles.dufte_bar):
    #plt.style.use(matplotx.styles.github['dark'])
    #plt.rcParams['text.color'] = '#FFFFFF'
    dataset_hmc_pt_inf = hmc_outputs['test'].plot_all(
        num_chains=16, therm_frac=0.1,
        subplots_kwargs=subplots_kwargs,
    )

## L2HMC

In [None]:
from __future__ import annotations

from typing import Union

import torch.nn

Scalar = Union[float, int]
def rescale_eps(dynamics: GaugeDynamics, xfrac: Scalar, vfrac: Scalar):
    xeps, veps = [], []
    rg = dynamics.config.eps_fixed
    
    for xe, ve in zip(dynamics.xeps, dynamics.veps):
        xe_ = nn.Parameter(xe * xfrac if xfrac < 1. else xe / xfrac, requires_grad=rg)
        ve_ = nn.Parameter(ve * vfrac if vfrac < 1. else ve / vfrac, requires_grad=rg)
        
        xeps.append(xe_)
        veps.append(ve_)
        
    dynamics.xeps = nn.ParameterList(xeps)
    dynamics.veps = nn.ParameterList(veps)
    
    return dynamics

In [None]:
dynamics_config = DynamicsConfig(**dc)
dynamics_pt = GaugeDynamics(dynamics_config, net_config, lr_config)
optimizer_pt = optim.Adam(dynamics_pt.parameters(), lr=0.001)

#dynamics_pt.train()

In [None]:
from dynamics.pytorch.dynamics import train_and_test

from dynamics.pytorch.dynamics import Steps

#beta_final = 3.0
steps = Steps(train=5000, test=1000, log=10, save=250)
#beta = np.array(np.linspace(0.5, beta_final, steps.train), dtype=np.float32).tolist()
skip = ['logdets', 'px', 'Qi', 'Qs', 'p4x4', 'p4']

#beta = 
#beta = np.array(np.linspace(3, 4, steps.train), dtype=np.float32).tolist()
beta = 3.
outputs = train_and_test(dynamics_pt,
                         optimizer_pt,
                         # ------------
                         #x=x,
                         beta=beta,
                         skip=skip,
                         steps=steps,
                         window=50,
                         train_history=outputs['train'],
                         test_history=outputs['test'])

In [None]:
import matplotx

plt.style.use(Path().home().joinpath('.matplotlib', 'stylelib', 'molokai.mplstyle'))
plt.rcParams['figure.dpi'] = 150

defaults = {
    'num_chains': 16,
    'therm_frac': 0.0,
    'subplots_kwargs': {
        'figsize': (7, 3.5),
        'constrained_layout': True,
    },
}
    
with plt.style.context(matplotx.styles.dufte):
    dataset_pt_inf = outputs['train'].plot_all(**defaults)

In [None]:
plt.style.use('default')
plt.style.use('~/.matplotlib/stylelib/molokai.mplstyle')
#matplotx.styles.duftify(
#plt.style.use('~/.matplotlib/stylelib/molokai.mplstyle')
with plt.style.context(matplotx.styles.dufte):
    #matplotx.ylabel_top("voltage [V]")
    #matplotx.line_labels()
    dataset_pt_inf = outputs['test'].plot_all(**defaults)

# Tensorflow

In [None]:
import datetime
import json
import logging
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from rich import get_console, print
from rich.theme import Theme

import tensorflow as tf
from config import BIN_DIR, PROJECT_DIR
from utils.hvd_init import RANK, SIZE
from utils.logger import Logger, print_dict

if os.path.abspath('..') not in sys.path:
    sys.path.append(os.path.abspath('..'))

sns.set_palette('bright')
plt.style.use('default')
sns.set_context('talk')
sns.set_style('whitegrid')
sns.set_palette('bright')

plt.rc('text', usetex=False)

#logger = Logger()
#console = get_console()
#console._width = 180
#console.use_theme(theme)

In [None]:
from rich import print

print(dynamics_pt.config)

In [None]:
import json

from config import BIN_DIR

#from utils.logger import Logger, print_dict

train_configs_file = os.path.join(BIN_DIR, 'test_configs.json')
with open(train_configs_file, 'rt') as f:
    configs = json.load(f)
    
configs.update({
    'ensure_new': False,
    'run_steps': 1000,
    'save_steps': 500,
    'steps_per_epoch': 100,
    'patience': 2,
    'min_lr': 1e-4,
    'logging_steps': 100,
    'print_steps': 10,
    'beta_init': 3.,
    'beta_final': 3.,
})

configs['dynamics_config'].update({
    'use_conv_net': False,
    'separate_networks': True,
    'use_mixed_loss': False,
    'aux_weight': 0.0,
    'num_steps': 5,
    'x_shape': [256, 16, 16, 2],
})

configs['network_config'].update({
    'units': [16, 16, 16, 16],
    'use_batch_norm': False,
    'dropout_prob': 0.025,
})

configs['conv_config'].update({
    'input_shape': configs['dynamics_config']['x_shape'][1:]
})

logger.log(print_dict(configs, name='Configs'))

In [None]:
from dynamics.gauge_dynamics import build_dynamics

dynamics_tf = build_dynamics(configs)

In [None]:
from utils.data_containers import DataContainer

#from dynamics.gauge_dynamics import convert_to_angle

x = tf.random.uniform(dynamics_tf.x_shape, minval=-PI, maxval=PI)

keep = ['step', 'loss', 'accept_prob', 'beta', 'dq_int', 'dq_sin', 'plaqs', 'p4x4']
train_data = DataContainer(configs['train_steps'])

In [None]:
from utils.data_containers import DataContainer
from utils.history import History
from utils.step_timer import StepTimer
#from utils.pytorch.io import Logger

import rich

from rich.console import Console

keep = ['loss', 'accept', 'step', 'px', 'acc', 'plaqs', 'dq_int', 'dq_sin', 'dQint', 'dQsin']
#logger = Logger()
console = Console()

timer = StepTimer()
history_tf = History()
ntrain = configs['train_steps']
for step in range(ntrain):
    timer.start()
    x, metrics = dynamics_tf.train_step((x, beta))
    #x, metrics_ = test_step((to_u1(x), beta), )
    dt = timer.stop()
    pdict = {'step': step, 'dt': f'{dt:.2g}'}
    pre = ' '.join([f'{key}={val}' for key, val in pdict.items()])
    #pre = ' '.join(['='.join(['step': step, 'dt': dt])])
    #pre = pstr(step, ntest, dt)
    history_tf.update(metrics, step)
    mstr = history_tf.metrics_summary(window=0, skip=skip, keep=keep, pre=pre)
    console.log(f'{mstr}')

In [None]:
from utils.data_containers import DataContainer
dctf = DataContainer()

In [None]:
dctf.data = history_tf.data

In [None]:
dctf_dataset = dctf.get_dataset()

In [None]:
from utils.history import History
history_tf_ = History()

In [None]:
history_tf_.data = history_tf.data

In [None]:
history_tf_.plot_all()

In [None]:
%debug

In [None]:
dctf_dataset.pl

In [None]:
dataset_tf = history_tf.plot_all(
    num_chains=10,
    therm_frac=0.1,
    subplots_kwargs={
        'figsize': (5, 3),
        'constrained_layout': True,
    }
)

In [None]:
%debug

In [None]:
beta = 4.
ntrain = 5000
for step in range(ntrain):
    x, metrics = dynamics_tf.train_step((x, beta))
    metrics_ = logger.print_metrics(metrics, skip=['logdets', 'px'], keep=keep, pre=[f'{step}/{ntrain}'])
    #data_str = train_data.print_metrics(metrics, window=0, keep=keep,
    #                                    pre=[f'{step}/{ntrain}'])
    #loss, mc_states, metrics = train_step((x, beta), dynamics, optimizer=optimizer)
    #metrics = logger.print_metrics(metrics, skip=['logdets', 'px'], pre=[f'{step}/{ntrain}'])
    #metrics_ = logger.print_metrics(metrics_, skip=['logdets', 'px'], pre=[f'{step}/{ntrain}'])
    #x = convert_to_angle(x)