# Restricted Boltzmann Machine (RBM)

Consider stochastic binary visible variables: $x\in\{0,1\}^D$ and stochastic binary hidden variables: $h \in \{0,1\}^H$ to model MNIST images. This will involve training a Bernoulli-Bernoulli RBM for generating images.

$$P(x, h) = \frac{exp(-E(x, h))}{Z}$$

where $E(x,h)$ is the energy associated with the joint configuration of (x,h). In the Bernoulli-Bernoulli RBM, $E(x, h) := -h^TWx - c^Tx - b^Th$.

Using the above, we can determine the following:

* Marginal of x: $P(x) = \frac{exp(c^Tx + \sum_{j=1}^H softplus{(b_j + W_{j.}x)} )}{Z} = \frac{exp(-F(x))}{Z}$, where $F(x)$ is called the "free energy".
* Conditionals: $p(x|h)$ and $p(h|x)$.

Note: $Z = \sum_{x,h} exp(-E(x,h)) = \sum_x exp(-F(x)) $

Objective: Minimize the Negative log-likelihood $-ln(P(x))$ of training samples:

$$
\begin{align}
-lnP(x) &= F(x) + ln(Z) \\ \\
% -\frac{\partial lnP(x)}{\partial \theta} &= \frac{\partial F(x)}{\partial \theta} + \frac{\partial lnZ}{\partial \theta} \\
-\frac{\partial lnP(x)}{\partial \theta} &= \frac{\partial F(x)}{\partial \theta} - \sum_{\tilde{x}} p(\tilde{x}) \frac{\partial F(\tilde{x})}{\partial \theta} 
\end{align}
$$

where $\tilde{x}$ denotes the negative sample. To sample $\tilde{x}$, we use MCMC.

**Contrastive Divergence (CD-k)**

Contrastive Divergence uses two tricks to speed up the sampling process:

* We initialize the Markov chain with a training example.
* CD does not wait for the chain to converge. Samples are obtained after only k-steps of Gibbs sampling.

### Train Bernoulli-Bernoulli RBM

In [1]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import numpy as np

## Setup dataloader

In [2]:
batch_size = 128
train_set, test_set, train_loader, test_loader = {},{},{},{}
transform = transforms.Compose(
    [transforms.ToTensor()])

train_set['mnist'] = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set['mnist'] = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader['mnist'] = torch.utils.data.DataLoader(train_set['mnist'], batch_size=batch_size, shuffle=True, num_workers=0)
test_loader['mnist'] = torch.utils.data.DataLoader(test_set['mnist'], batch_size=batch_size, shuffle=False, num_workers=0)

