The following experiment can be a good method to verify if the code for the Bernoulli Restricted Boltzmann Machine (RBM) works correctly. I saw this idea in a talk by Prof. Hinton (https://www.youtube.com/watch?v=AyzOUbkUf3M, see from 12:55 onwards). 

In this case, we train the RBM on only the mnist digits with label as 2. We then test the RBM by supplying an image with a different label (something that it was not trained on, in this case, 3) and ask it to reconstruct back using n Gibbs sampling steps. If the RBM trained well to recognize 2s, it should convert the 3 to a 2. Note that since the RBM memorizes 2s, any other image is a confabulation and so the reconstruction process must try to recover the true memory (i.e., the images of digit 2) when supplied with the confabulation (i.e., the digit 3). 

In [1]:
import torch
from torchvision import datasets
from BernoulliRBM import BernoulliRBM
import matplotlib.pyplot as plt

In [2]:
# Function to display digit 
def show_digit(x):
    plt.imshow(x.reshape((28, 28)), cmap=plt.cm.gray)
    plt.show()

In [3]:
# Load mnist data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,))
train_data = train_loader.dataset.train_data
train_labels = train_loader.dataset.train_labels

In [4]:
# Filter images for the digit 2. 
len_mnist_images = len(train_data)
len_image = 784
mnist_images_2 = []
for i in range(len_mnist_images): 
    temp = train_labels[i]
    if temp == 2:
        mnist_images_2.append(train_data[i].type(torch.FloatTensor).resize_(len_image,1)/255)

In [5]:
# Convert to one large matrix for rbm training. 
len_mnist_images_2 = len(mnist_images_2)
mnist_image_matrix = torch.FloatTensor(len_image, len_mnist_images_2)
for i in range(len_mnist_images_2): 
    mnist_image_matrix[:, i] = mnist_images_2[i].bernoulli()

In [6]:
# Take an example image of a different class, in this case, of digit 3.
example_class = 3 
mnist_test_image = torch.FloatTensor(len_image, 1)
for i in range(len_mnist_images): 
    temp = train_labels[i]
    if temp == example_class:
        mnist_test_image = (train_data[i].type(torch.FloatTensor).resize_(len_image,1)/255).bernoulli()
        break

In [7]:
# Hyperparameters
n_vis = len(mnist_image_matrix[:, 0])
n_hid = 500
init_wt_var = 0.01
l_rate = 0.01
n_itr = 50
bsz = 10
verb = True
xv_init = True
lr_decay = True
inc_cd_k = False
cdk = 5

print("Setting n_v = ", n_vis,
      ", n_h = ", n_hid,
      ", init_wt_var = ", init_wt_var,
      ", lr = ", l_rate,
      ", n_itr = ", n_itr,
      ", bsz = ", bsz,
      ", verbose = ", verb,
      ", xv_init = ", xv_init,
      ", lr_decay = ", lr_decay,
      ", inc_cd_k = ", inc_cd_k,
      ", cdk = ", cdk)

Setting n_v =  784 , n_h =  500 , init_wt_var =  0.01 , lr =  0.01 , n_itr =  50 , bsz =  10 , verbose =  True , xv_init =  True , lr_decay =  True , inc_cd_k =  False , cdk =  5


In [None]:
# Define rbm
rbm = BernoulliRBM(n_vis,
                   n_hid,
                   init_weight_variance=init_wt_var,
                   learning_rate=l_rate,
                   n_epochs=n_itr,
                   batch_size=bsz,
                   verbose=verb,
                   xavier_init=xv_init,
                   learning_rate_decay=lr_decay,
                   increase_to_cd_k=inc_cd_k,
                   k=cdk)

In [None]:
# Fit rbm to training data. 
# The network understands only the digit 2. 
rbm.fit(mnist_image_matrix)

(BernoulliRBM, fitting):   4%|4         | 26/596 [00:00<00:04, 122.58it/s]

  return self.add_(other)


(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 136.68it/s]
epoch  1 , avg_cost =  62.04045783433338 , std_cost =  10.769680355068763 , avg_grad =  16156.153528405515 , std_grad =  4366.9376975637515 , time elapsed =  4.362382888793945
(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 140.36it/s]
epoch  2 , avg_cost =  51.085630288860145 , std_cost =  4.082860842176768 , avg_grad =  13505.208546888109 , std_grad =  672.4660576162144 , time elapsed =  4.248440980911255
(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 139.97it/s]
epoch  3 , avg_cost =  48.50140764889301 , std_cost =  3.873179595649698 , avg_grad =  13192.287432164954 , std_grad =  612.3063637240227 , time elapsed =  4.259500980377197
(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 138.98it/s]
epoch  4 , avg_cost =  47.09526979843242 , std_cost =  3.7437956621321793 , avg_grad =  12951.841537987626 , std_grad =  594.2479285859704 , time elapsed =  4.289978981018

(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 140.75it/s]
epoch  34 , avg_cost =  41.86089313430274 , std_cost =  3.4083227850797475 , avg_grad =  12040.304472852873 , std_grad =  548.1606028458759 , time elapsed =  4.236132860183716
(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 140.41it/s]
epoch  35 , avg_cost =  41.78172485300359 , std_cost =  3.4686550176251028 , avg_grad =  12054.372422595952 , std_grad =  559.1449554594528 , time elapsed =  4.246268033981323
(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 140.72it/s]
epoch  36 , avg_cost =  41.77291271030503 , std_cost =  3.451471145582472 , avg_grad =  12044.2913351379 , std_grad =  544.1846101033898 , time elapsed =  4.237365007400513
(BernoulliRBM, fitting): 100%|##########| 596/596 [00:04<00:00, 140.44it/s]
epoch  37 , avg_cost =  41.734356259339606 , std_cost =  3.3917200011534336 , avg_grad =  12025.051421914324 , std_grad =  543.355325139339 , time elapsed =  4.24548792839

In [None]:
# Sample image of the learned class. 
plt.figure(1)
image = mnist_image_matrix[:,10]
show_digit(image.numpy())

In [None]:
# Reconstruction of the sample image. 
plt.figure(2)
_, image_reconst = rbm.reconstruct(image)
show_digit(image_reconst.numpy())

In [None]:
# Get the weight matrix to visualize the feature detectors. 
W = rbm.W

In [None]:
# Visualize the first 25 feature detectors, i.e, the incoming weights to the first 25 hidden units. 
fig = plt.figure(3, figsize=(10,10))
for i in range(25): 
    sub = fig.add_subplot(5, 5, i+1)
    sub.imshow(W[i, :].numpy().reshape((28,28)), cmap=plt.cm.gray)

In [None]:
# Display the test image of the different class (digit 3, in this case). 
show_digit(mnist_test_image.numpy())

In [None]:
# Perform reconstruction with rbm and display the reconstructed image
# After 100 gibbs sampling steps, the rbm recalls a 2 since it was trained to memorize images of the digit 2. 
image = mnist_test_image
_, image_reconst = rbm.reconstruct(image, n_gibbs=100)
show_digit(image_reconst.numpy())