# Testing GaugeModel

## Imports

In [1]:
import os
import sys
import pandas as pd
import numpy as np
import tensorflow as tf
import time
import pickle

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from collections import namedtuple

#from models.model import GaugeModel
#from models.gauge_model import GaugeModel
from loggers.train_logger import TrainLogger
from loggers.run_logger import RunLogger
from trainers.trainer import GaugeModelTrainer
from plotters.gauge_model_plotter import GaugeModelPlotter
from plotters.leapfrog_plotters import LeapfrogPlotter
from runners.runner import GaugeModelRunner
from main import create_config, train_setup
from config import PARAMS
import utils.file_io as io

import matplotlib as mpl
import matplotlib.pyplot as plt

%autoreload 2
%matplotlib notebook
np.set_printoptions(precision=5)
np.set_printoptions(suppress=True)

In [2]:
from models.gauge_model import GaugeModel
from models.params import GAUGE_PARAMS
from loggers.train_logger import TrainLogger
from main import create_config
from config import NP_FLOAT

print(f'GAUGE_PARAMS:\n')
_ = [print(f'  {k}: {v}') for k, v in GAUGE_PARAMS.items()]

params, hooks = train_setup(GAUGE_PARAMS, log_file=None)
model = GaugeModel(params)

train_logger = TrainLogger(model, params['log_dir'],
                           logging_steps=10, summaries=params['summaries'])

config, params = create_config(params)
net_weights_init = [1., 1., 1.]
samples_init = np.reshape(np.array(model.lattice.samples, dtype=NP_FLOAT),
                          (model.num_samples, model.x_dim))
beta_init = model.beta_init

checkpoint_dir = os.path.join(model.log_dir, 'checkpoints')
io.check_else_make_dir(checkpoint_dir)

sess_kwargs = {
    'checkpoint_dir': checkpoint_dir,
    'hooks': hooks,
    'config': config,
    'save_summaries_secs': None,
    'save_summaries_steps': None,
}

global_var_init = tf.global_variables_initializer()
local_var_init = tf.local_variables_initializer()
uninited = tf.report_uninitialized_variables()

sess = tf.train.MonitoredTrainingSession(**sess_kwargs)
tf.keras.backend.set_session(sess)

trainer = GaugeModelTrainer(sess, model, logger=train_logger)

GAUGE_PARAMS:

  space_size: 8
  time_size: 8
  link_type: U1
  dim: 2
  num_samples: 128
  rand: True
  num_steps: 5
  eps: 0.2
  fixed_beta: False
  hmc_beta: None
  hmc_eps: None
  beta_init: 2.0
  beta_final: 5.0
  inference: False
  beta_inference: None
  lr_init: 0.001
  lr_decay_steps: 1000
  lr_decay_rate: 0.96
  warmup_lr: False
  train_steps: 5000
  run_steps: 10000
  network_arch: generic
  num_hidden1: 50
  num_hidden2: 50
  use_bn: False
  dropout_prob: 0.5
  clip_value: 0.0
  summaries: True
  save_samples: False
  save_lf: True
  eps_fixed: False
  hmc: False
  loss_scale: 0.1
  std_weight: 1.0
  aux_weight: 1.0
  charge_weight: 0.0
  metric: cos_diff
  trace: False
  profiler: False
  gpu: False
  horovod: False
  comet: False
  restore: False
  theta: False
  num_intra_threads: 0
  float64: False
  using_hvd: False
--------------------------------------------------------------------------------
Starting training using L2HMC algorithm...
Creating directory: /Users/safor

In [3]:
train_kwargs = {
    'samples_np': samples_init,
    'beta_np': beta_init,
    'net_weights': net_weights_init,
    'print_steps': 1.,
}

trainer.train(5000, **train_kwargs)

-----------------------------------------------------------------------------------------------------------------
    STEP       LOSS     t/STEP    % ACC      EPS       BETA     ACTION     PLAQ    (EXACT)      dQ        LR    