device = 'cuda'

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [3]:
class RBM(nn.Module):
    """Restricted Boltzmann Machine for generating MNIST images."""
    
    def __init__(self, D: int, F: int, k: int):
        """Creates an instance RBM module.
            
            Args:
                D: Size of the input data.
                F: Size of the hidden variable.
                k: Number of MCMC iterations for negative sampling.
                
            The function initializes the weight (W) and biases (c & b).
        """
        
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(F, D)* 1e-2) # Initialized from Normal(mean=0.0, variance=1e-4)
        self.c = nn.Parameter(torch.zeros(D)) # Initialized as 0.0
        self.b = nn.Parameter(torch.zeros(F)) # Initilaized as 0.0
        self.k = k
    
    def sample(self, p):
        """Sample from a bernoulli distribution defined by a given parameter.
        
           Args:
                p: Parameter of the bernoulli distribution.
           
           Returns:
               bern_sample: Sample from Bernoulli(p)
        """
        
        bern_sample = p.bernoulli()
        return bern_sample
    
    def P_h_x(self, x):
        """Returns the conditional P(h|x). (Slide 9, Lecture 14)
        
        Args:
            x: The parameter of the conditional h|x.
        
        Returns:
            ph_x: probability of hidden vector being element-wise 1.
        """
        print("P_h_x, x:", x.shape)
        print("P_h_x, W:", self.W.shape)
        ph_x = torch.sigmoid(F.linear(x, self.W, self.b)) # n_batch x F
        print("P_h_x, ph_x (output):", ph_x.shape)
        return ph_x
    
    def P_x_h(self, h):
        """Returns the conditional P(x|h). (Slide 9, Lecture 14)
        
        Args:
            h: The parameter of the conditional x|h.
        
        Returns:
            px_h: probability of visible vector being element-wise 1.
        """
        print("P_x_h, h", h.shape)
        print("P_x_h, W", self.W.shape)
        mean = F.linear(h, self.W.t(), self.c) # 128 784
        print("P_x_h, mean", mean.shape)
        print("P_x_h, output", torch.normal(mean, 1).shape)
        px_h = torch.normal(mean, 1)

        # px_h = torch.sigmoid(F.linear(h, self.W.t(), self.c)) # n_batch x D
        print("P_x_h, px_h", px_h.shape)

        return px_h

    def free_energy(self, x):
        """Returns the Average energy. (Slide 11, Lecture 14)."""
        # Wx = self.W.matmul(x.t())
        # print("Wx", Wx.shape)
        # hWx = self.P_h_x(x).t().matmul(Wx.t())
        # print("hWx", hWx.shape)
        # cx = self.c.t().matmul(x.t()) # n_batch x 1
        # print("cx", cx.shape)
        # bh = self.b.t().matmul(self.P_h_x(x).t())  
        # print("bh", bh.shape)
        # xx = x.t().matmul(x) # vector product
        # print("xx", xx.shape)
        print("start free_energy method")
        Wx = F.linear(x, self.W)
        h = self.P_h_x(x)
        hWx = F.linear(Wx, h)
        print("hWx", hWx.shape)
        cx = F.linear(x, self.c)
        print("cx", cx.shape)
        bh = F.linear(h, self.b)
        print("bh", bh.shape)
        s = x.shape[1]
        xx = torch.bmm(x.view(128, 1, s), x.view(128, s, 1)) # torch.bmm(x.unsqueeze(1), x.unsqueeze(2)
        print("xx", xx.shape)
        print("free_energy output", (-hWx -cx -bh + 0.5*xx).mean().shape)
        return (-hWx -cx -bh + 0.5*xx).mean()

        # vbias_term = x.mv(self.c) # n_batch x 1
        # wv_b = F.linear(x, self.W, self.b) # n_batch x F
        # hidden_term = F.softplus(wv_b).sum(dim=1) # n_batch x 1
        # print("original mean size", (-hidden_term - vbias_term).mean().shape)
        # return (-hidden_term - vbias_term).mean() # 1 x 1 
    
    def forward(self, x):
        """Generates x_negative using MCMC Gibbs sampling starting from x."""
        print("start forward")
        print("argument x to forward", x.shape)
        x_negative = x
        for i in range(self.k):
            print("forward loop", i)
            ## Step 1: Sample h from previous iteration.
            # Get the conditional prob h|x
            phx_k = self.P_h_x(x_negative) 
            print("forward, phx_k",phx_k.shape)
            # Sample from h|x
            h_negative = self.sample(phx_k)
            print("forward, sample(phx_k)", h_negative.shape)
            
            ## Step 2: Sample x using h from step 1.
            # Get the conditional proba x|h
            pxh_k = self.P_x_h(h_negative)
            print("forward, pxh_k", pxh_k.shape)
            # Sample from x|h
            x_negative = self.sample(pxh_k)
            print("forward, sample(pxh_k)", x_negative.shape)
        print("forward end")
        return x_negative, pxh_k

In [4]:
DD = 2
FF = 3
WW = torch.randint(10,(FF, DD)) # Initialized from Normal(mean=0.0, variance=1e-4)
cc = torch.randint(10,(DD,)) # Initialized as 0.0
bb = torch.randint(10,(FF,))

In [5]:
F.linear(cc,WW)

tensor([84, 95, 54])

In [6]:
WW.matmul(cc.t())

tensor([84, 95, 54])

In [7]:
bb

tensor([2, 8, 5])

In [8]:
a=torch.arange(0,10).view(5,2)
b=torch.arange(0,20).view(2,10)

In [9]:
a.matmul(b)

