# Midterm 2, Assignment 3 - Gaetano Barresi [579102]

A Restricted Boltzmann Machine (RBM) is a generative stochastic neural network that learns a probability distribution over its inputs. It is widely used for unsupervised learning, dimensionality reduction, and feature extraction.
To implement from scratch a RBM we must define first its architecture. It is a simple two layer neural network, one input layer (visible states, our data) and one hidden layer (the hidden states, latent feature representation). For parameters, we have a set of weights and two set of bias, one for the visible units and one for the hidden units:


```python
self.W = torch.randn(hidden_dim, visible_dim, device=self.device) * 0.01
self.v_bias = torch.zeros(visible_dim, device=self.device)
self.h_bias = torch.zeros(hidden_dim, device=self.device)
```


Weights W are initialized with small values and biases with zeros.
Hidden units are conditionally independent given visible units and viceversa, due to not oriented edges and bipartition structure:


$$
\
\mathbb{P}(h_j \mid v) = \sigma \left( \sum_i M_{ij} v_i + c_j \right) \quad \forall j
\
$$


$$
\
\mathbb{P}(v_i \mid h) = \sigma \left( \sum_j M_{ij} h_j + b_i \right) \quad \forall i
\
$$


These are resepctively the forward pass (wake) and the backward pass (dream) and they can be implemented as:


```python
def visible_to_hidden(self, v):  # forward pass
    # compute probabilities of hidden units given visible units
    p_h = torch.sigmoid(F.linear(v, self.W, self.h_bias))
    return p_h
    
def hidden_to_visible(self, h):  # backward pass
    # compute probabilities of visible units given hidden units
    p_v = torch.sigmoid(F.linear(h, self.W.t(), self.v_bias))
    return p_v
```


In a RBM, sampling is a key step in the training process, particularly during Gibbs sampling. In order to generate binary samples (0 or 1) from a given probability distribution p, we use the following function. These binary samples represent the activation states of the visible or hidden units in the RBM.


```python
def sample_from_p(self, p):
    return F.relu(torch.sign(p - torch.rand_like(p, device=self.device)))
```


`sample_from_p` takes p, which is a tensor of probabilities. Each value in p represents the probability of a unit being activated (set to 1).
`torch.rand_like(p)` generates a tensor of random values, uniformly distributed between 0 and 1, with the same shape as p. `p - torch.rand_like(p)` computes the difference between the probabilities and the random values.
`torch.sign(p - torch.rand_like(p))` produces a tensor where values greater than 0 are set to 1, and values less than 0 are set to -1. This effectively performs a thresholding operation to decide whether each unit is activated (1) or not (-1).
`F.relu(...)` ensures that all negative values are clamped to 0. This step converts the -1 values to 0, resulting in a binary tensor of 0s and 1s.
All these pieces are used inside the generalized version of Contrastive Divergence (CD) learning algorithm. It is divided in positive phase (wake part), negative phase (dream part) and parameters update.

Positive phase computes the hidden probabilities (`p_h_given_v`) and sample hidden states (`h_sample`). It computes also the positive gradient as the outer product of `h_sample` and `v`.


```python
# positive phase
p_h_given_v = self.visible_to_hidden(v)
h_sample = self.sample_from_p(p_h_given_v)
positive_grad = torch.mm(h_sample.t(), v)
```


Negative phase performs k steps of Gibbs sampling to reconstruct the visible and hidden states and computes the negative gradient as the outer product of the reconstructed hidden probabilities and visible states.


```python
# gibbs sampling (negative phase)
v_sample = v
for _ in range(self.k):
    p_h_given_v = self.visible_to_hidden(v_sample)
    h_sample = self.sample_from_p(p_h_given_v)
    p_v_given_h = self.hidden_to_visible(h_sample)
    v_sample = self.sample_from_p(p_v_given_h)

# negative phase
p_h_given_v_sample = self.visible_to_hidden(v_sample)
negative_grad = torch.mm(p_h_given_v_sample.t(), v_sample)
```


Weights (`W`) and biases (`v_bias`, `h_bias`) are updated using the difference between the positive and negative gradients, normalized by the batch size.


```python
self.W += self.learning_rate * (positive_grad - negative_grad) / v.size(0)
self.v_bias += self.learning_rate * torch.sum(v - v_sample, dim=0) / v.size(0)
self.h_bias += self.learning_rate * torch.sum(p_h_given_v - p_h_given_v_sample, dim=0) / v.size(0)
```


We now pack all this stuff in a custom class `RBM`, provided with its own training method and train two different RBMs, with different values of k for CD algorithm: k=4 and k=8.