-----------------------------------------------------------------------------------------------------------------
    0/5000     -1054     13.26    0.6431     0.201       2       63.57   0.006792   0.6978    0.4141     0.001   
    0/5000     -1054     13.26    0.6431     0.201       2       63.57   0.006792   0.6978    0.4141     0.001   
    1/5000    -983.8    0.3945    0.5453     0.202       2       47.89    0.2517    0.6978    0.2266     0.001   
    1/5000    -983.8    0.3945    0.5453     0.202       2       47.89    0.2517    0.6978    0.2266     0.001   
    2/5000     -1011    0.2527     0.62      0.203       2       39.62     0.381    0.6979    0.2031   0.0009999 
    2/5000     -1011    0.2527     0.62      0.203       2       39.62     0.381    0.697

   35/5000     -1184    0.5968     0.727    0.2353     2.008     18.73    0.7073    0.6992    0.1484   0.0009986 
   35/5000     -1184    0.5968     0.727    0.2353     2.008     18.73    0.7073    0.6992    0.1484   0.0009986 
   36/5000     -1202    0.2676    0.7204    0.2363     2.009     18.74    0.7072    0.6992     0.125   0.0009985 
   36/5000     -1202    0.2676    0.7204    0.2363     2.009     18.74    0.7072    0.6992     0.125   0.0009985 
   37/5000     -1155    0.3122    0.6443    0.2372     2.009     18.73    0.7073    0.6992    0.07031  0.0009985 
   37/5000     -1155    0.3122    0.6443    0.2372     2.009     18.73    0.7073    0.6992    0.07031  0.0009985 
   38/5000     -1194    0.3708    0.6974    0.2381     2.009     18.61    0.7092    0.6993    0.09375  0.0009984 
   38/5000     -1194    0.3708    0.6974    0.2381     2.009     18.61    0.7092    0.6993    0.09375  0.0009984 
   39/5000     -1207    0.5805    0.6731     0.239     2.009     18.8     0.7063    0.69

   71/5000     -1347    0.4293    0.7048    0.2645     2.017     19.23    0.6996    0.7006    0.1406   0.0009971 
   71/5000     -1347    0.4293    0.7048    0.2645     2.017     19.23    0.6996    0.7006    0.1406   0.0009971 
   72/5000     -1327    0.3504    0.6752    0.2651     2.017     19.51    0.6951    0.7006    0.1641   0.0009971 
   72/5000     -1327    0.3504    0.6752    0.2651     2.017     19.51    0.6951    0.7006    0.1641   0.0009971 
   73/5000     -1322    0.1694    0.6436    0.2658     2.018     19.39    0.6971    0.7007    0.08594  0.000997  
   73/5000     -1322    0.1694    0.6436    0.2658     2.018     19.39    0.6971    0.7007    0.08594  0.000997  
   74/5000     -1349    0.5036    0.6909    0.2665     2.018     19.47    0.6958    0.7007    0.1016   0.000997  
   74/5000     -1349    0.5036    0.6909    0.2665     2.018     19.47    0.6958    0.7007    0.1016   0.000997  
   75/5000     -1353     0.249    0.7049    0.2672     2.018     19.29    0.6986    0.70

  104/5000     -1409    0.3891    0.6534     0.284     2.025     18.98    0.7034    0.7019    0.1172   0.0009958 
  104/5000     -1409    0.3891    0.6534     0.284     2.025     18.98    0.7034    0.7019    0.1172   0.0009958 
  105/5000     -1404    0.3192    0.6145    0.2845     2.026     19.2     0.6999    0.7019    0.1016   0.0009957 
  105/5000     -1404    0.3192    0.6145    0.2845     2.026     19.2     0.6999    0.7019    0.1016   0.0009957 
  106/5000     -1409    0.2813    0.6285    0.2848     2.026     19.28    0.6988     0.702    0.09375  0.0009957 
  106/5000     -1409    0.2813    0.6285    0.2848     2.026     19.28    0.6988     0.702    0.09375  0.0009957 
  107/5000     -1404    0.3844    0.6191    0.2852     2.026     19.36    0.6974     0.702     0.125   0.0009956 
  107/5000     -1404    0.3844    0.6191    0.2852     2.026     19.36    0.6974     0.702     0.125   0.0009956 
  108/5000     -1478    0.3658    0.6444    0.2856     2.026     19.63    0.6933     0.7

  140/5000     -1481    0.4101    0.6265    0.2954     2.034     19.71     0.692    0.7033     0.125   0.0009943 
  140/5000     -1481    0.4101    0.6265    0.2954     2.034     19.71     0.692    0.7033     0.125   0.0009943 
  141/5000     -1488    0.3274    0.6595    0.2956     2.034     19.8     0.6907    0.7033    0.1328   0.0009943 
  141/5000     -1488    0.3274    0.6595    0.2956     2.034     19.8     0.6907    0.7033    0.1328   0.0009943 
  142/5000     -1455    0.3998    0.5964    0.2957     2.035     19.79    0.6908    0.7034    0.1328   0.0009942 
  142/5000     -1455    0.3998    0.5964    0.2957     2.035     19.79    0.6908    0.7034    0.1328   0.0009942 
  143/5000     -1464    0.2793    0.5922    0.2959     2.035     20.04    0.6868    0.7034    0.1719   0.0009942 
  143/5000     -1464    0.2793    0.5922    0.2959     2.035     20.04    0.6868    0.7034    0.1719   0.0009942 
  144/5000     -1411    0.2212    0.6209    0.2961     2.035     20.05    0.6867    0.70

  176/5000     -1472    0.2118    0.6873    0.2966     2.043     19.31    0.6983    0.7047    0.09375  0.0009928 
  176/5000     -1472    0.2118    0.6873    0.2966     2.043     19.31    0.6983    0.7047    0.09375  0.0009928 
  177/5000     -1464    0.3385     0.649    0.2966     2.043     19.29    0.6986    0.7048    0.1172   0.0009928 
  177/5000     -1464    0.3385     0.649    0.2966     2.043     19.29    0.6986    0.7048    0.1172   0.0009928 
  178/5000     -1493     0.32     0.7031    0.2967     2.044     19.32    0.6982    0.7048    0.1328   0.0009928 
  178/5000     -1493     0.32     0.7031    0.2967     2.044     19.32    0.6982    0.7048    0.1328   0.0009928 
  179/5000     -1502     0.501    0.6815    0.2968     2.044     19.24    0.6994    0.7049    0.1562   0.0009927 
  179/5000     -1502     0.501    0.6815    0.2968     2.044     19.24    0.6994    0.7049    0.1562   0.0009927 
  180/5000     -1521    0.2914    0.6586    0.2969     2.044     19.03    0.7026    0.70

  209/5000     -1539    0.1913     0.674    0.2965     2.051     19.09    0.7017     0.706    0.1328   0.0009915 
  209/5000     -1539    0.1913     0.674    0.2965     2.051     19.09    0.7017     0.706    0.1328   0.0009915 
  210/5000     -1532    0.2733    0.6573    0.2967     2.052     19.18    0.7004    0.7061    0.0625   0.0009915 
  210/5000     -1532    0.2733    0.6573    0.2967     2.052     19.18    0.7004    0.7061    0.0625   0.0009915 
  211/5000     -1507    0.2118    0.6521    0.2968     2.052     19.06    0.7021    0.7061    0.1406   0.0009914 
  211/5000     -1507    0.2118    0.6521    0.2968     2.052     19.06    0.7021    0.7061    0.1406   0.0009914 
  212/5000     -1497    0.2278     0.67     0.2972     2.052     19.28    0.6988    0.7062    0.1406   0.0009914 
  212/5000     -1497    0.2278     0.67     0.2972     2.052     19.28    0.6988    0.7062    0.1406   0.0009914 
  213/5000     -1519    0.2566    0.6917    0.2975     2.052     19.18    0.7004    0.70

  245/5000     -1511    0.2822     0.682    0.2891     2.061     18.87    0.7051    0.7075    0.08594   0.00099  
  245/5000     -1511    0.2822     0.682    0.2891     2.061     18.87    0.7051    0.7075    0.08594   0.00099  
  246/5000     -1546    0.2277    0.6993     0.289     2.061     18.65    0.7085    0.7075    0.1094    0.00099  
  246/5000     -1546    0.2277    0.6993     0.289     2.061     18.65    0.7085    0.7075    0.1094    0.00099  
  247/5000     -1506    0.2643    0.6923    0.2889     2.061     19.04    0.7025    0.7076    0.1719    0.00099  
  247/5000     -1506    0.2643    0.6923    0.2889     2.061     19.04    0.7025    0.7076    0.1719    0.00099  
  248/5000     -1513    0.1849    0.7559    0.2889     2.061     19.1     0.7015    0.7076    0.1406   0.0009899 
  248/5000     -1513    0.1849    0.7559    0.2889     2.061     19.1     0.7015    0.7076    0.1406   0.0009899 
  249/5000     -1568    0.2684    0.7291    0.2889     2.062     19.04    0.7025    0.70

  281/5000     -1509    0.3972    0.7433    0.2879     2.07      19.01    0.7029    0.7089    0.1094   0.0009886 
  281/5000     -1509    0.3972    0.7433    0.2879     2.07      19.01    0.7029    0.7089    0.1094   0.0009886 
  282/5000     -1535     0.419    0.7275    0.2881     2.07      18.72    0.7075    0.7089    0.07812  0.0009886 
  282/5000     -1535     0.419    0.7275    0.2881     2.07      18.72    0.7075    0.7089    0.07812  0.0009886 
  283/5000     -1526    0.1786    0.6774    0.2883     2.07      18.94     0.704     0.709    0.1172   0.0009885 
  283/5000     -1526    0.1786    0.6774    0.2883     2.07      18.94     0.704     0.709    0.1172   0.0009885 
  284/5000     -1512    0.1695     0.671    0.2884     2.071     18.79    0.7064     0.709    0.1328   0.0009885 
  284/5000     -1512    0.1695     0.671    0.2884     2.071     18.79    0.7064     0.709    0.1328   0.0009885 
  285/5000     -1535    0.1596    0.7469    0.2886     2.071     18.91    0.7046    0.70

  314/5000     -1569    0.3111    0.7111    0.2979     2.078     19.02    0.7028    0.7102    0.1172   0.0009873 
  314/5000     -1569    0.3111    0.7111    0.2979     2.078     19.02    0.7028    0.7102    0.1172   0.0009873 
  315/5000     -1559    0.3095    0.6958     0.298     2.079     19.1     0.7016    0.7103    0.07812  0.0009872 
  315/5000     -1559    0.3095    0.6958     0.298     2.079     19.1     0.7016    0.7103    0.07812  0.0009872 
  316/5000     -1535     0.332    0.7002    0.2982     2.079      19      0.7031    0.7103    0.08594  0.0009872 
  316/5000     -1535     0.332    0.7002    0.2982     2.079      19      0.7031    0.7103    0.08594  0.0009872 
  317/5000     -1600    0.3294    0.7275    0.2984     2.079     18.72    0.7075    0.7103    0.1641   0.0009871 
  317/5000     -1600    0.3294    0.7275    0.2984     2.079     18.72    0.7075    0.7103    0.1641   0.0009871 
  318/5000     -1565    0.5219    0.6788    0.2986     2.079     18.93    0.7042    0.71

  350/5000     -1603    0.5841    0.7418    0.2965     2.088     18.58    0.7098    0.7117     0.125   0.0009858 
  350/5000     -1603    0.5841    0.7418    0.2965     2.088     18.58    0.7098    0.7117     0.125   0.0009858 
  351/5000     -1552    0.3386    0.7162    0.2966     2.088     18.5     0.7109    0.7117    0.1406   0.0009858 
  351/5000     -1552    0.3386    0.7162    0.2966     2.088     18.5     0.7109    0.7117    0.1406   0.0009858 
  352/5000     -1568    0.4156    0.6966    0.2971     2.088     18.31    0.7138    0.7117    0.1641   0.0009857 
  352/5000     -1568    0.4156    0.6966    0.2971     2.088     18.31    0.7138    0.7117    0.1641   0.0009857 
  353/5000     -1531    0.4262    0.6703    0.2976     2.088     18.19    0.7157    0.7118    0.1406   0.0009857 
  353/5000     -1531    0.4262    0.6703    0.2976     2.088     18.19    0.7157    0.7118    0.1406   0.0009857 
  354/5000     -1523    0.4363    0.6596    0.2981     2.089     18.27    0.7145    0.71

  386/5000     -1561    0.5259    0.7342    0.2967     2.097     18.93    0.7043    0.7131    0.1484   0.0009844 
  386/5000     -1561    0.5259    0.7342    0.2967     2.097     18.93    0.7043    0.7131    0.1484   0.0009844 
  387/5000     -1572    0.4611     0.693    0.2966     2.097     19.08    0.7019    0.7131    0.1562   0.0009843 
  387/5000     -1572    0.4611     0.693    0.2966     2.097     19.08    0.7019    0.7131    0.1562   0.0009843 
  388/5000     -1584    0.5382     0.725    0.2964     2.098     19.09    0.7017    0.7132    0.08594  0.0009843 
  388/5000     -1584    0.5382     0.725    0.2964     2.098     19.09    0.7017    0.7132    0.08594  0.0009843 
  389/5000     -1542    0.6352    0.7036    0.2963     2.098     19.01    0.7029    0.7132    0.1797   0.0009842 
  389/5000     -1542    0.6352    0.7036    0.2963     2.098     19.01    0.7029    0.7132    0.1797   0.0009842 
  390/5000     -1552    0.7124    0.6773    0.2962     2.098     19.03    0.7027    0.71

  420/5000     -1558    0.1829     0.732    0.2974     2.106     18.55    0.7102    0.7145    0.09375  0.000983  
  420/5000     -1558    0.1829     0.732    0.2974     2.106     18.55    0.7102    0.7145    0.09375  0.000983  
  421/5000     -1570    0.2055    0.7349    0.2977     2.106     18.42    0.7122    0.7145     0.125   0.000983  
  421/5000     -1570    0.2055    0.7349    0.2977     2.106     18.42    0.7122    0.7145     0.125   0.000983  
  422/5000     -1639    0.2117    0.7666     0.298     2.107     18.61    0.7092    0.7145     0.125   0.0009829 
  422/5000     -1639    0.2117    0.7666     0.298     2.107     18.61    0.7092    0.7145     0.125   0.0009829 
  423/5000     -1568    0.2154    0.7133    0.2984     2.107     18.51    0.7108    0.7146    0.1562   0.0009829 
  423/5000     -1568    0.2154    0.7133    0.2984     2.107     18.51    0.7108    0.7146    0.1562   0.0009829 
  424/5000     -1534    0.2264    0.6838    0.2986     2.107     18.5      0.711    0.71

  456/5000     -1573    0.4686    0.7238    0.2996     2.116     18.45    0.7118    0.7159     0.125   0.0009816 
  456/5000     -1573    0.4686    0.7238    0.2996     2.116     18.45    0.7118    0.7159     0.125   0.0009816 
  457/5000     -1547    0.5179    0.6522    0.2995     2.116     18.66    0.7084    0.7159    0.09375  0.0009815 
  457/5000     -1547    0.5179    0.6522    0.2995     2.116     18.66    0.7084    0.7159    0.09375  0.0009815 
  458/5000     -1602     0.464     0.752    0.2991     2.116     18.56     0.71      0.716    0.1172   0.0009815 
  458/5000     -1602     0.464     0.752    0.2991     2.116     18.56     0.71      0.716    0.1172   0.0009815 
  459/5000     -1634     0.477    0.7783    0.2988     2.117     18.75     0.707     0.716    0.1328   0.0009814 
  459/5000     -1634     0.477    0.7783    0.2988     2.117     18.75     0.707     0.716    0.1328   0.0009814 
  460/5000     -1560    0.2708    0.6821    0.2986     2.117     18.35    0.7133    0.71

  492/5000     -1585    0.5519    0.7328    0.3012     2.125     18.12    0.7169    0.7173    0.09375  0.0009801 
  492/5000     -1585    0.5519    0.7328    0.3012     2.125     18.12    0.7169    0.7173    0.09375  0.0009801 
  493/5000     -1542    0.4212    0.7155    0.3011     2.126     17.91    0.7202    0.7174    0.07812  0.0009801 
  493/5000     -1542    0.4212    0.7155    0.3011     2.126     17.91    0.7202    0.7174    0.07812  0.0009801 
  494/5000     -1505    0.8159    0.6399     0.301     2.126     17.85    0.7211    0.7174    0.1172    0.00098  
  494/5000     -1505    0.8159    0.6399     0.301     2.126     17.85    0.7211    0.7174    0.1172    0.00098  
  495/5000     -1579    0.8723    0.7154     0.301     2.126     18.3      0.714    0.7175    0.1016    0.00098  
  495/5000     -1579    0.8723    0.7154     0.301     2.126     18.3      0.714    0.7175    0.1016    0.00098  
  496/5000     -1563    0.5456    0.7721    0.3009     2.127     18.15    0.7164    0.71

  525/5000     -1526    0.5044    0.7228    0.3013     2.134     18.21    0.7155    0.7187    0.1094   0.0009788 
  525/5000     -1526    0.5044    0.7228    0.3013     2.134     18.21    0.7155    0.7187    0.1094   0.0009788 
  526/5000     -1533    0.3391    0.7321    0.3014     2.135     18.38    0.7127    0.7187    0.1094   0.0009788 
  526/5000     -1533    0.3391    0.7321    0.3014     2.135     18.38    0.7127    0.7187    0.1094   0.0009788 
  527/5000     -1524    0.3451    0.7253    0.3012     2.135     18.49     0.711    0.7187    0.1875   0.0009787 
  527/5000     -1524    0.3451    0.7253    0.3012     2.135     18.49     0.711    0.7187    0.1875   0.0009787 
  528/5000     -1503    0.3906    0.6879     0.301     2.135     18.58    0.7098    0.7188    0.1016   0.0009787 
  528/5000     -1503    0.3906    0.6879     0.301     2.135     18.58    0.7098    0.7188    0.1016   0.0009787 
  529/5000     -1527    0.4623    0.6952     0.301     2.136     18.3     0.7141    0.71

  561/5000     -1530    0.2885    0.7528    0.3059     2.144     18.3      0.714    0.7201    0.1562   0.0009774 
  561/5000     -1530    0.2885    0.7528    0.3059     2.144     18.3      0.714    0.7201    0.1562   0.0009774 
  562/5000     -1546    0.3617    0.7312     0.306     2.145     18.52    0.7106    0.7202    0.1016   0.0009773 
  562/5000     -1546    0.3617    0.7312     0.306     2.145     18.52    0.7106    0.7202    0.1016   0.0009773 
  563/5000     -1490     0.447    0.6646    0.3061     2.145     18.28    0.7144    0.7202     0.125   0.0009773 
  563/5000     -1490     0.447    0.6646    0.3061     2.145     18.28    0.7144    0.7202     0.125   0.0009773 
  564/5000     -1556    0.4528    0.7146    0.3062     2.145     18.28    0.7144    0.7202    0.1172   0.0009772 
  564/5000     -1556    0.4528    0.7146    0.3062     2.145     18.28    0.7144    0.7202    0.1172   0.0009772 
  565/5000     -1529    0.3141    0.7042    0.3068     2.145     18.41    0.7123    0.72

  597/5000     -1501    0.6518     0.705    0.3141     2.154     18.35    0.7133    0.7216    0.1328   0.0009759 
  597/5000     -1501    0.6518     0.705    0.3141     2.154     18.35    0.7133    0.7216    0.1328   0.0009759 
  598/5000     -1567    0.3078    0.7224    0.3139     2.155     18.17    0.7161    0.7216    0.1406   0.0009759 
  598/5000     -1567    0.3078    0.7224    0.3139     2.155     18.17    0.7161    0.7216    0.1406   0.0009759 