tensor([[ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
        [ 30,  35,  40,  45,  50,  55,  60,  65,  70,  75],
        [ 50,  59,  68,  77,  86,  95, 104, 113, 122, 131],
        [ 70,  83,  96, 109, 122, 135, 148, 161, 174, 187],
        [ 90, 107, 124, 141, 158, 175, 192, 209, 226, 243]])

In [10]:
F.linear(a,b.t())

tensor([[ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
        [ 30,  35,  40,  45,  50,  55,  60,  65,  70,  75],
        [ 50,  59,  68,  77,  86,  95, 104, 113, 122, 131],
        [ 70,  83,  96, 109, 122, 135, 148, 161, 174, 187],
        [ 90, 107, 124, 141, 158, 175, 192, 209, 226, 243]])

## Define train and test functions

In [11]:
def train(model, device, train_loader, optimizer, epoch):
    
    train_loss = 0
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # torchvision provides us with normalized data, s.t. input is in [0,1]
        data = data.view(data.size(0),-1) # flatten the array: Converts n_batchx1x28x28 to n_batchx784
        data = data.bernoulli() 
        data = data.to(device)
        print("data", data.shape)
        
        optimizer.zero_grad()
        
        x_tilde, _ = model(data)
        x_tilde = x_tilde.detach()
        print("x_tilde", x_tilde.shape)

        print("now compute model.free_energy(data) - model.free_energy(tilde)")
        loss = model.free_energy(data) - model.free_energy(x_tilde)
        loss.backward()
        optimizer.step()
        print("loss shape", loss.shape)
        train_loss += loss.item()
        
        if (batch_idx+1) % (len(train_loader)//2) == 0:
            print('Train({})[{:.0f}%]: Loss: {:.4f}'.format(
                epoch, 100. * batch_idx / len(train_loader), train_loss/(batch_idx+1)))

def test(model, device, test_loader, epoch):
    
    model.eval()
    test_loss = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(data.size(0),-1)
            data = data.bernoulli()
            data = data.to(device)
            xh_k,_ = model(data)
            loss = model.free_energy(data) - model.free_energy(xh_k)
            test_loss += loss.item() # sum up batch loss
    
    test_loss = (test_loss*batch_size)/len(test_loader.dataset)
    print('Test({}): Loss: {:.4f}'.format(epoch, test_loss))

## __Define make_optimizer and make_schedule__

In [12]:
def make_optimizer(optimizer_name, model, **kwargs):
    if optimizer_name=='Adam':
        optimizer = optim.Adam(model.parameters(),lr=kwargs['lr'])
    elif optimizer_name=='SGD':
        optimizer = optim.SGD(model.parameters(),lr=kwargs['lr'],momentum=kwargs.get('momentum', 0.), 
                              weight_decay=kwargs.get('weight_decay', 0.))
    else:
        raise ValueError('Not valid optimizer name')
    return optimizer
    
def make_scheduler(scheduler_name, optimizer, **kwargs):
    if scheduler_name=='MultiStepLR':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=kwargs['milestones'],gamma=kwargs['factor'])
    else:
        raise ValueError('Not valid scheduler name')
    return scheduler

In [13]:
# General variables

seed = 1
data_name = 'mnist'
optimizer_name = 'Adam'
scheduler_name = 'MultiStepLR'
num_epochs = 10
lr = 0.001


In [14]:
device = torch.device(device)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

rbm = RBM(D=28*28, F=400, k=5).to(device)
optimizer = make_optimizer(optimizer_name, rbm, lr=lr)
scheduler = make_scheduler(scheduler_name, optimizer, milestones=[5], factor=0.1)

for epoch in range(1, num_epochs + 1):
    
    train(rbm, device, train_loader[data_name], optimizer, epoch)
    test(rbm, device, test_loader[data_name], epoch)
    scheduler.step()
    
    print('Optimizer Learning rate: {0:.4f}\n'.format(optimizer.param_groups[0]['lr']))

data torch.Size([128, 784])
start forward
argument x to forward torch.Size([128, 784])
forward loop 0
P_h_x, x: torch.Size([128, 784])
P_h_x, W: torch.Size([400, 784])
P_h_x, ph_x (output): torch.Size([128, 400])
forward, phx_k torch.Size([128, 400])
forward, sample(phx_k) torch.Size([128, 400])
P_x_h, h torch.Size([128, 400])
P_x_h, W torch.Size([400, 784])
P_x_h, mean torch.Size([128, 784])
P_x_h, output torch.Size([128, 784])
P_x_h, px_h torch.Size([128, 784])
forward, pxh_k torch.Size([128, 784])
forward, sample(pxh_k) torch.Size([128, 784])
forward loop 1
P_h_x, x: torch.Size([128, 784])
P_h_x, W: torch.Size([400, 784])
P_h_x, ph_x (output): torch.Size([128, 400])
forward, phx_k torch.Size([128, 400])
forward, sample(phx_k) torch.Size([128, 400])
P_x_h, h torch.Size([128, 400])
P_x_h, W torch.Size([400, 784])
P_x_h, mean torch.Size([128, 784])
P_x_h, output torch.Size([128, 784])
P_x_h, px_h torch.Size([128, 784])
forward, pxh_k torch.Size([128, 784])
forward, sample(pxh_k) torch.

RuntimeError: ignored

## Plot

In [None]:
def show(img1, img2):
    npimg1 = img1.cpu().numpy()
    npimg2 = img2.cpu().numpy()
    
    fig, axes = plt.subplots(1,2, figsize=(20,10))
    axes[0].imshow(np.transpose(npimg1, (1,2,0)), interpolation='nearest')
    axes[1].imshow(np.transpose(npimg2, (1,2,0)), interpolation='nearest')
    fig.show()

## Plot original images and reconstructed images using the test dataset

In [None]:
data,_ = next(iter(test_loader[data_name]))
data = data[:32]
data_size = data.size()
data = data.view(data.size(0),-1)
bdata = data.bernoulli().to(device)
vh_k, pvh_k = rbm(bdata)
vh_k, pvh_k = vh_k.detach(), pvh_k.detach()

In [None]:
show(make_grid(data.reshape(data_size), padding=0), make_grid(pvh_k.reshape(data_size), padding=0))