In [6]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

class RBM:
    def __init__(self, visible_dim, hidden_dim, k=1, lr=0.01, device='cpu'):
        self.visible_dim = visible_dim
        self.hidden_dim = hidden_dim
        self.k = k  # number of Gibbs sampling steps
        self.learning_rate = lr
        self.device = device  # Device to run the computations on (e.g., 'cuda' or 'cpu')
        
        # weights and biases initialization
        self.W = torch.randn(hidden_dim, visible_dim, device=self.device) * 0.01
        self.v_bias = torch.zeros(visible_dim, device=self.device)
        self.h_bias = torch.zeros(hidden_dim, device=self.device)
    
    def sample_from_p(self, p):
        # bernoulli sampling given probabilities
        return F.relu(torch.sign(p - torch.rand_like(p, device=self.device)))
    
    def visible_to_hidden(self, v):  # forward pass
        # compute probabilities of hidden units given visible units
        p_h = torch.sigmoid(F.linear(v, self.W, self.h_bias))
        return p_h
    
    def hidden_to_visible(self, h):  # backward pass
        # compute probabilities of visible units given hidden units
        p_v = torch.sigmoid(F.linear(h, self.W.t(), self.v_bias))
        return p_v

    def contrastive_divergence(self, v):
        # Move input to the correct device
        v = v.to(self.device)

        # positive phase
        p_h_given_v = self.visible_to_hidden(v)
        h_sample = self.sample_from_p(p_h_given_v)
        positive_grad = torch.mm(h_sample.t(), v)

        # gibbs sampling (negative phase)
        v_sample = v
        for _ in range(self.k):
            p_h_given_v = self.visible_to_hidden(v_sample)
            h_sample = self.sample_from_p(p_h_given_v)
            p_v_given_h = self.hidden_to_visible(h_sample)
            v_sample = self.sample_from_p(p_v_given_h)

        # negative phase
        p_h_given_v_sample = self.visible_to_hidden(v_sample)
        negative_grad = torch.mm(p_h_given_v_sample.t(), v_sample)

        # update weights and biases
        self.W += self.learning_rate * (positive_grad - negative_grad) / v.size(0)
        self.v_bias += self.learning_rate * torch.sum(v - v_sample, dim=0) / v.size(0)
        self.h_bias += self.learning_rate * torch.sum(p_h_given_v - p_h_given_v_sample, dim=0) / v.size(0)

    def train(self, data_loader, epochs=10):
        for epoch in range(epochs):
            epoch_error = 0
            for batch in tqdm(data_loader, desc="Training Batches", leave=False):
                # Extract data from batch (ignore labels)
                batch, _ = batch  # Unpack the tuple (data, labels)
                batch = batch.view(-1, self.visible_dim).to(self.device)  # Flatten input and move to device
                self.contrastive_divergence(batch)
            
            print(f"Epoch {epoch + 1}/{epochs}")

In [7]:
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader


visible_dim = 784  # For MNIST
hidden_dim = 256  # Number of hidden neurons in RBM
num_classes = 10  # Digits 0-9

# Use GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Initialize RBMs
rbm4 = RBM(visible_dim, hidden_dim, k=4, lr=0.01, device=device)
rbm8 = RBM(visible_dim, hidden_dim, k=8, lr=0.01, device=device)

# Load the MNIST training and test data
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
mnist_train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train_data, batch_size=10, shuffle=True)

print("Training RBM4...")
rbm4.train(train_loader, epochs=50)
print("Training RBM8...")
rbm8.train(train_loader, epochs=50)

Using device: cuda
Training RBM4...


                                                                      

Epoch 1/50


                                                                      

Epoch 2/50


                                                                      

Epoch 3/50


                                                                      

Epoch 4/50


                                                                      

Epoch 5/50


                                                                      

Epoch 6/50


                                                                      

Epoch 7/50


                                                                      

Epoch 8/50


                                                                      

Epoch 9/50


                                                                      

Epoch 10/50


                                                                      

Epoch 11/50


                                                                      

Epoch 12/50


                                                                      

Epoch 13/50


                                                                      

Epoch 14/50


                                                                      

Epoch 15/50


                                                                      

Epoch 16/50


                                                                      

Epoch 17/50


                                                                      

Epoch 18/50


                                                                      

Epoch 19/50


                                                                      

Epoch 20/50


                                                                      

Epoch 21/50


                                                                      

Epoch 22/50


                                                                      

Epoch 23/50


                                                                      