INFO:tensorflow:global_step/global_step/sec: 1.85293
  599/5000     -1530     0.511    0.7175    0.3137     2.155     18.25    0.7149    0.7216    0.1484   0.0009758 
-----------------------------------------------------------------------------------------------------------------
    STEP       LOSS     t/STEP    % ACC      EPS       BETA     ACTION     PLAQ    (EXACT)      dQ        LR    
-----------------------------------------------------------------------------------------------------------------
  599/5000     -1530     0.511    0.

  630/5000     -1444    0.7945     0.733     0.286     2.164     17.67    0.7238    0.7229     0.125   0.0009746 
  630/5000     -1444    0.7945     0.733     0.286     2.164     17.67    0.7238    0.7229     0.125   0.0009746 
  631/5000     -1503     0.276    0.7438    0.2857     2.164     17.84    0.7212    0.7229    0.09375  0.0009746 
  631/5000     -1503     0.276    0.7438    0.2857     2.164     17.84    0.7212    0.7229    0.09375  0.0009746 
  632/5000     -1513    0.2178    0.7591    0.2852     2.164     18.12    0.7168     0.723    0.1641   0.0009745 
  632/5000     -1513    0.2178    0.7591    0.2852     2.164     18.12    0.7168     0.723    0.1641   0.0009745 
  633/5000     -1488    0.2016    0.7449    0.2848     2.164     17.86     0.721     0.723    0.1328   0.0009745 
  633/5000     -1488    0.2016    0.7449    0.2848     2.164     17.86     0.721     0.723    0.1328   0.0009745 
  634/5000     -1485    0.6921    0.7542    0.2844     2.165     17.99    0.7189     0.7

  666/5000     -1500    0.6312    0.7637     0.286     2.174     17.3     0.7296    0.7243     0.125   0.0009732 
  666/5000     -1500    0.6312    0.7637     0.286     2.174     17.3     0.7296    0.7243     0.125   0.0009732 
  667/5000     -1506    0.8367     0.763    0.2862     2.174     17.47    0.7271    0.7244    0.08594  0.0009731 
  667/5000     -1506    0.8367     0.763    0.2862     2.174     17.47    0.7271    0.7244    0.08594  0.0009731 
  668/5000     -1591    0.6063    0.7923    0.2863     2.174     17.32    0.7294    0.7244     0.125   0.0009731 
  668/5000     -1591    0.6063    0.7923    0.2863     2.174     17.32    0.7294    0.7244     0.125   0.0009731 
  669/5000     -1517    0.7802    0.7468    0.2864     2.175     17.43    0.7277    0.7245    0.0625   0.0009731 
  669/5000     -1517    0.7802    0.7468    0.2864     2.175     17.43    0.7277    0.7245    0.0625   0.0009731 
  670/5000     -1508    0.3862    0.7541    0.2865     2.175     17.63    0.7245    0.72

  700/5000     -1517    0.3985    0.7576    0.2896     2.183     17.97    0.7192    0.7257    0.08594  0.0009718 
  700/5000     -1517    0.3985    0.7576    0.2896     2.183     17.97    0.7192    0.7257    0.08594  0.0009718 
  701/5000     -1545    0.4999    0.7686    0.2897     2.184     18.35    0.7133    0.7257    0.0625   0.0009718 
  701/5000     -1545    0.4999    0.7686    0.2897     2.184     18.35    0.7133    0.7257    0.0625   0.0009718 
  702/5000     -1515    0.2917    0.7855    0.2898     2.184     17.95    0.7196    0.7258    0.09375  0.0009717 
  702/5000     -1515    0.2917    0.7855    0.2898     2.184     17.95    0.7196    0.7258    0.09375  0.0009717 
  703/5000     -1522    0.4147    0.7402     0.29      2.184     17.82    0.7216    0.7258    0.07812  0.0009717 
  703/5000     -1522    0.4147    0.7402     0.29      2.184     17.82    0.7216    0.7258    0.07812  0.0009717 
  704/5000     -1545    0.5147     0.752    0.2901     2.185     17.59    0.7252    0.72

  736/5000     -1572    0.2995    0.7634     0.293     2.194     17.41    0.7279    0.7272    0.1094   0.0009704 
  736/5000     -1572    0.2995    0.7634     0.293     2.194     17.41    0.7279    0.7272    0.1094   0.0009704 
  737/5000     -1594    0.2845    0.7725    0.2931     2.194     17.29    0.7298    0.7272    0.09375  0.0009704 
  737/5000     -1594    0.2845    0.7725    0.2931     2.194     17.29    0.7298    0.7272    0.09375  0.0009704 
  738/5000     -1584    0.2101    0.7686    0.2932     2.194     17.34     0.729    0.7272    0.09375  0.0009703 
  738/5000     -1584    0.2101    0.7686    0.2932     2.194     17.34     0.729    0.7272    0.09375  0.0009703 
  739/5000     -1555    0.3422     0.765    0.2934     2.195     17.72    0.7231    0.7273    0.1016   0.0009703 
  739/5000     -1555    0.3422     0.765    0.2934     2.195     17.72    0.7231    0.7273    0.1016   0.0009703 
  740/5000     -1563     0.27     0.7326    0.2935     2.195     17.62    0.7247    0.72

  772/5000     -1578    0.3647     0.769    0.2956     2.204     17.37    0.7285    0.7286    0.08594  0.000969  
  772/5000     -1578    0.3647     0.769    0.2956     2.204     17.37    0.7285    0.7286    0.08594  0.000969  
  773/5000     -1559    0.3414    0.7125    0.2957     2.204     17.29    0.7299    0.7287    0.1172   0.0009689 
  773/5000     -1559    0.3414    0.7125    0.2957     2.204     17.29    0.7299    0.7287    0.1172   0.0009689 
  774/5000     -1514    0.2547    0.7271    0.2957     2.205     17.6      0.725    0.7287    0.08594  0.0009689 
  774/5000     -1514    0.2547    0.7271    0.2957     2.205     17.6      0.725    0.7287    0.08594  0.0009689 
  775/5000     -1516    0.3228    0.7158    0.2957     2.205     17.29    0.7298    0.7287    0.1328   0.0009689 
  775/5000     -1516    0.3228    0.7158    0.2957     2.205     17.29    0.7298    0.7287    0.1328   0.0009689 
  776/5000     -1588     0.317    0.7958    0.2957     2.205     17.3     0.7296    0.72

  805/5000     -1551    0.2197     0.746    0.2987     2.214     18.03    0.7182    0.7299    0.1641   0.0009677 
  805/5000     -1551    0.2197     0.746    0.2987     2.214     18.03    0.7182    0.7299    0.1641   0.0009677 
  806/5000     -1607    0.2523    0.7953    0.2988     2.214     18.13    0.7167     0.73     0.1172   0.0009676 
  806/5000     -1607    0.2523    0.7953    0.2988     2.214     18.13    0.7167     0.73     0.1172   0.0009676 
  807/5000     -1611    0.2919    0.7526    0.2989     2.214     18.29    0.7143     0.73     0.1094   0.0009676 
  807/5000     -1611    0.2919    0.7526    0.2989     2.214     18.29    0.7143     0.73     0.1094   0.0009676 
  808/5000     -1577    0.3081    0.7631    0.2989     2.215     18.25    0.7148    0.7301    0.09375  0.0009676 
  808/5000     -1577    0.3081    0.7631    0.2989     2.215     18.25    0.7148    0.7301    0.09375  0.0009676 
  809/5000     -1578    0.3962    0.7689     0.299     2.215     18.15    0.7164    0.73

  841/5000     -1606    0.2704    0.7293    0.3013     2.224     17.63    0.7245    0.7314    0.1562   0.0009663 
  841/5000     -1606    0.2704    0.7293    0.3013     2.224     17.63    0.7245    0.7314    0.1562   0.0009663 
  842/5000     -1613    0.2209    0.7777    0.3014     2.225     17.41     0.728    0.7314    0.1016   0.0009662 
  842/5000     -1613    0.2209    0.7777    0.3014     2.225     17.41     0.728    0.7314    0.1016   0.0009662 
  843/5000     -1584    0.6506    0.7004    0.3015     2.225     17.33    0.7292    0.7315    0.07031  0.0009662 
  843/5000     -1584    0.6506    0.7004    0.3015     2.225     17.33    0.7292    0.7315    0.07031  0.0009662 
  844/5000     -1576    0.6159    0.7429    0.3016     2.225     17.48    0.7269    0.7315    0.0625   0.0009661 
  844/5000     -1576    0.6159    0.7429    0.3016     2.225     17.48    0.7269    0.7315    0.0625   0.0009661 
  845/5000     -1585    0.5425     0.751    0.3017     2.226     17.79     0.722    0.73

  877/5000     -1597    0.4363    0.7076    0.3038     2.235     17.23    0.7307    0.7329    0.1094   0.0009648 
  877/5000     -1597    0.4363    0.7076    0.3038     2.235     17.23    0.7307    0.7329    0.1094   0.0009648 
  878/5000     -1531    0.2935    0.6888    0.3039     2.236     17.3     0.7297    0.7329    0.1094   0.0009648 
  878/5000     -1531    0.2935    0.6888    0.3039     2.236     17.3     0.7297    0.7329    0.1094   0.0009648 
  879/5000     -1595    0.6241    0.7087    0.3039     2.236     17.12    0.7325    0.7329    0.07812  0.0009648 
  879/5000     -1595    0.6241    0.7087    0.3039     2.236     17.12    0.7325    0.7329    0.07812  0.0009648 
  880/5000     -1565    0.3437    0.7649     0.304     2.236     17.36    0.7288     0.733    0.09375  0.0009647 
  880/5000     -1565    0.3437    0.7649     0.304     2.236     17.36    0.7288     0.733    0.09375  0.0009647 
  881/5000     -1612    0.3851    0.7575    0.3039     2.236     17.25    0.7304     0.7

  910/5000     -1618     0.324    0.7842     0.304     2.245     17.47     0.727    0.7342    0.1484   0.0009635 
  910/5000     -1618     0.324    0.7842     0.304     2.245     17.47     0.727    0.7342    0.1484   0.0009635 
  911/5000     -1544    0.3987    0.7104     0.304     2.245     17.54     0.726    0.7342    0.09375  0.0009635 
  911/5000     -1544    0.3987    0.7104     0.304     2.245     17.54     0.726    0.7342    0.09375  0.0009635 
  912/5000     -1602    0.5954    0.7111     0.304     2.246     17.66    0.7241    0.7343    0.1328   0.0009635 
  912/5000     -1602    0.5954    0.7111     0.304     2.246     17.66    0.7241    0.7343    0.1328   0.0009635 
  913/5000     -1626    0.4787    0.7649    0.3038     2.246     17.63    0.7245    0.7343    0.1172   0.0009634 
  913/5000     -1626    0.4787    0.7649    0.3038     2.246     17.63    0.7245    0.7343    0.1172   0.0009634 
  914/5000     -1552    0.5298    0.7025    0.3037     2.246     17.14    0.7322    0.73

  947/5000     -1567    0.2661     0.67     0.3042     2.256     17.25    0.7304    0.7357    0.09375  0.0009621 
  947/5000     -1567    0.2661     0.67     0.3042     2.256     17.25    0.7304    0.7357    0.09375  0.0009621 
  948/5000     -1618    0.4197    0.7523    0.3043     2.257     17.52    0.7263    0.7357    0.09375  0.000962  
  948/5000     -1618    0.4197    0.7523    0.3043     2.257     17.52    0.7263    0.7357    0.09375  0.000962  
  949/5000     -1579    0.4763     0.705    0.3044     2.257     16.98    0.7347    0.7358    0.09375  0.000962  
  949/5000     -1579    0.4763     0.705    0.3044     2.257     16.98    0.7347    0.7358    0.09375  0.000962  
  950/5000     -1581    0.2028    0.7206    0.3043     2.257      17      0.7343    0.7358    0.1484   0.000962  
  950/5000     -1581    0.2028    0.7206    0.3043     2.257      17      0.7343    0.7358    0.1484   0.000962  
  951/5000     -1592    0.2219    0.7522    0.3042     2.258     17.09    0.7329    0.73

