### Imports

In [50]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from mog_model import *
from matplotlib import gridspec
from utils.plot_helper import errorbar_plot
plt.style.use('/Users/saforem2/.config/matplotlib/stylelib/dark_jupyter.mplstyle')

plt.rc('text', usetex=False)
plt.rcParams['errorbar.capsize'] = 0
%matplotlib notebook
%autoreload 2

run_dir = '../log_mog_tf/run_251/'
data_dir = run_dir + 'run_info/'
figs_dir = run_dir + 'figures1/'
if not os.path.exists(figs_dir):
    os.makedirs(figs_dir)

#### Helper functions

In [51]:
def pkl_loader(_dir, _file):
    with open(_dir + _file, 'rb') as f:
        _data = pickle.load(f)
    return _data

def load_data(data_dir):
    _d = {}
    data = {}
    for file in os.listdir(data_dir):
        if file.endswith('.pkl'):
            key = file[:-4]
            _d[key] = pkl_loader(data_dir, file)
        for key, val in _d.items():
            data[key] = val
    return data

In [52]:
def fix_legends(axes):
    if isinstance(axes, (np.ndarray, list)):
        legends = [ax.get_legend() for ax in axes]
        for leg in legends:
            leg.texts[0].set_color('w')
    else:
        legend = axes.get_legend()
        for idx in range(len(legend.texts)):
            legend.texts[idx].set_color('w')
    return axes

In [53]:
def add_vline(axes, x, **kwargs):
    if isinstance(axes, (np.ndarray, list)):
        for ax in axes:
            ax.axvline(x, **kwargs)
                       #, color='C3', ls=':', lw=2.)
    else:
        axes.axvline(x, **kwargs)
    return axes

In [54]:
def fix_ticks(axes):
    if isinstance(axes, (np.ndarray, list)):
        for idx in range(len(axes)):
            axes[idx].tick_params(which='both', color='#474747', labelcolor='k')
    else:
        axes.tick_params(which='both', color='#474747', labelcolor='k')

### Define data_dir and load data

In [55]:
data = load_data(data_dir)

get_vals_as_arr = lambda _dict: np.array(list(_dict.values()))

tr0 = get_vals_as_arr(data['tunneling_rates'])
tr1 = get_vals_as_arr(data['tunneling_rates_highT'])
ar0 = get_vals_as_arr(data['acceptance_rates'])
ar1 = get_vals_as_arr(data['acceptance_rates_highT'])
d0 = get_vals_as_arr(data['distances'])
d1 = get_vals_as_arr(data['distances_highT'])

steps_arr = []
temp_arr = []
for key in data['tunneling_rates_highT'].keys():
    steps_arr.append(key[0]+1)
    temp_arr.append(key[1])

In [56]:
x_steps = 3 * [steps_arr]
x_temps = 3 * [temp_arr]

y_data = [tr0[:, 0], ar0[:, 0], d0[:, 0]]
y_err = [tr0[:, 1], ar1[:, 1], d1[:, 1]]

y_data_highT = [tr1[:, 0], ar1[:, 0], d1[:, 0]]

y_err_highT = [tr1[:, 1], ar1[:, 1], d1[:, 1]]

str0 = (f"{data['_params']['num_distributions']} in {data['_params']['x_dim']} dims; ")
str1 = (r'$\mathcal{N}_{\hat \mu}(\sqrt{2}\hat \mu;$'
        + r'${{{0}}}),$'.format(data['_params']['sigma']))
title = str0 + str1 + r'$T_{trajectory} = 1$'
title_highT = str0 + str1 + r'$T_{trajectory} > 1$'
def out_file(f): return figs_dir + f'{f}.pdf'

kwargs = {
    'fillstyle': 'full',
    'markersize': 3,
    'alpha': 1.,
    'capsize': 0,
    'capthick': 0,
    'x_label': 'Training step',
    'y_label': '',
    'legend_labels': ['Tunneling rate',
                      'Acceptance rate',
                      'Distance / step'],
    'title': title,
    'grid': True,
    'reverse_x': False,
}

### $(T = 1)$ Tunneling rate, Acceptance Rate and Avg. Distance vs. Training Step 

In [57]:
%matplotlib notebook

In [58]:
out_file0 = out_file('tr_ar_dist_steps_lowT')#, step)
fig, axes = errorbar_plot(x_steps, y_data, y_err, out_file=out_file0, **kwargs)
_ = fix_legends(axes)
_ = fix_ticks(axes)
sfig = fig.savefig(out_file0, dpi=400, bbox_inches='tight')

<IPython.core.display.Javascript object>

Saving figure to: ../log_mog_tf/run_251/figures1/tr_ar_dist_steps_lowT.pdf


### $(T > 1)$ Tunneling rate, Acceptance Rate and Avg. Distance vs. Training Step 

In [59]:
# for trajectories with temperature > 1 vs. STEP
out_file1 = out_file('tr_ar_dist_steps_highT')#, step)
kwargs1 = kwargs.copy()
kwargs1['title'] = title_highT
fig, axes = errorbar_plot(x_steps, y_data_highT, y_err_highT,
              out_file=out_file1, **kwargs1)
axes = fix_legends(axes)
_ = fix_ticks(axes)
fig.savefig(out_file1, dpi=400, bbox_inches='tight')

<IPython.core.display.Javascript object>

Saving figure to: ../log_mog_tf/run_251/figures1/tr_ar_dist_steps_highT.pdf


### $(T = 1)$ Tunneling rate, Acceptance Rate and Avg. Distance vs. Temperature

