# Data Loading

# Model

In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

## Definition

In [None]:
class RBM(nn.Module):

    def __init__(self, p: int, q:int):
        super(RBM, self).__init__()
        self.p = p
        self.q = q

        # Notes:
        # nn.Parameters automatically adds the variable to the list of the model's parameters
        # nn.Parameter tells Pytorch to include this tensor in the computation graph and compute gradients for it during backprop
        # nn.Parameter allow params to move to the right devide when appluing .to(device)

        # Parameters
        self.W = nn.Parameter(torch.randn(q, p)*1e-2)

        # Bias - initialised at 0
        self.a = nn.Parameter(torch.zeros(p))
        self.b = nn.Parameter(torch.zeros(q))
        

        

    def entree_sortie(self, v):
        # F.linear performs: v.W + b
        sigm = torch.sigmoid(F.linear(v, self.W, self.b))
        return sigm

    def sortie_entree(self, h):
        # F.linear performs: h.W(transpose) + a
        sigm = torch.sigmoid(F.linear(h, self.W.t(), self.a))

    def forward(self, v):
        raise NotImplementedError("Use the train method for training the RBM.")

    def train(self, V, nb_epoch, batch_size, eps=0.001):
        """
        Train the RBM using Contrastive Divergence

        Args:
        - V: input data
        - nb_epoch: Number of epoch
        - batch_size: Batch size
        - eps: Learning rate
        """
        n = V.size(0)
        p, q = self.p, self.q

        for epoch in range(nb_epoch):
            # Shuffle dataset
            V = V[torch.randperm(n)]
            
            # Iterate with batch_size step
            for j in range(0, n, batch_size):
                V_batch = V[j:min(j + batch_size, n)]
                batch_size_actual = V_batch.size(0)

                v_0 = V_batch
                p_h_v_0 = self.entree_sortie(v_0)
                # Sample
                h_0 = (torch.rand(batch_size_actual, q) < p_h_v_0)*1
                
                p_v_h_0 = self.sortie_entree(h_0)
                v_1 = (torch.rand(batch_size_actual, p) < p_v_h_0)*1
                
                p_h_v_1 = self.entree_sortie(v_1)

                # grad
                grad_a = torch.sum(v_0 - v_1, dim=0)
                grad_b = torch.sum(p_h_v_0 - p_h_v_1, dim=0)
                grad_W = torch.matmul(v_0.t(), p_h_v_0) - torch.matmul(v_1.t(), p_h_v_1)

                # Update params - Normalise to batch size
                # Note: We bypass the pytorch computation graph with .data to avoid accumulating gradients
                self.W.data += eps * grad_W.t() / batch_size_actual
                self.b.data += eps * grad_b / batch_size_actual
                self.a.data += eps * grad_a / batch_size_actual

            H = self.entree_sortie(V)
            V_rec = self.sortie_entree(H)
            quad_error = torch.sum((V - V_rec)**2) / (n*p)
            print(f"Epoch {epoch+1}/{nb_epoch}, Reconstruction Error (EQ): {quad_error.item():.6f}")


## Training

## Test