TypeError: update() missing 1 required positional argument: 'net_weights'

## Optimize for inference

In [None]:
from inference import create_config

ld1 = ('../../logs/2019_08_27/2019_08_27_0457/'
       'lattice8_batch100_lf10_qw00_aw10_generic_dp00')
log_dir1 = os.path.join(*ld1.split('/'))

params_file = os.path.join(log_dir1, 'parameters.pkl')
with open(params_file, 'rb') as f:
    params = pickle.load(f)
    
checkpoint_dir = os.path.join(log_dir1, 'checkpoints')
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)

config, params = create_config(params)
sess = tf.Session(config=config)
saver = tf.train.import_meta_graph(f'{checkpoint_file}.meta')
saver.restore(sess, checkpoint_file)

In [None]:
run_ops = tf.get_collection('run_ops', )
inputs = tf.get_collection('inputs')


In [None]:
run_nodes = [node.name for node in tf.get
node_names = [node.name for node in tf.get_default_graph().as_graph_def().node]


In [None]:
run_node_names = [op for op in run_ops if op in tf.get_default_graph().as_graph_def().node]

In [None]:
run_node_names

In [None]:
run_op_names = []
for op in run_ops:
    _name = op.name.split('/')
    name = ('/').join(_name[:-1])
    run_op_names.append(name)
    
run_op_names1 = [op.name for op in run_ops]
_run_op_names = (', \n').join(run_op_names)
_run_op_names1 = (', \n').join(run_op_names1)
print(_run_op_names)

In [None]:
print(_run_op_names1)

In [None]:
_names = 'sampler/x_update/x_out,sampler/x_update/accept_prob,observables/actions,observables/plaqs'

In [None]:
     checkpoint_file = kwargs.get('checkpoint_file', None)
    restore_op_name = kwargs.get('restore_op_name', 'save/restore_all')
    filename_tensor = kwargs.get('filename_tensor', 'save/Const:0')

    if checkpoint_file is None:
        checkpoint_dir = os.path.dirname(input_graph)
        checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)

    #  freeze_graph.freeze_graph('tensorflowModel.pbtxt', "", False,
    #                            './tensorflowModel.ckpt', "output/softmax",
    #                             "save/restore_all", "save/Const:0",
    #                             'frozentensorflowModel.pb', True, "")

    freeze_graph.freeze_graph(
        input_graph=input_graph,
        input_saver="",
        input_binary=False,  # True: `.pb` file; False: `.pbtxt` file
        input_checkpoint=checkpoint_file,
        output_node_names=output_node_names,
        restore_op_name=restore_op_name,
        filename_tensor=filename_tensor,
        output_graph=output_graph,
        clear_devices=True,
        initializer_nodes=""
    )


In [None]:
from tensorflow.python.tools import optimize_for_inference_lib, freeze_graph

In [None]:
from tensorflow.python.tools import optimize_for_inference_lib, freeze_graph

input_graph = os.path.join(checkpoint_dir, 'graph.pbtxt')
output_graph = os.path.join(checkpoint_dir, 'frozen_graph.pb')

restore_op_name = 'save/restore_all'
filename_tensor = 'save/Const:0'
freeze_graph.freeze_graph(
    input_graph=input_graph,
    input_saver='',
    input_binary=False,
    input_checkpoint=checkpoint_file,
    output_node_names=_names,
    restore_op_name='save/restore_all',
    filename_tensor_name='save/Const:0',
    output_graph=output_graph,
    clear_devices=True,
    initializer_nodes=''
)

In [None]:
keys = ['x_out', 'px', 'actions_op', 'plaqs_op', 'avg_plaqs_op', 'charges_op', 'charge_diffs_op', 'eps']
_keys = (',').join(keys)
node_names = [node.name for node in tf.get_default_graph().as_graph_def().node]

In [None]:
len(node_names)

In [None]:
node_names[:10]

In [None]:
%debug

In [None]:
_node_names = [i for i in node_names if '']

In [None]:
_node_names

In [None]:
input_names = [op.name for op in inputs]
input_names

In [None]:
python -m tensorflow.python.tools.freeze_graph --input_graph graph.pb --input_checkpoint test_model --output_graph graph_frozen.pb --output_node_names=y

In [None]:
_in_names = (', ').join(input_names)
_in_names

In [None]:
input_graph = os.path.join(checkpoint_dir, 'graph.pbtxt')
input_checkpoint = os.path.join(checkpoint_dir, 'model')
output_graph = os.path.join(checkpoint_dir, 'graph_frozen.pb')
output_node_names = (', ').join(run_op_names)

In [None]:
%%python3 -m tensorflow.python.tools.freeze_graph \ 
    --input_graph=input_graph \
    --output_graph=output_graph \
    --output_node_names=output_node_names

In [None]:
optimized_graph = os.path.join(checkpoint_dir, 'optimized_graph.pb')

In [None]:
print((',').join(run_op_names))

In [None]:
(', \n').join(input_names)

In [None]:
%%python3 -m tensorflow.python.tools.optimize_for_inference \
    --input input_graph \
    --output_graph optimized_graph \
    --input_names (', ').join(input_names) \
    --output_names (', ').join(run_op_names)

In [None]:
python -m tensorflow.python.tools.optimize_for_inference --input graph_frozen.pb --output graph_optimized.pb --input_names=x --output_names=y

In [None]:
!spython3 -m tensorflow.python.tools.optimizer_for_inference --input 

## OLD

### Plot $\langle \delta_{\phi_{P}}\rangle$ vs. net_weights_arr

In [None]:
ld1 = ('../../logs/2019_08_27/2019_08_27_0457/'
       'lattice8_batch100_lf10_qw00_aw10_generic_dp00')
log_dir1 = os.path.join(*ld1.split('/'))

In [None]:
ld2 = ('../../logs/2019_08_27/2019_08_27_1314/'
       'lattice8_batch100_lf16_qw00_aw10_generic_dp00')
log_dir2 = os.path.join(*ld2.split('/'))

In [None]:
ld3 = ('../../logs/cooley_logs/2019_08_28/2019_08_28_0213/'
       'lattice8_batch128_lf16_qw00_aw10_generic_dp00')
log_dir3 = os.path.join(*ld3.split('/'))

In [None]:
ld4 = ('../../logs/2019_08_28/2019_08_28_1244/'
       'lattice8_batch100_lf16_qw00_aw10_generic_dp00')
log_dir4 = os.path.join(*ld4.split('/'))