In [60]:
out_file2 = out_file('tr_ar_dist_temps_lowT')#, step)
# for trajectories with temperature = 1. vs TEMP
kwargs2 = kwargs.copy()
kwargs2['x_label'] = 'Temperature'
kwargs2['title'] = title
kwargs2['reverse_x'] = True
fig, axes = errorbar_plot(x_temps, y_data, y_err,
                          out_file=out_file2, **kwargs2)
_ = fix_legends(axes)
_ = fix_ticks(axes)
_ = add_vline(axes, 1, **{'color': 'C6', 'ls': '-', 'lw': 2.})
#axes[-1].set_xlim((15, 0.5))
fig.savefig(out_file2, dpi=400, bbox_inches='tight')
#ax.set_xlim(15, 1.05)

<IPython.core.display.Javascript object>

Saving figure to: ../log_mog_tf/run_251/figures1/tr_ar_dist_temps_lowT.pdf


### $(T > 1)$ Tunneling rate, Acceptance Rate and Avg. Distance vs. Temperature

In [61]:
out_file3 = out_file('tr_ar_dist_temps_highT')#, step)
# for trajectories with temperature > 1. vs TEMP
kwargs3 = kwargs.copy()
kwargs3['title'] = title_highT
kwargs3['x_label'] = 'Temperature'
kwargs3['reverse_x'] = True
fig, axes = errorbar_plot(x_temps, y_data_highT, y_err_highT,
                        out_file=out_file3, **kwargs3)
_ = add_vline(axes, 1, **{'color': 'C6', 'ls': '-', 'lw': 2.})
    
axes = fix_legends(axes)
_ = fix_ticks(axes)
fig.savefig(out_file3, dpi=400, bbox_inches='tight')
#axes[-1].set_xlim(10, 0.9)

<IPython.core.display.Javascript object>

Saving figure to: ../log_mog_tf/run_251/figures1/tr_ar_dist_temps_highT.pdf


### Annealing Schedule

In [62]:
temp0 = data['_params']['temp_init']
#steps = np.arange(0, max(steps_arr))
#steps
annealing_factor = data['_params']['annealing_factor']
annealing_steps = data['_params']['annealing_steps']
tunneling_steps = data['_params']['tunneling_rate_steps']
fixed_temps = []
fixed_steps = []
temp = temp0
for step in range(max(steps_arr)):
    if step % annealing_steps == 0:
        tt  = temp * annealing_factor
        if tt > 1:
            temp = tt
    if (step+1) % tunneling_steps == 0:
        fixed_steps.append(step+2)
        fixed_temps.append(temp)

In [65]:
plt.style.use('/Users/saforem2/.config/matplotlib/stylelib/dark_jupyter.mplstyle')
fig, ax = plt.subplots()
pt = ax.plot(fixed_steps, fixed_temps, ls='-', label='Fixed schedule', lw=2)
pt = ax.plot(steps_arr, temp_arr, label='Dynamic schedule', lw=2., alpha=0.75)
hl = ax.axhline(y=1., color='C6', ls='-', lw=2., label='T=1')
xl = ax.set_xlabel('Training step')
yl = ax.set_ylabel('Temperature')
lg = ax.legend(loc='best')
_ = fix_legends(ax)
#ylabels = ax.get_yticklabels()
#xlabels = ax.get_xticklabels()
#ax.set_yticklabels(ylabels, {'color': 'k'})
#ax.set_xticklabels(xlabels, {'color': 'k'})
#ax.set_yticklabels(ax.get_yticklabels(), {'color': 'k'})
#ax.set_xticklabels(ax.get_xticklabels(), {'color': 'k'})
_ = fix_ticks(ax)

plt.savefig(figs_dir + 'annealing_schedule.pdf', dpi=400, bbox_inches='tight')

<IPython.core.display.Javascript object>

In [64]:
for i in range(len(fixed_steps)):
    print(f'({fixed_steps[i]}, {fixed_temps[i]:.3g})\t'
          f'({steps_arr[i]}, {temp_arr[i]:.3g})')

(501, 24.5)	(501, 24)
(1001, 24)	(1001, 23.1)
(1501, 24)	(1501, 22.1)
(2001, 23.5)	(2001, 21.3)
(2501, 23.1)	(2501, 20.4)
(3001, 23.1)	(3001, 20.4)
(3501, 22.6)	(3501, 19.6)
(4001, 22.6)	(4001, 19.6)
(4501, 22.1)	(4501, 19.2)
(5001, 21.7)	(5001, 18.5)
(5501, 21.7)	(5501, 18.1)
(6001, 21.3)	(6001, 17.4)
(6501, 20.8)	(6501, 17)
(7001, 20.8)	(7001, 16.7)
(7501, 20.4)	(7501, 16)
(8001, 20.4)	(8001, 15.7)
(8501, 20)	(8501, 15.1)
(9001, 19.6)	(9001, 14.8)
(9501, 19.6)	(9501, 14.2)
(10001, 19.2)	(10001, 13.9)
(10501, 18.8)	(10501, 13.6)
(11001, 18.8)	(11001, 13.1)
(11501, 18.5)	(11501, 12.8)
(12001, 18.5)	(12001, 12.3)
(12501, 18.1)	(12501, 12.1)
(13001, 17.7)	(13001, 11.6)
(13501, 17.7)	(13501, 11.4)
(14001, 17.4)	(14001, 11.1)
(14501, 17)	(14501, 10.7)
(15001, 17)	(15001, 10.5)
(15501, 16.7)	(15501, 10.1)
(16001, 16.7)	(16001, 9.87)
(16501, 16.4)	(16501, 9.48)
(17001, 16)	(17501, 9.29)
(17501, 16)	(18001, 9.1)
(18001, 15.7)	(18501, 8.74)
(18501, 15.4)	(19001, 8.57)
(19001, 15.4)	(19501, 8.4

IndexError: list index out of range