In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define RBM model"""
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
tf.enable_eager_execution()

from utils.data import mnist_eager
from utils.rbm import repeated_gibbs, gibbs_update_brbm, energy_rbm
from utils.viz import random_sample_grid, img_grid_npy, imshow


# data
batch_size = 256
train_steps = 1500

data = mnist_eager("data/mnist_train", batch_size)


# model
mode = "pcd"
if mode not in ["pcd", "cd", "naive"]:
    raise ValueError("uh oh!")

n_h = 1000
w_vh = tf.get_variable("w_vh", [32*32, n_h], tf.float32)
b_v = tf.get_variable("b_v", [32*32], tf.float32,
                      initializer=tf.zeros_initializer)
b_h = tf.get_variable("b_h", [n_h], tf.float32,
                      initializer=tf.zeros_initializer)
weights = [w_vh, b_v, b_h]

# TODO adapt to data having image dimensions 
# compute marginal for better sampling
data_once = mnist_eager("data/mnist_train", 60000, train=False)
marginals = tf.reduce_mean([img for (img, _) in data_once], axis=[0,1])
print(marginals.shape)


start_sampler = tf.distributions.Bernoulli(probs=0.5, dtype=tf.float32)
marginal_sampler = tf.distributions.Bernoulli(probs=marginals, dtype=tf.float32)

#opt = tf.train.GradientDescentOptimizer(0.1)
opt = tf.train.AdamOptimizer()

In [None]:
"""Train"""
if mode == "pcd":
    v_sampled = marginal_sampler.sample(batch_size)
    h_sampled = start_sampler.sample([batch_size, n_h])
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    v_data = img_batch
    #_, h_data = gibbs_update_brbm((v_data, h_random), w_vh, b_v, b_h)
    h_data = tf.nn.sigmoid(tf.matmul(v_data, w_vh) + b_h)

    if mode == "cd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_data, h_data), 20, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    elif mode == "pcd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_sampled, h_sampled), 20, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    else:
        v_random = marginal_sampler.sample(img_batch.shape[0])
        # this is just a dummy
        h_random = start_sampler.sample([img_batch.shape[0], n_h])
        v_sampled, h_sampled = repeated_gibbs(
            (v_random, h_random), 200, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)

    with tf.GradientTape() as tape:
        logits_pos = tf.reduce_mean(-energy_rbm(v_data, h_data, w_vh, b_v, b_h))
        logits_neg = tf.reduce_mean(
            -energy_rbm(v_sampled, h_sampled, w_vh, b_v, b_h))
        loss = -(logits_pos - logits_neg)
    grads = tape.gradient(loss, weights)
    opt.apply_gradients(zip(grads, weights))
    if not step % 50:
        print("Step", step)
        print("Loss:", loss)

In [None]:
print("weights...")
for img in w_vh.numpy().T:
    absmax = np.max(np.abs(img))
    plt.imshow(img.reshape((32,32)), vmin=-absmax, vmax=absmax, cmap="RdBu_r")
    plt.show()

In [None]:
"""random samples"""
v_random = marginal_sampler.sample([16])
h_random = start_sampler.sample([16, n_h])
img_sample, _ = repeated_gibbs(
        (v_random, h_random), 200, gibbs_update_brbm,
        w_vh=w_vh, b_v=b_v, b_h=b_h)

img_sample = [img.numpy().reshape((32, 32)) for img in img_sample]
grid = img_grid_npy(img_sample, 4, 4, normalize=False)
imshow(grid)