In [None]:
fig4, ax4 = plot_plaq_diffs_vs_net_weights(log_dir4, **kwargs)

In [None]:
import pandas as pd
plaq_diff_txt_file = os.path.join(log_dir4, 'plaq_diffs_data.txt')
_data = pd.read_csv(plaq_diff_txt_file, header=None)
pdd = _data.values

In [None]:
log_dir4

In [None]:
zero_data = pdd[0]
stq_data = pdd[-1]
pdd = pdd[1:-1]

q_data = pdd[pdd[:, 2] > 0.]
t_data = pdd[pdd[:, 1] > 0.]
s_data = pdd[pdd[:, 0] > 0.]

print(q_data)

In [None]:
print(q_data[-1])

In [None]:
from plotters.plot_utils import plot_plaq_diffs_vs_net_weights
plt.close('all')
kwargs = {'ext': 'png'}
fig1, ax1 = plot_plaq_diffs_vs_net_weights(log_dir1, **kwargs)

In [None]:
fig2, ax2 = plot_plaq_diffs_vs_net_weights(log_dir2, **kwargs)

In [None]:
fig3, ax3 = plot_plaq_diffs_vs_net_weights(log_dir3, **kwargs)

In [None]:
plaq_diff_txt_file = os.path.join(log_dir, 'plaq_diffs_data.txt')
pdd = pd.read_csv(plaq_diff_txt_file, header=None)
plaq_diff_data = pdd.values

In [None]:
x = np.arange(25)

In [None]:
x

In [None]:
x[0]
x[1:9]
x[9:16]

In [None]:
zero_weights = np.array([0.00, 0.00, 0.00])   # set weights to 0.

q_weights = np.array([[0.00, 0.00, 0.10],   # loop over Q weights
                      [0.00, 0.00, 0.25],
                      [0.00, 0.00, 0.50],
                      [0.00, 0.00, 0.75],
                      [0.00, 0.00, 1.00],
                      [0.00, 0.00, 1.50],
                      [0.00, 0.00, 2.00],
                      [0.00, 0.00, 5.00]])

t_weights = np.array([[0.00, 0.10, 0.00],
                      [0.00, 0.25, 0.00],
                      [0.00, 0.50, 0.00],
                      [0.00, 0.75, 0.00],
                      [0.00, 1.00, 0.00],
                      [0.00, 1.50, 0.00],
                      [0.00, 2.00, 0.00],
                      [0.00, 5.00, 0.00]])

s_weights = np.array([[0.10, 0.00, 0.00],
                      [0.25, 0.00, 0.00],
                      [0.50, 0.00, 0.00],
                      [0.75, 0.00, 0.00],
                      [1.00, 0.00, 0.00],
                      [1.50, 0.00, 0.00],
                      [2.00, 0.00, 0.00],
                      [5.00, 0.00, 0.00]])

stq_weights = np.array([1.00, 1.00, 1.00])

net_weights_arr = np.array([zero_weights.tolist(),
                            *q_weights.tolist(),
                            *t_weights.tolist(),
                            *s_weights.tolist(),
                            stq_weights.tolist()])
                   #*list(t_weights), *list(s_weights), list(stq_weights)]
net_weights_arr


In [None]:
s_weights.tolist()

In [None]:
zero_weights = [0.00, 0.00, 0.00]   # set weights to 0.

q_weights = [[0.00, 0.00, 0.10],   # loop over Q weights
             [0.00, 0.00, 0.25],
             [0.00, 0.00, 0.50],
             [0.00, 0.00, 0.75],
             [0.00, 0.00, 1.00],
             [0.00, 0.00, 1.50],
             [0.00, 0.00, 2.00],
             [0.00, 0.00, 5.00]]

t_weights = [[0.00, 0.10, 0.00],
             [0.00, 0.25, 0.00],
             [0.00, 0.50, 0.00],
             [0.00, 0.75, 0.00],
             [0.00, 1.00, 0.00],
             [0.00, 1.50, 0.00],
             [0.00, 2.00, 0.00],
             [0.00, 5.00, 0.00]]

s_weights = [[0.10, 0.00, 0.00],
             [0.25, 0.00, 0.00],
             [0.50, 0.00, 0.00],
             [0.75, 0.00, 0.00],
             [1.00, 0.00, 0.00],
             [1.50, 0.00, 0.00],
             [2.00, 0.00, 0.00],
             [5.00, 0.00, 0.00]]

stq_weights = [1.00, 1.00, 1.00]

net_weights_arr = [zero_weights, *q_weights,
                   *t_weights, *s_weights, stq_weights]

In [None]:
net_weights_arr = np.array(net_weights_arr)
net_weights_arr

In [None]:
nw = net_weights_arr
zero = nw[0]
q = nw[1:9]
t = nw[9:17]
s = nw[17:25]
stq = nw[25]

print(zero)
print('\n')
print(q)
print(len(q))
print('\n')
print(t)
print(len(t))
print('\n')
print(s)
print(len(s))
print('\n')
print(stq)


In [None]:
print(plaq_diff_data[11:16])
print('\n')
print(plaq_diff_data[16])
print('\n')
print(plaq_diff_data[17])
print('\n')
print(plaq_diff_data[17:])

In [None]:
zero_data = plaq_diff_data[0]
q_data = plaq_diff_data[1:6]
t_data = plaq_diff_data[6:11]
s_data = plaq_diff_data[11:16]
tq_data = plaq_diff_data[16]
sq_data = plaq_diff_data[17]
st_data = plaq_diff_data[18]
rand_data = plaq_diff_data[19]
stq_data = plaq_diff_data[20]

In [None]:
qx, qy = q_data[:, 2], q_data[:, -1]
tx, ty = t_data[:, 1], t_data[:, -1]
sx, sy = s_data[:, 0], s_data[:, -1]

In [None]:
xlabel = 'Net weight'
ylabel = 'Avg. plaq. difference'
fig, ax = plt.subplots()
ax.plot(qx, qy, marker='.', label='Transformation (Q) fn')
ax.plot(tx, ty, marker='.', label='Translation (T) fn')
ax.plot(sx, sy, marker='.', label='Scale (S) fn')
ax.plot(0, zero_data[-1], marker='s', label='S, T, Q = 0')
ax.plot(1, stq_data[-1], marker='v', label='S, T, Q = 1')
ax.set_xlabel(xlabel, fontsize=14)
ax.set_ylabel(ylabel, fontsize=14)
ax.legend(loc='best')
plt.tight_layout()
figs_dir = os.path.join(log_dir, 'figures')
out_file = os.path.join(figs_dir, 'plaq_diff_vs_net_weights.pdf')
plt.savefig(out_file, dpi=400, bbox_inches='tight')

In [None]:
xlabel = 'Net weight'
ylabel = 'Avg. plaq. difference'
fig, ax = plt.subplots()
ax.plot(qx, qy, marker='.', label='Transformation (Q) fn')
ax.plot(tx, ty, marker='.', label='Translation (T) fn')
ax.plot(sx, sy, marker='.', label='Scale (S) fn')
ax.plot(0, zero_data[-1], marker='s', label='S, T, Q = 0')
ax.plot(1, stq_data[-1], marker='v', label='S, T, Q = 1')
ax.set_xlabel(xlabel, fontsize=14)
ax.set_ylabel(ylabel, fontsize=14)
ax.legend(loc='best')
plt.tight_layout()
figs_dir = os.path.join(log_dir, 'figures')
out_file = os.path.join(figs_dir, 'plaq_diff_vs_net_weights.pdf')
plt.savefig(out_file, dpi=400, bbox_inches='tight')

In [None]:
plaq_diff_txt_file = os.path.join(log_dir, 'plaq_diffs_data.txt')
net_weights = []
avg_plaq_diffs = []
with open(plaq_diff_data_file, 'r') as f:
    import pdb
    pdb.set_trace()
    line = f.readline()
    vals = [float(i) for i in line.split(',')]
    net_weights.extend(vals[:2])
    avg_plaq_diff.append(vals[-1])
    

## OLDER

In [None]:
diff_dict = {}

In [None]:
transl_weights = [0., 0.1, 0.25, 0.5, 0.75, 1.]

In [None]:
transl_weights[0:2]

In [None]:
transl_weights[:3]

In [None]:
transl_weights = [0., 0.1, 0.25, 0.5, 0.75, 1.]

offsets10_17_1036 = [0.00136, -0.00323, -0.00947, -0.01574, -0.01882, -0.02056]
diff_dict['lf10_2019_08_17_1036'] = offsets10_17_1036

offsets10_19_1946 = [0.00221, -0.00221, -0.00795, -0.01353, -0.01626, -0.01760]
diff_dict['lf10_2019_08_19_1946'] = offsets10_19_1946

offsets12_19_1713 = [-0.00047, -0.00502, -0.01164, -0.01796, -0.02116, -0.02228]
diff_dict['lf12_2019_08_19_1713'] = offsets12_19_1713

offsets15_19_2332 = [0.00043, -0.00424, -0.01042, -0.01727, -0.02072, -0.02231]
diff_dict['lf15_2019_08_19_2332'] = offsets15_19_2332

offsets16_15_1915 = [-0.00180, -0.00452, -0.00811, -0.01249, -0.01485, -0.01638]
diff_dict['lf16_2019_08_15_1915'] = offsets16_15_1915

offsets16_16_1735 = [-0.00520, -0.00819, -0.01191, -0.01593, -0.01787, -0.01887]
diff_dict['lf16_2019_08_16_1735'] = offsets16_16_1735

offsets16_17_0745 = [-0.00221, -0.00472, -0.00854, -0.01300, -0.01553, -0.01703]
diff_dict['lf16_2019_08_17_0745'] = offsets16_17_0745

offsets16_20_0052 = [0.00028, -0.00276, -0.00700, -0.01213, -0.01484, -0.01620]
diff_dict['lf16_2019_08_20_0052'] = offsets16_20_0052

offsets16_20_0504 = [0.00165, -0.00614, -0.01324, -0.01775, -0.01889, -0.01986]
diff_dict['lf16_2019_08_20_0504'] = offsets16_20_0504

offsets16_13_1546 = mean_diff_dict['2019_08_13_1546']
diff_dict['lf16_2019_08_13_1546'] = mean_diff_dict['2019_08_13_1546']

In [None]:
diff_dict

In [None]:
fig, ax = plt.subplots()
for key, val in diff_dict.items():
    lf_steps = int(key.split('_')[0].lstrip('lf'))
    if lf_steps == 16:
        ax.plot(transl_weights, val,
                label=r'$N_{\mathrm{LF}} = 16$')
        
