In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define RBM model"""
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from matplotlib import pyplot as plt

from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.rbm import repeated_gibbs, gibbs_update_brbm, energy_rbm
from utils.viz import random_sample_grid, img_grid_npy, imshow


# data
batch_size = 256
train_steps = 1500

parse_fn = lambda x: parse_img_label_tfr(x, (32*32,))
data = tfr_dataset_eager("/cache/tfrs/mnist_train.tfr", batch_size, parse_fn,
                         shufrep=60000)


# model
mode = "pcd"
if mode not in ["pcd", "cd", "naive"]:
    raise ValueError("uh oh!")

n_h = 1000
w_vh = tf.Variable(tf.random.uniform([32*32, n_h], -0.1, 0.1))
b_v = tf.Variable(tf.zeros([32*32]))
b_h = tf.Variable(tf.zeros([n_h]))
weights = [w_vh, b_v, b_h]

# compute marginal for better sampling
data_once = tfr_dataset_eager("/cache/tfrs/mnist_train.tfr", 60000, parse_fn)
# 1 is the batch axis; 0 is an extra axis we get for the list wrapper (unnecessary)
marginals = tf.reduce_mean([img for (img, _) in data_once], axis=[0, 1])


start_sampler = tfp.distributions.Bernoulli(probs=0.5, dtype=tf.float32)
marginal_sampler = tfp.distributions.Bernoulli(probs=marginals, dtype=tf.float32)

#opt = tf.optimizers.SGD(0.1)
opt = tf.optimizers.Adam()

In [None]:
"""Train"""

@tf.function
def train(batch, v_sampled=None, h_sampled=None):
    """v_sampled and h_sampled are used only for PCD.
    
    It's always passed because I'm lazy.
    """
    v_data = batch
    # a bit of a cheat: we "should" also be sampling h here, like so:
    #_, h_data = gibbs_update_brbm((v_data, h_random), w_vh, b_v, b_h)
    # where h_random is a dummy.
    h_data = tf.nn.sigmoid(tf.matmul(v_data, w_vh) + b_h)
    
    if mode == "cd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_data, h_data), 20, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    elif mode == "pcd":
        v_sampled, h_sampled = repeated_gibbs(
            (v_sampled, h_sampled), 20, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)
    else:
        v_random = marginal_sampler.sample(tf.shape(batch)[0])
        # this is just a dummy
        h_random = start_sampler.sample([tf.shape(batch)[0], n_h])
        v_sampled, h_sampled = repeated_gibbs(
            (v_random, h_random), 200, gibbs_update_brbm,
            w_vh=w_vh, b_v=b_v, b_h=b_h)

    with tf.GradientTape() as tape:
        logits_pos = tf.reduce_mean(-energy_rbm(v_data, h_data, w_vh, b_v, b_h))
        logits_neg = tf.reduce_mean(
            -energy_rbm(v_sampled, h_sampled, w_vh, b_v, b_h))
        loss = -(logits_pos - logits_neg)
    grads = tape.gradient(loss, weights)
    opt.apply_gradients(zip(grads, weights))
    
    return loss, v_sampled, h_sampled


v_samp = marginal_sampler.sample(batch_size)
h_samp = start_sampler.sample([batch_size, n_h])
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    loss, v_samp, h_samp = train(img_batch, v_samp, h_samp)
    if not step % 50:
        print("Step", step)
        print("Loss:", loss)

In [None]:
print("weights...")
for img in w_vh.numpy().T:
    absmax = np.max(np.abs(img))
    plt.imshow(img.reshape((32,32)), vmin=-absmax, vmax=absmax, cmap="RdBu_r")
    plt.show()

In [None]:
"""random samples"""
v_random = marginal_sampler.sample([49])
h_random = start_sampler.sample([49, n_h])
img_sample, _ = repeated_gibbs(
        (v_random, h_random), 200, gibbs_update_brbm,
        w_vh=w_vh, b_v=b_v, b_h=b_h)

img_sample = [img.numpy().reshape((32, 32)) for img in img_sample]
grid = img_grid_npy(img_sample, 7, 7, normalize=False)
imshow(grid)