In [None]:
# If it's on colab:
# !pip install git+https://github.com/shuiruge/energymodel

In [None]:
import tensorflow as tf
from tensorflow.keras import models, layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tqdm import tqdm

from energymodel import (
    EnergyModel, random_uniform, LossMonitor, FantasyParticleMonitor,
    VectorFieldMonitor, LossGradientMonitor,
)

tf.compat.v1.reset_default_graph()
%load_ext tensorboard

## Dataset

In [None]:
(dataset, _), info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def filter_img(image, label):
    return image

def normalize_img(image):
    return 2 * tf.cast(image, 'float32') / 255 - 1

def reshape_img(image):
    return tf.reshape(image, [28*28])

def preprocess_dataset(dataset, batch_size):
    return (
        dataset
        .map(filter_img, num_parallel_calls=tf.data.AUTOTUNE)
        .map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
        .map(reshape_img, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(info.splits['train'].num_examples)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )

batch_size = 128
dataset = preprocess_dataset(dataset, batch_size)

## Model

In [None]:
# We employ LeNet-like CNN as the network.
# The activation function are changed from tanh to swish;
# and the top dense layers are slightly adjusted. Output layer
# has to be Dense(1, use_bias=False)
network = models.Sequential([
    layers.Reshape([28, 28, 1]),

    layers.Conv2D(6, kernel_size=5, strides=1, padding='same'),
    layers.Activation('swish'),
    layers.AveragePooling2D(pool_size=2, strides=2, padding='valid'),

    layers.Conv2D(16, kernel_size=5, strides=1, padding='valid'),
    layers.Activation('swish'),
    layers.AveragePooling2D(pool_size=2, strides=2, padding='valid'),

    layers.Flatten(),

    layers.Dense(256),
    layers.Activation('swish'),

    layers.Dense(64),
    layers.Activation('swish'),

    layers.Dense(1, use_bias=False),
])

In [None]:
!rm -rf ./logdir

In [None]:
input_shape = [28*28]
resample = lambda: random_uniform([batch_size, *input_shape])
network(resample())  # build.

model = EnergyModel(
    network,
    resample,
    t=5e-0,
    dt=1e-1,
)
tf.print('T = ', model.T)

optimizer = tf.keras.optimizers.Adam(1e-3, clipvalue=1e-1)
writer = tf.summary.create_file_writer('./logdir')
callbacks = [
    LossMonitor(writer, 5),
    FantasyParticleMonitor(writer, model, 5),
    VectorFieldMonitor(writer, model, 5),
    LossGradientMonitor(writer, model, 5),
]

train_step = model.get_optimize_fn(optimizer, callbacks)
train_step = tf.function(train_step)

In [None]:
%tensorboard --logdir logdir

In [None]:
# Two epochs are enough!
for epoch in range(2):
    for batch in tqdm(dataset):
        train_step(batch)

## Evaluation

In [None]:
test_X = list(dataset)[0]

In [None]:
# If test denoise:
noised_X = test_X + 0.5 * tf.random.truncated_normal(test_X.shape)

# Or if test generation:
# noised_X = random_uniform(test_X.shape)

relaxed_X = model.evolve(noised_X)

In [None]:
def display_image(x):
    x = x.numpy().reshape([28, 28])
    plt.imshow(x)
    plt.show()

In [None]:
i = 0
display_image(test_X[i, :])
# display_image(noised_X[i, :])
display_image(relaxed_X[i, :])

## Conclusions

### Clipping
We find that there's no need to do any clipping on the model, including that on vector field or on fantasy particles. The model can take care of itself! The only clipping may be about the optimizer.

### Instability
Instability may happen, but rarely. It comes from the optimizer, instead of from the model itself. Indeed, when instability appears, the loss diverges quickly, while the vector field values and fantasy particles are kept stable.

### Persistance
Using persistant random walk performs badly on both denoise and generation tasks. In high-dimensional phase space, the volumn of the space is quite large. Persistant random walk cannot explore the phase space efficiently, such that the fantasy particles dance in limited subspace. Contrarily, the non-persistant random walk explore the phase space with unreasonable efficiency, and performs greatly on both denoise and generation tasks.