ax.set_xlabel('Translation weight', fontsize=14)
ax.set_ylabel(r"$\langle\delta_{\phi_{P}}^{(\mathrm{obs})}\rangle$",
              #r"- \delta_{\phi_{P}}^{(\mathrm{exp})}$",
              fontsize=14)
plt.tight_layout()
ax.legend(loc='best')

In [None]:
fig, ax = plt.subplots()
ax.plot(transl_weights, offsets10, marker='s', ls=':', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 10$')
ax.plot(transl_weights, offsets12, marker='d', ls='-.', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 12$')
ax.plot(transl_weights, offsets15, marker='v', ls='--', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 15$')
ax.plot(transl_weights, offsets16, marker='.', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.plot(transl_weights, offsets16_1, marker='>', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.plot(transl_weights, offsets16_2, marker='H', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.plot(transl_weights, offsets16_3, marker='^', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.set_xlabel('Translation weight', fontsize=14)
ax.set_ylabel(r"$\langle\delta_{\phi_{P}}^{(\mathrm{obs})}\rangle$",
              #r"- \delta_{\phi_{P}}^{(\mathrm{exp})}$",
              fontsize=14)
plt.tight_layout()
ax.legend(loc='best')

of = '../../logs/figures/avg_plaq_diff_vs_num_lf2.pdf'
out_file = os.path.join(*of.split('/'))
plt.savefig(out_file, dpi=400, bbox_inches='tight')

In [None]:
transl_weights = [0., 0.1, 0.25, 0.5, 0.75, 1.]
offsets16 = [0.00028, -0.00276, -0.00700, -0.01213, -0.01484, -0.01620]
offsets16_1 = [0.00165, -0.00614, -0.01324, -0.01775, -0.01889, -0.01986]
offsets16_2 = [-0.00180, -0.00452, -0.00811, -0.01249, -0.01485, -0.01638]
offsets16_3 = [-0.00221, -0.00472, -0.00854, -0.01300, -0.01553, -0.01703]

fig, ax = plt.subplots()
ax.plot(transl_weights, offsets16, #marker='.', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.plot(transl_weights, offsets16_1, #marker='>', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.plot(transl_weights, offsets16_2, #marker='H', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.plot(transl_weights, offsets16_3, #marker='^', ls='-', fillstyle='none',
        label=r'$N_{\mathrm{LF}} = 16$')
ax.set_xlabel('Translation weight', fontsize=14)
ax.set_ylabel(r"$\langle\delta_{\phi_{P}}^{(\mathrm{obs})}\rangle$",
              #r"- \delta_{\phi_{P}}^{(\mathrm{exp})}$",
              fontsize=14)
plt.tight_layout()
ax.legend(loc='best')

of = '../../logs/figures/avg_plaq_diff_vs_num_lf3.pdf'
out_file = os.path.join(*of.split('/'))
plt.savefig(out_file, dpi=400, bbox_inches='tight')

In [None]:
from scipy.stats import sem

offsets16 = np.array(
    [[0.00028, -0.00276, -0.00700, -0.01213, -0.01484, -0.01620],
     [0.00165, -0.00614, -0.01324, -0.01775, -0.01889, -0.01986],
     [-0.00180, -0.00452, -0.00811, -0.01249, -0.01485, -0.01638],
     [-0.00221, -0.00472, -0.00854, -0.01300, -0.01553, -0.01703]]
)
avg16 = offsets16.mean(axis=0)
err16 = sem(offsets16, axis=0)

In [None]:
err16.shape

In [None]:
from scipy.stats import sem


transl_weights = [0., 0.1, 0.25, 0.5, 0.75, 1.]
offsets10 = [0.00221, -0.00221, -0.00795, -0.01353, -0.01626, -0.01760]
offsets12 = [-0.00047, -0.00502, -0.01164, -0.01796, -0.02116, -0.02228]
offsets15 = [0.00043, -0.00424, -0.01042, -0.01727, -0.02072, -0.02231]

fig, ax = plt.subplots()
ax.plot(transl_weights, offsets10, #marker='s',
        ls=':', fillstyle='none', label=r'$N_{\mathrm{LF}} = 10$')
ax.plot(transl_weights, offsets12, #marker='d',
        ls='-.', fillstyle='none', label=r'$N_{\mathrm{LF}} = 12$')
ax.plot(transl_weights, offsets15, #marker='v',
        ls='--', fillstyle='none', label=r'$N_{\mathrm{LF}} = 15$')
ax.errorbar(transl_weights, avg16, yerr=err16, marker='', ls='-', #fillstyle='none',
            #ecolor='k', color='k',
            label=r'$N_{\mathrm{LF}} = 16$')
ax.set_xlabel('Translation weight', fontsize=14)
ax.set_ylabel(r"$\langle\delta_{\phi_{P}}^{(\mathrm{obs})}\rangle$",
              #r"- \delta_{\phi_{P}}^{(\mathrm{exp})}$",
              fontsize=14)
plt.tight_layout()
ax.legend(loc='best')

of = '../../logs/figures/avg_plaq_diff_vs_num_lf1.pdf'
out_file = os.path.join(*of.split('/'))
plt.savefig(out_file, dpi=400, bbox_inches='tight')

In [None]:
mean_diff_dict = {}

In [None]:
log_dirs = [
    ('../../logs/cooley_logs/2019_08_13/2019_08_13_1546/'
     'lattice8_batch128_lf16_qw10_aw10_generic_dp05/'),
    ('../../logs/cooley_logs/2019_08_15/2019_08_15_1915/'
     'lattice8_batch128_lf16_qw10_aw10_generic_dp05/'),
    ('../../logs/cooley_logs/2019_08_16/2019_08_16_1735/'
     'lattice8_batch128_lf16_qw10_aw10_generic_dp05/'),
    ('../../logs/cooley_logs/2019_08_17/2019_08_17_0745/'
     'lattice8_batch128_lf16_qw10_aw10_generic_dp05/'),
    ('../../logs/cooley_logs/2019_08_17/2019_08_17_1036/'
     'lattice8_batch128_lf10_qw10_aw10_generic_dp05/'),
    ('../../logs/2019_08_19/2019_08_19_1713/'
     'lattice8_batch100_lf12_qw10_aw10_generic_dp05/'),
    ('../../logs/2019_08_19/2019_08_19_1946/'
     'lattice8_batch100_lf10_qw00_aw10_generic_dp05/'),
    ('../../logs/2019_08_19/2019_08_19_2332/'
     'lattice8_batch100_lf15_qw00_aw10_generic_dp05/'),
    ('../../logs/cooley_logs/2019_08_20/2019_08_20_0052/'
     'lattice8_batch128_lf16_qw00_aw10_generic_dp05/'),
    ('../../logs/cooley_logs/2019_08_20/2019_08_20_0504/'
     'lattice8_batch128_lf16_qw00_aw00_generic_dp05/'),
]

plaq_diff_dict = {}

for ld in log_dirs:
    log_dir = os.path.join(*ld.split('/'))
    params_file = os.path.join(log_dir, 'parameters.pkl')
    with open(params_file, 'rb') as f:
        params = pickle.load(f)

    runs_dirs = os.path.join(log_dir, 'runs')
    figs_dir = os.path.join(log_dir, 'figures')
    run_dirs = [os.path.join(runs_dirs, i) for i in os.listdir(runs_dirs)]
    run_strs = [i.split('/')[-1] for i in run_dirs]
    
    #run_params_files = [os.path.join(i, 'parameters.pkl') for i in run_dirs]
    #run_data_files = [os.path.join(i, 'run_data.pkl') for i in run_dirs]
    #with open(run_data_files[0], 'rb') as f:
    #    run_data = pickle.load(f)
    #with open(run_params_files[0], 'rb') as f:
    #    run_params = pickle.load(f)
        
    plotter = GaugeModelPlotter(params, figs_dir)
    
    mean_diff_arr = []
    transl_weights = []
    for d, s in zip(run_dirs, run_strs):
        run_data_file = os.path.join(d, 'run_data.pkl')
        run_params_file = os.path.join(d, 'parameters.pkl')
        with open(run_data_file, 'rb') as f:
            run_data = pickle.load(f)
        with open(run_params_file, 'rb') as f:
            run_params = pickle.load(f)

        weights = {
            'charge_weight': run_params['charge_weight'],
            'net_weights': run_params['net_weights']
        }
        xy_data, kwargs = plotter._plot_setup(run_data, 5., s, weights)
        x, y, yerr = xy_data['plaqs_diffs']
        mean_diff = np.mean(y)
        #mean_diff = plotter._plot_plaqs_diffs(xy_data['plaqs_diffs'], **kwargs)
        transl_weights.append(weights['net_weights'][1])
        #print(f"{weights['net_weights']}: {mean_diff}\n")
        #mean_diff = plotter.plot_observables(run_data, 5., s, weights)
        mean_diff_arr.append(mean_diff)

    wd_dict = dict(zip(transl_weights, mean_diff_arr))
    weights_diffs_dict = OrderedDict(sorted(wd_dict.items(),
                                            key=lambda k: k[0]))
    
    if log_dir[-1] == '/':
        log_dir = log_dir.rstrip('/')
        
    dict_str = log_dir.split('/')[-2]
    lf_steps = run_params['num_steps']
    plaq_diff_dict_key = f'lf{lf_steps}_' + dict_str
    # create key, value pair containing path to log dir
    plaq_diff_dict[plaq_diff_dict_key] = {
        'data': weights_diffs_dict,
        'log_dir': log_dir
    }
    
pd_dict = OrderedDict(sorted(plaq_diff_dict.items(),
                             key=lambda k: int(k[0].split('_')[0].lstrip('lf'))))

In [None]:
lf16_yarr

In [None]:
#lf16_xarr = []
lf10_yarr = []
lf12_yarr = []
lf15_yarr = []
lf16_yarr = []
fig, ax = plt.subplots()
for key, val in pd_dict.items():
    lf_steps = int(key.split('_')[0].lstrip('lf'))
    xy_data = val['data']
    x = list(xy_data.keys())
    y = list(xy_data.values())
    if lf_steps == 10:
        lf10_yarr.append(y)
    if lf_steps == 12:
        lf12_yarr.append(y)
    if lf_steps == 15:
        lf15_yarr.append(y)
    if lf_steps == 16:
        lf16_yarr.append(y)
        
# N_lf = 10
lf10_yavg = np.mean(np.array(lf10_yarr), axis=0)
lf10_yerr = sem(np.array(lf10_yarr), axis=0)
ax.errorbar(x, lf10_yavg, yerr=lf10_yerr,  capthick=2., capsize=2., 
            ls=':', label=r'$N_{\mathrm{LF}} = $' + '10')

# N_lf = 12
ax.plot(x, lf12_yarr[0], ls='-.', label=r'$N_{\mathrm{LF}} = $' + '12')

# N_lf = 15
ax.plot(x, lf15_yarr[0], ls='--', label=r'$N_{\mathrm{LF}} = $' + '15')
    
# N_lf = 16
lf16_yavg = np.mean(np.array(lf16_yarr), axis=0)
lf16_yerr = sem(np.array(lf16_yarr), axis=0)
ax.errorbar(x, lf16_yavg, yerr=lf16_yerr, capthick=1.5, capsize=1.5,
            label=r'$N_{\mathrm{LF}} = $' + '16')
        
xlabel = 'Translation weight'
ylabel = r"$\langle\delta_{\phi_{P}}^{(\mathrm{obs})}\rangle$"
ax.grid(True)
ax.set_xlabel(xlabel, fontsize=14)
ax.set_ylabel(ylabel, fontsize=14)
plt.tight_layout()
ax.legend(loc='best')

of = '../../logs/figures/avg_plaq_diff_vs_transl_weight_errs.pdf'
#of = '../../logs/figures/avg_plaq_diff_vs_transl_weight_lf10_12_15.pdf'
#of = '../../logs/figures/avg_plaq_diff_vs_transl_weight_lf16.pdf'
out_file = os.path.join(*of.split('/'))
plt.savefig(out_file, dpi=400, bbox_inches='tight')
        
#ax.set_xlabel('Translation weight', fontsize=14)
#ax.set_ylabel(r"$\langle\delta_{\phi_{P}}^{(\mathrm{obs})}\rangle$",
#              fontsize=14)
#plt.tight_layout()
#ax.legend(loc='best')

In [None]:
debug

In [None]:
ld = ('../../logs/cooley_logs/2019_08_13/2019_08_13_1546/'
      'lattice8_batch128_lf16_qw10_aw10_generic_dp05')
log_dir = os.path.join(*ld.split('/'))

params_file = os.path.join(log_dir, 'parameters.pkl')
with open(params_file, 'rb') as f:
    params = pickle.load(f)
    
runs_dirs = os.path.join(log_dir, 'runs')
figs_dir = os.path.join(log_dir, 'figures')
run_dirs = [os.path.join(runs_dirs, i) for i in os.listdir(runs_dirs)]
run_strs = [i.split('/')[-1] for i in run_dirs]
run_strs
run_dirs

In [None]:
run_params_files = [os.path.join(i, 'parameters.pkl') for i in run_dirs]
run_data_files = [os.path.join(i, 'run_data.pkl') for i in run_dirs]

In [None]:
with open(run_data_files[0], 'rb') as f:
    run_data = pickle.load(f)
with open(run_params_files[0], 'rb') as f:
    run_params = pickle.load(f)
run_params['net_weights']

In [None]:
plotter = GaugeModelPlotter(run_params, figs_dir)

In [None]:
dict_str = log_dir.split('/')[-2]
lf_steps = run_params['num_steps']
dict_key = f'lf{lf_steps}_' + dict_str

In [None]:
mean_diff_arr = []
transl_weights = []
for d, s in zip(run_dirs, run_strs):
    run_data_file = os.path.join(d, 'run_data.pkl')
    run_params_file = os.path.join(d, 'parameters.pkl')
    with open(run_data_file, 'rb') as f:
        run_data = pickle.load(f)
    with open(run_params_file, 'rb') as f:
        run_params = pickle.load(f)
        
    weights = {
        'charge_weight': run_params['charge_weight'],
        'net_weights': run_params['net_weights']
    }
    xy_data, kwargs = plotter._plot_setup(run_data, 5., s, weights)
    mean_diff = plotter._plot_plaqs_diffs(xy_data['plaqs_diffs'], **kwargs)
    transl_weights.append(weights['net_weights'][1])
    #print(f"{weights['net_weights']}: {mean_diff}\n")
    #mean_diff = plotter.plot_observables(run_data, 5., s, weights)
    mean_diff_arr.append(mean_diff)
    
wd_dict = dict(zip(transl_weights, mean_diff_arr))
weights_diffs_dict = OrderedDict(sorted(weights_diffs_dict.items(),
                                        key=lambda k: k[0]))
    
dict_str = log_dir.split('/')[-2]
lf_steps = run_params['num_steps']
dict_key = f'lf{lf_steps}_' + dict_str
diff_dict[dict_key] = mean_diff_arr
#mean_diff_dict[dict_str] = mean_diff_arr

In [None]:
from collections import OrderedDict
#OrderedDict(sorted(d.items(), key=lambda t: t[1]))
weights_diffs_dict =  dict(zip(transl_weights, mean_diff_arr))
weights_diffs_dict = OrderedDict(sorted(weights_diffs_dict.items(),
                                        key=lambda k: k[0]))
weights_diffs_dict

In [None]:
transl_weights
mean_diff_arr

In [None]:
plotter.plot_observables(run_data, 5., run_strs[0], weights)

In [None]:
xy_data = plotter._parse_data(run_data, 5.)

In [None]:
kwargs = {
    'markers': False,
    'lines': True,
    'alpha': 0.6,
    'legend': False,
    'ret': False,
    'out_file': []
}

In [None]:
plotter._plot_plaqs_diffs(xy_data['plaqs_diffs'], **kwargs)

In [None]:
'../args/params.pkl'
params_file = os.path.join('..', 'args', 'params.pkl')
with open(params_file, 'rb') as f:
    params = pickle.load(f)

In [None]:
from inference import create_config

checkpoint_dir = os.path.join(params['log_dir'], 'checkpoints/')
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)

config, params = create_config(params)
sess = tf.Session(config=config)
saver = tf.train.import_meta_graph(f'{checkpoint_file}.meta')
saver.restore(sess, checkpoint_file)

In [None]:
autocorrs = np.array(run_data['charges_autocorrs'])
autocorrs_avg = np.mean(autocorrs, axis=0)
num_steps = autocorrs.shape[1]
mid = num_steps // 2
lower = int(mid - num_steps * 0.1)
upper = int(mid + num_steps * 0.1)

fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=False)
ax0.plot(np.arange(len(autocorrs_avg)), autocorrs_avg)
ax1.plot(np.arange(lower, upper), autocorrs_avg[lower:upper])

In [None]:
from plotters.gauge_model_plotter import GaugeModelPlotter
import utils.file_io as io

figs_dir = os.path.join(log_dir, 'figures')
io.check_else_make_dir(figs_dir)
run_str = 'steps_10000_beta_50_eps_016_qw_10_0'

weights = {
    'charge_weight': params['charge_weight'],
    'net_weights': [1., 1., 1.]
}

plotter = GaugeModelPlotter(params, figs_dir)
#plotter.plot_observables(run_data, beta=6., run_str=run_str, weights=weights)

In [None]:
xy_data = plotter._parse_data(run_data, beta=6.)

In [None]:
plotter._plot_plaqs(xy_data['plaqs'], beta=6., save=False)

In [None]:
plotter.plot_observables(run_data, beta=6., run_str=run_str, weights=weights)

In [None]:
plt.errorbar()

In [None]:
from utils.attr_dict import AttrDict
#ld = ('../../logs/2019_07_23/2019_07_23_1349/'
#      'lattice8_batch128_lf6_qw10_aw10_conv2D_dp00_bn/')
#log_dir = os.path.join(*ld.split('/'))
#params_file = os.path.join(log_dir, 'parameters.pkl')
pf = '../args/params.pkl'
params_file = os.path.join(*pf.split('/'))
#params_file = os.path.join('..', '', 'params.pkl')
with open(params_file, 'rb') as f:
    params = pickle.load(f)
params    
params['data_format'] = 'channels_last'

#for key, val in params.items():
#    FLAGS.__dict__[key] = val
    
#FLAGS = AttrDict(params.items())

In [None]:
params['log_dir'] = None

import utils.file_io as io
log_dir = io.create_log_dir(params)

params['log_dir'] = io.create_log_dir(params)

In [None]:
checkpoint_dir = os.path.join(params['log_dir'], 'checkpoints/')
io.check_else_make_dir(checkpoint_dir)

model = GaugeModel(params=params)

In [None]:
from variables import TF_FLOAT, NP_FLOAT, GLOBAL_SEED
from loggers.train_logger import TrainLogger
from inference import create_config

train_logger = TrainLogger(model, log_dir, params['summaries'])
config, params = create_config(params)

charge_weight_init = params['charge_weight']
net_weights_init = [1., 1., 1.]
samples_init = np.reshape(np.array(model.lattice.samples, dtype=NP_FLOAT),
                          (model.num_samples, model.x_dim))
beta_init = model.beta_init

init_feed_dict = {
    model.x: samples_init,
    model.beta: beta_init,
    model.charge_weight: charge_weight_init,
    model.net_weights[0]: net_weights_init[0],  # scale_weight
    model.net_weights[1]: net_weights_init[1],  # transformation_weight
    model.net_weights[2]: net_weights_init[2],  # translation_weight
    model.train_phase: True,
}

target_collection = []
collection = tf.local_variables() + target_collection

local_init_op = tf.variables_initializer(collection)
ready_for_local_init_op = tf.report_uninitialized_variables(collection)
init_op = tf.global_variables_initializer()

scaffold = tf.train.Scaffold(
    init_feed_dict=init_feed_dict,
    local_init_op=local_init_op,
    ready_for_local_init_op=ready_for_local_init_op
)

In [None]:
# ----------------------------------------------------------
#                       TRAINING
# ----------------------------------------------------------
hooks = []
sess_kwargs = {
    'checkpoint_dir': checkpoint_dir,
    'scaffold': scaffold,
    'hooks': hooks,
    'config': config,
    'save_summaries_secs': None,
    'save_summaries_steps': None
}

sess = tf.train.MonitoredTrainingSession(**sess_kwargs)
tf.keras.backend.set_session(sess)
sess.run(init_op)

trainer = GaugeModelTrainer(sess, model, train_logger)
train_kwargs = {
    'samples_np': samples_init,
    'beta_np': beta_init,
    'net_weights': net_weights_init
}

trainer.train(20, **train_kwargs)

sess.close()
tf.reset_default_graph()

In [None]:
from inference import create_config
from loggers.run_logger import RunLogger
from plotters.gauge_model_plotter import GaugeModelPlotter

checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
config, params = create_config(params)

sess = tf.Session(config=config)

saver = tf.train.import_meta_graph(f'{checkpoint_file}.meta')
saver.restore(sess, checkpoint_file)

run_ops = tf.get_collection('run_ops')
inputs = tf.get_collection('inputs')

run_logger  = RunLogger(params, inputs, run_ops, save_lf_data=False)
plotter = GaugeModelPlotter(params, run_logger.figs_dir)

In [None]:
from inference import inference_setup
from runners.runner import GaugeModelRunner
from inference import run_inference

kwargs = {
    'run_steps': 100,
    'loop_net_weights': True,
    'plot_lf': True
}
params.update(kwargs.items())

inference_dict = inference_setup(params)

runner = GaugeModelRunner(sess, params, inputs, run_ops, run_logger)
run_inference(inference_dict, runner, run_logger, plotter)

In [None]:
runner.eps

In [None]:
tf.global_variables()

In [None]:
tf.get_collection(tf.GraphKeys.UPDATE_OPS)

In [None]:
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
checkpoint_dir = os.path.join(params['log_dir'], 'checkpoints')
sess = tf.Session(config=config)
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
run_logger = RunLogger(model, params['log_dir'], save_lf_data=False)
plotter = GaugeModelPlotter(model, run_logger.figs_dir)

In [None]:
xnet_x, xnet_v, generic_net = model.dynamics.x_fn.layers

In [None]:
xconv_x1 = xnet_x.layers[0]

In [None]:
k, b = xconv_x1.get_weights()
k_rand = np.random.randn(*k.shape)
b_rand = np.random.randn(*b.shape)

xconv_x1.set_weights([k_rand, b_rand])

In [None]:
xconv_x1.set_weights

In [None]:
xnet = model.dynamics.x_fn
vnet = model.dynamics.v_fn

for xblock, vblock in zip(xnet.layers, vnet.layers):
    for xlayer, vlayer in zip(xblock.layers, vblock.layers):
        try:
            print(f'xlayer.name: {xlayer.name}')
            print(f'vlayer.name: {vlayer.name}')
            kx, bx = xlayer.get_weights()
            kv, bv = vlayer.get_weights()
            kx_rand = np.random.randn(*kx.shape)
            bx_rand = np.random.randn(*bx.shape)
            kv_rand = np.random.randn(*kv.shape)
            bv_rand = np.random.randn(*bv.shape)

            xlayer.set_weights([kx, bx])
            vlayer.set_weights([kv, bv])
        except ValueError:
            print(f'Unable to set weights for: {xlayer.name}')
            print(f'Unable to set weights for: {vlayer.name}')
        

In [None]:
xconv_x1w, xconv_x1b = xconv_x1.get_weights()

In [None]:
random_sample = np.random.randn(*model.lattice.samples.shape)

In [None]:
random_sample = np.array(np.random.randn(*model.lattice.samples.shape), dtype=np.float32)
samples_reshaped = tf.reshape(tf.convert_to_tensor(random_sample), (-1, 8, 8, 2))
xconv_out1 = model.dynamics.x_fn.x_conv_net.conv1(samples_reshaped)

In [None]:
xconv_out1.shape

In [None]:
for f in range(xconv_out1.shape[-1]):
    fig, ax = plt.subplots()
    ax.imshow(sess.run(xconv_out1[0, :, :, f]))

In [None]:
xconv1w, xconv1b = model.dynamics.x_fn.x_conv_net.conv1.get_weights()

In [None]:
import matplotlib.pyplot as plt


In [None]:
for c in range(xconv1w.shape[2]):
    for f in range(xconv1w.shape[-1]):
        fig, ax = plt.subplots()
        ax.imshow(xconv1w[:, :, c, f])

In [None]:
model.dynamics.x_fn.x_conv_net.conv1.get_output_at(0)

In [None]:
sess.run(model.dynamics.x_fn.x_conv_net.conv1, feed_dict=init_feed_dict)

In [None]:
train_logger = TrainLogger(model, model.log_dir, FLAGS.summaries)

In [None]:
from globals import TF_FLOAT, NP_FLOAT, GLOBAL_SEED

charge_weight_init = FLAGS.charge_weight
net_weights_init = [1., 1., 1.]
samples_init = np.reshape(np.array(model.lattice.samples, dtype=NP_FLOAT),
                          (model.num_samples, model.x_dim))
beta_init = model.beta_init

init_feed_dict = {
    model.x: samples_init,
    model.beta: beta_init,
    model.charge_weight: charge_weight_init,
    model.net_weights[0]: net_weights_init[0],  # scale_weight
    model.net_weights[1]: net_weights_init[1],  # transformation_weight
    model.net_weights[2]: net_weights_init[2],  # translation_weight
    model.train_phase: True,
}

In [None]:
x_layers_dict = {
    (idx, l.name): l for idx, l in enumerate(model.dynamics.x_fn.layers)
}

xx_conv_layers = {
    l.name: l for l in model.dynamics.x_fn.x_conv_block.layers
}
xv_conv_layers = {
    l.name: l for l in model.dynamics.x_fn.v_conv_block.layers
}
x_generic_layers = {
    l.name: l for l in model.dynamics.x_fn.generic_block.layers
}
xx_conv_layers
xv_conv_layers
x_generic_layers

In [None]:
q = tf.reshape(model.x, (-1, *model.lattice.samples.shape[1:], 1))
xx_conv_layers['conv1'].compute_output_shape(q.shape)

In [None]:
q = tf.placeholder(dtype=TF_FLOAT, shape=model.x.shape, name='q')
if model.x.shape != model.dynamics.x_fn._input_shape[1:]:
    q = tf.reshape(q, (-1, *model.dynamics.x_fn._input_shape[1:]))
q.shape

In [None]:
xnet_xlayers = []
xnet_vlayers = []
for layer in model.dynamics.x_fn.layers:
    if 'conv_x' in layer.name:
        xnet_xlayers.append(layer)
    if 'conv_v' in layer.name:
        xnet_vlayers.append(layer)
        
for 

In [None]:
x_net_x_layers = model.dynamics.x_fn.layers

In [None]:
target_collection = []
collection = tf.local_variables() + target_collection

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint('./model.h5', verbose=1)
]

In [None]:
local_init_op = tf.variables_initializer(collection)
ready_for_local_init_op = tf.report_uninitialized_variables(collection)

In [None]:
import utils.file_io as io

checkpoint_dir = os.path.join(model.log_dir, 'checkpoints')
io.check_else_make_dir(checkpoint_dir)

scaffold = tf.train.Scaffold(
    init_feed_dict=init_feed_dict,
    local_init_op=local_init_op,
    ready_for_local_init_op=ready_for_local_init_op
)

sess_kwargs = {
    'checkpoint_dir': checkpoint_dir,
    'scaffold': scaffold,
    'hooks': [],
    'config': config,
    'save_summaries_secs': None,
    'save_summaries_steps': None
}

sess = tf.train.MonitoredTrainingSession(**sess_kwargs)

In [None]:
trainer = GaugeModelTrainer(sess, model, train_logger)

In [None]:
tf.keras.backend.set_learning_phase(True)

In [None]:
tf.keras.backend.learning_phase()

In [None]:
v_rand = tf.random_normal(tf.shape(model.x), seed=GLOBAL_SEED)
t = model.dynamics._get_time(0, tile=tf.shape(model.x)[0])

mask, mask_inv = model.dynamics._get_mask(0)

x_scale, x_translation, x_transformation = model.dynamics.x_fn(
    (v_rand, mask * model.x, t), model.train_phase
)

dynamics_out = model.dynamics(model.x, model.beta, model.net_weights, model.train_phase)

In [None]:
feed_dict = {
    model.x: samples_init,
    model.beta: beta_init,
    model.charge_weight: charge_weight_init,
    model.net_weights[0]: net_weights_init[0],  # scale_weight
    model.net_weights[1]: net_weights_init[1],  # transformation_weight
    model.net_weights[2]: net_weights_init[2],  # translation_weight
    model.train_phase: True,
}

In [None]:
dynamics_out = model.dynamics(model.x, model.beta, 
                              model.net_weights, model.train_phase)

In [None]:
from network.network_utils import batch_norm

x_reshaped = model.dynamics.x_fn.reshape_5D(model.x)
v_reshaped = model.dynamics.v_fn.reshape_5D(v_rand)

conv_x1 = model.dynamics.x_fn.conv_x1(x_reshaped)
max_pool_x1 = model.dynamics.x_fn.max_pool_x1(conv_x1)
conv_x2 = model.dynamics.x_fn.conv_x2(max_pool_x1)
bn_x = batch_norm(conv_x2, model.train_phase,
                  axis=model.dynamics.x_fn.bn_axis,
                  internal_update=True)
max_pool_x2 = model.dynamics.x_fn.max_pool_x2(tf.nn.relu(bn_x))

conv_v1 = model.dynamics.x_fn.conv_v1(v_reshaped)
max_pool_v1 = model.dynamics.x_fn.max_pool_v1(conv_v1)
conv_v2 = model.dynamics.x_fn.conv_v2(max_pool_v1)
bn_v = batch_norm(conv_v2, model.train_phase,
                  axis=model.dynamics.x_fn.bn_axis,
                  internal_update=True)
max_pool_v2 = model.dynamics.x_fn.max_pool_v2(tf.nn.relu(bn_v))

x_flat = model.dynamics.x_fn.flatten(max_pool_x2)
v_flat = model.dynamics.x_fn.flatten(max_pool_v2)

x_out = tf.nn.relu(model.dynamics.x_fn.x_layer(x_flat))
v_out = tf.nn.relu(model.dynamics.x_fn.v_layer(v_flat))
t_out = tf.nn.relu(model.dynamics.x_fn.t_layer(t))

h1 = tf.nn.relu(x_out + v_out + t_out)
h2 = tf.nn.relu(model.dynamics.x_fn.h_layer(h))

translation = model.dynamics.x_fn.translation_layer(h)
scale = (tf.nn.tanh(model.dynamics.x_fn.scale_layer(h))
         * tf.exp(model.dynamics.x_fn.coeff_scale))

transformation = (model.dynamics.x_fn.transformation_layer(h)
                  * tf.exp(model.dynamics.x_fn.coeff_transformation))

layers = {
    'conv_x12': conv_x12,
    'conv_x1': conv_x1,
    'max_pool_x1': max_pool_x1,
    'conv_x2': conv_x2,
    'bn_x': bn_x,
    'max_pool_x2': max_pool_x2,
    'x_flat': x_flat,
    'x_out': x_out,
    't_out': t_out,
    'h1': h1,
    'h2': h2,
}

print(f'model.lattice.samples.shape: {model.lattice.samples.shape}\n')
print(f'model.x.shape: {model.x.shape}\n')
print(f'x_reshaped.shape: {x_reshaped.shape}\n')
for name, layer in layers.items():
      print(f'{name}: {layer.shape}\n')