Epoch 24/50


                                                                      

Epoch 25/50


                                                                      

Epoch 26/50


                                                                      

Epoch 27/50


                                                                      

Epoch 28/50


                                                                      

Epoch 29/50


                                                                      

Epoch 30/50


                                                                      

Epoch 31/50


                                                                      

Epoch 32/50


                                                                      

Epoch 33/50


                                                                      

Epoch 34/50


                                                                      

Epoch 35/50


                                                                      

Epoch 36/50


                                                                      

Epoch 37/50


                                                                      

Epoch 38/50


                                                                      

Epoch 39/50


                                                                      

Epoch 40/50


                                                                      

Epoch 41/50


                                                                      

Epoch 42/50


                                                                      

Epoch 43/50


                                                                      

Epoch 44/50


                                                                      

Epoch 45/50


                                                                      

Epoch 46/50


                                                                      

Epoch 47/50


                                                                      

Epoch 48/50


                                                                      

Epoch 49/50


                                                                      

Epoch 50/50
Training RBM8...


                                                                      

Epoch 1/50


                                                                      

Epoch 2/50


                                                                      

Epoch 3/50


                                                                      

Epoch 4/50


                                                                      

Epoch 5/50


                                                                      

Epoch 6/50


                                                                      

Epoch 7/50


                                                                      

Epoch 8/50


                                                                      

Epoch 9/50


                                                                      

Epoch 10/50


                                                                      

Epoch 11/50


                                                                      

Epoch 12/50


                                                                      

Epoch 13/50


                                                                      

Epoch 14/50


                                                                      

Epoch 15/50


                                                                      

Epoch 16/50


                                                                      

Epoch 17/50


                                                                      

Epoch 18/50


                                                                      

Epoch 19/50


                                                                      

Epoch 20/50


                                                                      

Epoch 21/50


                                                                      

Epoch 22/50


                                                                      

Epoch 23/50


                                                                      

Epoch 24/50


                                                                      

Epoch 25/50


                                                                      

Epoch 26/50


                                                                      

Epoch 27/50


                                                                      

Epoch 28/50


                                                                      

Epoch 29/50


                                                                      

Epoch 30/50


                                                                      

Epoch 31/50


                                                                      

Epoch 32/50


                                                                      

Epoch 33/50


                                                                      

Epoch 34/50


                                                                      

Epoch 35/50


                                                                      

Epoch 36/50


                                                                      

Epoch 37/50


                                                                      

Epoch 38/50


                                                                      

Epoch 39/50


                                                                      

Epoch 40/50


                                                                      

Epoch 41/50


                                                                      

Epoch 42/50


                                                                      

Epoch 43/50


                                                                      

Epoch 44/50


                                                                      

Epoch 45/50


                                                                      

Epoch 46/50


                                                                      

Epoch 47/50


                                                                      

Epoch 48/50


                                                                      

Epoch 49/50


                                                                      

Epoch 50/50




Once we have the trained RBMs we need to create an encoding of MNIST dataset using their hidden neurons. We encode the MNIST twice, one time for each RBM and we will use these encodings to train a couple of simple classifiers. Their performances will reflect the RBMs' encodings quality.

The following code block will show the implementation of the encoding function and its call. The code block after will show the implementation of the classifier and its training loop, with its training phase and evaluation phase.

In [8]:
def encode_dataset(rbm, data_loader, device):
    # Lists to store encodings and labels
    encodings = []
    labels = []

    # Iterate through the dataset
    for batch, batch_labels in data_loader:
        batch = batch.to(device)  # Move to the correct device
        # Compute hidden neuron activations for the entire batch
        hidden_activations = rbm.visible_to_hidden(batch)
        encodings.append(hidden_activations.cpu())  # Store encodings
        labels.append(batch_labels)  # Store labels

    # Concatenate all batches into a single tensor
    encodings = torch.cat(encodings, dim=0)
    labels = torch.cat(labels, dim=0)

    return encodings, labels

############################################################

mnist_test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_test_data, batch_size=10, shuffle=False)

rbm4_train_encodings, rbm4_train_labels = encode_dataset(rbm4, train_loader, device)
rbm4_test_encodings, rbm4_test_labels = encode_dataset(rbm4, test_loader, device)

rbm8_train_encodings, rbm8_train_labels = encode_dataset(rbm8, train_loader, device)
rbm8_test_encodings, rbm8_test_labels = encode_dataset(rbm8, test_loader, device)

In [9]:
import torch.nn as nn
import torch.optim as optim


class MNISTClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MNISTClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    

