In [None]:
import pickle
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from copy import deepcopy

from maxent.data.mnist import load_mnist
from maxent.boltzmann.base import train
from maxent.boltzmann.bernoulli import (
    BernoulliBoltzmannMachine, HintonInitializer, initialize_fantasy_state,
    get_reconstruction_error, LatentIncrementingInitializer, enlarge_latent)
from maxent.utils import History, ExponentialMovingAverage

In [None]:
# Global parameters

IMAGE_SIZE = (16, 16)
LATENT_SIZE = 64
BATCH_SIZE = 128
SEED = 42

INCREMENT = 8

tf.random.set_seed(SEED)

In [None]:
# Data
(X, y), _ = load_mnist(image_size=IMAGE_SIZE, binarize=True, minval=0, maxval=1)

In [None]:
def build_and_train(epochs: int, cache_path: str = None):
    ambient_size = IMAGE_SIZE[0] * IMAGE_SIZE[1]
    bm = BernoulliBoltzmannMachine(
        ambient_size=ambient_size,
        latent_size=LATENT_SIZE,
        initializer=HintonInitializer(X),
        max_step=100,
        tolerance=1e-1,
        connect_ambient_to_ambient=False,
        sync_ratio=0.25,
        seed=SEED,
    )
    if cache_path is None:
        dataset = tf.data.Dataset.from_tensor_slices(X)
        epochs = 20
        # epochs = 1  # XXX: test!
        dataset = dataset.shuffle(10000, seed=SEED).repeat(epochs).batch(BATCH_SIZE)
        fantasy_state = initialize_fantasy_state(bm, BATCH_SIZE, SEED)
        optimizer = tf.optimizers.Adam()
        fantasy_state = train(bm, optimizer, dataset, fantasy_state)
    else:
        try:
            with open(cache_path, 'rb') as f:
                bm, fantasy_state = pickle.load(f)
        except FileNotFoundError as e:
            print(f'[WARNING]: Cannot find file "{cache_path}", create new file on that path.')
            bm, fantasy_state = build_and_train(epochs, cache_path=None)
        with open(cache_path, 'wb') as f:
            pickle.dump((bm, fantasy_state), f)
    return bm, fantasy_state

In [None]:
base_bm, base_fantasy_state = build_and_train(1, cache_path='../dat/base_bm_for_rg_flow.pkl')

In [None]:
get_reconstruction_error(base_bm, X[:1000])

In [None]:
# Initialize
history = History()
bm = deepcopy(base_bm)
fantasy_state = deepcopy(base_fantasy_state)
iter_step = 0

def log(iter_step):
    history.log(iter_step, 'ambient_latent_kernel', bm.ambient_latent_kernel.numpy())
    history.log(iter_step, 'latent_latent_kernel', bm.latent_latent_kernel.numpy())
    history.log(iter_step, 'ambient_bias', bm.ambient_bias.numpy())
    history.log(iter_step, 'latent_bias', bm.latent_bias.numpy())

In [None]:
# infinite loop of incrementing
while bm.latent_size <= 512:
    print(f'The {iter_step + 1}th interation......')
    dataset = tf.data.Dataset.from_tensor_slices(X)
    epochs = 10  # enough epochs for ensuring the convergence of training.
    # epochs = 1  # XXX: test!
    dataset = dataset.shuffle(10000, seed=SEED).repeat(epochs).batch(BATCH_SIZE)
    inc_bm, inc_fantasy_state = enlarge_latent(bm, fantasy_state, INCREMENT)
    optimizer = tf.optimizers.Adam()
    inc_fantasy_state = train(inc_bm, optimizer, dataset, inc_fantasy_state)

    bm, fantasy_state, iter_step = inc_bm, inc_fantasy_state, iter_step + 1

    log(iter_step)

In [None]:
print('Current latent size:', bm.latent_size)

In [None]:
steps = sorted(list(history.logs.keys()))
kernel_diff_hist = []
for i, j in zip(steps[:-1], steps[1:]):
    U_i = history.logs[i]['ambient_latent_kernel'][:, :LATENT_SIZE]
    U_j = history.logs[j]['ambient_latent_kernel'][:, :LATENT_SIZE]
    kernel_diff_hist.append(U_j - U_i)
kernel_diff_hist = np.stack(kernel_diff_hist, axis=0)
kernel_diff_hist = ExponentialMovingAverage(0.9)(kernel_diff_hist, axis=0).numpy()

plt.plot(steps[1:], np.zeros_like(steps[1:]), '--', label='zero')

def plot_confidence_region(confidence, **plot_kwargs):
    lower = [np.quantile(x.reshape([-1]), (1 - confidence) / 2) for x in kernel_diff_hist]
    upper = [np.quantile(x.reshape([-1]), 1 - (1 - confidence) / 2) for x in kernel_diff_hist]
    plt.fill_between(steps[1:], lower, upper,
                     label=f'{(confidence * 100):.2f}% confidence region',
                     **plot_kwargs)

plot_confidence_region(0.6827, alpha=0.5)
plot_confidence_region(0.9544, alpha=0.25)
plot_confidence_region(0.9973, alpha=0.25)

plt.title('Averaged kernel difference history')
plt.legend(loc='lower right')
plt.show()