def train_classifier_on_mnist(train_loader, input_dim, num_classes, device, num_epochs=40, lr=0.001):
    # Define the classifier
    classifier = MNISTClassifier(input_dim, num_classes).to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=lr)

    # Training loop
    print("Training classifier...")
    for epoch in range(num_epochs):
        classifier.train()
        epoch_loss = 0
        for batch, labels in train_loader:
            batch, labels = batch.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = classifier(batch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    return classifier


def evaluate_classifier(classifier, test_loader, device):
    classifier.eval()  # Set the classifier to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation for evaluation
        for batch, labels in test_loader:
            batch, labels = batch.to(device), labels.to(device)
            outputs = classifier(batch)
            _, predicted = torch.max(outputs, 1)  # Get the class with the highest score
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return accuracy

#############################################################################

# Train classifiers on RBM-encoded datasets
rbm4_train_dataset = TensorDataset(rbm4_train_encodings, rbm4_train_labels)
rbm4_train_loader = DataLoader(rbm4_train_dataset, batch_size=100, shuffle=True)
classifier4 = train_classifier_on_mnist(rbm4_train_loader, hidden_dim, num_classes, device)

rbm8_train_dataset = TensorDataset(rbm8_train_encodings, rbm8_train_labels)
rbm8_train_loader = DataLoader(rbm8_train_dataset, batch_size=100, shuffle=True)
classifier8 = train_classifier_on_mnist(rbm8_train_loader, hidden_dim, num_classes, device)

# Evaluate the classifiers on the MNIST test data encodings
rbm4_test_dataset = TensorDataset(rbm4_test_encodings, rbm8_test_labels)
rbm4_test_loader = DataLoader(rbm4_test_dataset, batch_size=100, shuffle=False)
accuracy4 = evaluate_classifier(classifier4, rbm4_test_loader, device)
print(f"Classifier4 accuracy: {accuracy4:.4f}")

rbm8_test_dataset = TensorDataset(rbm8_test_encodings, rbm8_test_labels)
rbm8_test_loader = DataLoader(rbm8_test_dataset, batch_size=100, shuffle=False)
accuracy8 = evaluate_classifier(classifier8, rbm8_test_loader, device)
print(f"Classifier8 accuracy: {accuracy8:.4f}")

Training classifier...
Epoch [1/40], Loss: 208.2376
Epoch [2/40], Loss: 85.2166
Epoch [3/40], Loss: 66.9309
Epoch [4/40], Loss: 57.7720
Epoch [5/40], Loss: 49.6877
Epoch [6/40], Loss: 44.3337
Epoch [7/40], Loss: 38.9533
Epoch [8/40], Loss: 35.1131
Epoch [9/40], Loss: 31.8776
Epoch [10/40], Loss: 28.6864
Epoch [11/40], Loss: 25.8670
Epoch [12/40], Loss: 23.3815
Epoch [13/40], Loss: 21.1571
Epoch [14/40], Loss: 19.3428
Epoch [15/40], Loss: 17.7163
Epoch [16/40], Loss: 15.5124
Epoch [17/40], Loss: 14.0166
Epoch [18/40], Loss: 12.6812
Epoch [19/40], Loss: 11.3757
Epoch [20/40], Loss: 10.7367
Epoch [21/40], Loss: 9.0202
Epoch [22/40], Loss: 8.3779
Epoch [23/40], Loss: 7.6348
Epoch [24/40], Loss: 6.7234
Epoch [25/40], Loss: 6.0694
Epoch [26/40], Loss: 4.8651
Epoch [27/40], Loss: 5.1208
Epoch [28/40], Loss: 4.2707
Epoch [29/40], Loss: 4.2519
Epoch [30/40], Loss: 3.9209
Epoch [31/40], Loss: 3.2159
Epoch [32/40], Loss: 2.7018
Epoch [33/40], Loss: 2.4247
Epoch [34/40], Loss: 2.3134
Epoch [35/40]

After training and testing the classifiers, we can see that the accuracy obtained during the testing phase is about the same: 97.90% for RBM with CD-4 and 97.92% for RBM with CD-8. As explained during the lessons, in practice CD with values ​​of k different from 1 does not apport particular advantages and is rarely used, in fact, by training an RBM and using CD-1, we can notice how the final results obtained do not differ too much from the previous ones (accuracy 97.95%).

In [None]:
# Initialize RBM1
rbm1 = RBM(visible_dim, hidden_dim, k=1, lr=0.01, device=device)

print("RBM1 test...")
rbm1.train(train_loader, epochs=50)

rbm1_train_encodings, rbm1_train_labels = encode_dataset(rbm1, train_loader, device)
rbm1_test_encodings, rbm1_test_labels = encode_dataset(rbm1, test_loader, device)

rbm1_train_dataset = TensorDataset(rbm1_train_encodings, rbm1_train_labels)
rbm1_train_loader = DataLoader(rbm1_train_dataset, batch_size=100, shuffle=True)
classifier1 = train_classifier_on_mnist(rbm1_train_loader, hidden_dim, num_classes, device)

# Evaluate the classifiers on the MNIST test data encodings
rbm1_test_dataset = TensorDataset(rbm1_test_encodings, rbm1_test_labels)
rbm1_test_loader = DataLoader(rbm1_test_dataset, batch_size=100, shuffle=False)
accuracy1 = evaluate_classifier(classifier1, rbm1_test_loader, device)
print(f"Classifier1 accuracy: {accuracy1:.4f}")

RBM1 test...


                                                                       

Epoch 1/50


                                                                       

Epoch 2/50


                                                                      

Epoch 3/50


                                                                      

Epoch 4/50


                                                                      

Epoch 5/50


                                                                      

Epoch 6/50


                                                                      

Epoch 7/50


                                                                      

Epoch 8/50


                                                                      

Epoch 9/50


                                                                      

Epoch 10/50


                                                                      

Epoch 11/50


                                                                      

Epoch 12/50


                                                                      

Epoch 13/50


                                                                      

Epoch 14/50


                                                                      

Epoch 15/50


                                                                      

Epoch 16/50


                                                                      

Epoch 17/50


                                                                       

Epoch 18/50


                                                                       

Epoch 19/50


                                                                      

Epoch 20/50


                                                                      

Epoch 21/50


                                                                       

Epoch 22/50


                                                                      

Epoch 23/50


                                                                       

Epoch 24/50


                                                                      

Epoch 25/50


                                                                      

Epoch 26/50


                                                                       

Epoch 27/50


                                                                       

Epoch 28/50


                                                                       

Epoch 29/50


                                                                       

Epoch 30/50


                                                                       

Epoch 31/50


                                                                       

Epoch 32/50


                                                                       

Epoch 33/50


                                                                       

Epoch 34/50


                                                                      

Epoch 35/50


                                                                       

Epoch 36/50


                                                                       

Epoch 37/50


                                                                      

Epoch 38/50


                                                                       

Epoch 39/50


                                                                       

Epoch 40/50


                                                                      

Epoch 41/50


                                                                       

Epoch 42/50


                                                                       

Epoch 43/50


                                                                      

Epoch 44/50


                                                                       

Epoch 45/50


                                                                       

Epoch 46/50


                                                                       

Epoch 47/50


                                                                       

Epoch 48/50


                                                                       

Epoch 49/50


                                                                       

Epoch 50/50
Training classifier...
Epoch [1/40], Loss: 217.3591
Epoch [2/40], Loss: 97.1386
Epoch [3/40], Loss: 75.9968
Epoch [4/40], Loss: 63.2306
Epoch [5/40], Loss: 54.3591
Epoch [6/40], Loss: 48.2648
Epoch [7/40], Loss: 42.8074
Epoch [8/40], Loss: 38.5775
Epoch [9/40], Loss: 34.0832
Epoch [10/40], Loss: 31.0952
Epoch [11/40], Loss: 28.3885
Epoch [12/40], Loss: 25.4651
Epoch [13/40], Loss: 23.3352
Epoch [14/40], Loss: 20.6842
Epoch [15/40], Loss: 19.4512
Epoch [16/40], Loss: 17.4249
Epoch [17/40], Loss: 16.0055
Epoch [18/40], Loss: 14.2375
Epoch [19/40], Loss: 13.3275
Epoch [20/40], Loss: 12.3588
Epoch [21/40], Loss: 10.8668
Epoch [22/40], Loss: 10.6986
Epoch [23/40], Loss: 9.2658
Epoch [24/40], Loss: 8.5722
Epoch [25/40], Loss: 7.7471
Epoch [26/40], Loss: 6.8541
Epoch [27/40], Loss: 6.7314
Epoch [28/40], Loss: 5.9326
Epoch [29/40], Loss: 5.9945
Epoch [30/40], Loss: 5.0099
Epoch [31/40], Loss: 4.4233
Epoch [32/40], Loss: 4.6791
Epoch [33/40], Loss: 3.9259
Epoch [34/40], Loss: 3.8386