# Data Loading

In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image
import scipy.io
import matplotlib.pyplot as plt

## Binary Alpha Digits

In [None]:
import os
data_folder = './data'
for dirname, _, filenames in os.walk(data_folder):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [3]:
data_binaryalphadigits = os.path.join(data_folder, 'binaryalphadigs.mat')
data = scipy.io.loadmat(data_binaryalphadigits)

In [4]:
def load_data(data_folder=data_folder, file='binaryalphadigs.mat'):
    data_file = os.path.join(data_folder, file)

    data = scipy.io.loadmat(data_binaryalphadigits)
    images = data['dat']
    height, width = images[0][0].shape

    # Flatten all images across all classes (0-35: 10 digits + 26 letters)
    flattened_images = []
    for class_images in images:  # Iterate over each class
        for img in class_images:  # Iterate over images in the class
            flattened_images.append(img.flatten())

    # Convert to a tensor
    image_tensor = torch.tensor(np.array(flattened_images), dtype=torch.float32)

    return image_tensor, images, height, width

In [None]:
# Convert to a tensor
image_tensor, images, height, width = load_data()

print(f"Tensor shape: {image_tensor.shape}")

In [None]:
images.shape

In [7]:
import random

def display_image(images, num_images=5):
    # Randomly select indices for the images
    random_indices = random.sample(range(len(images)), num_images)
    
    # Set up a grid for the images
    fig, axes = plt.subplots(1, num_images, figsize=(5, 5))
    
    # Loop through the random indices and display each image
    for i, idx in enumerate(random_indices):
        image = images[idx].reshape(20, 16)  # Reshape the image to 20x16 if needed
        axes[i].imshow(image, cmap='gray')
        axes[i].set_title(f'Image {idx}')
        axes[i].axis('off')  # Hide the axes
    
    plt.tight_layout()
    plt.show()

In [None]:
image_tensor.shape

In [None]:
display_image(images[0])
display_image(images[1])
display_image(images[2])
display_image(images[3])
display_image(images[4])
display_image(images[5])

#### Lire Alpha Digit

In [10]:
def lire_alpha_digit(data,list_):
    """
    """
    X=data['dat'][list_[0]]
    len_ = len(list_)

    for i in range(1,len_) :
        X_i = data['dat'][list_[i]]
        X = np.concatenate((X,X_i),axis=0)

    n = X.shape[0]
    X = np.concatenate(X).reshape((n,320))
    return X

In [11]:
output = lire_alpha_digit(data, [1, 2, 3])

In [None]:
output.shape

# Model

## Definition

In [13]:
import copy 
from tqdm import tqdm

class RBM(nn.Module):

    def __init__(self, p: int, q:int, img_size: tuple[int, int]):
        super(RBM, self).__init__()
        self.p = p
        self.q = q
        self.height= img_size[0]
        self.width= img_size[1]

        # Notes:
        # nn.Parameters automatically adds the variable to the list of the model's parameters
        # nn.Parameter tells Pytorch to include this tensor in the computation graph and compute gradients for it during backprop
        # nn.Parameter allow params to move to the right devide when appluing .to(device)

        # Parameters
        self.W = nn.Parameter(torch.randn(q, p)*1e-2)

        # Bias - initialised at 0
        self.a = nn.Parameter(torch.zeros(p))
        self.b = nn.Parameter(torch.zeros(q))

    def entree_sortie(self, v):
        # F.linear performs: v.W + b
        sigm = torch.sigmoid(F.linear(v, self.W, self.b))
        return sigm

    def sortie_entree(self, h):
        # F.linear performs: h.W(transpose) + a
        sigm = torch.sigmoid(F.linear(h, self.W.t(), self.a))
        return sigm

    def forward(self, v):
        raise NotImplementedError("Use the train method for training the RBM.")

    def train(self, V, nb_epoch, batch_size, eps=0.001, verbose=True):
        """
        Train the RBM using Contrastive Divergence

        Args:
        - V: input data
        - nb_epoch: Number of epoch
        - batch_size: Batch size
        - eps: Learning rate
        """
        n = V.size(0)
        p, q = self.p, self.q
        losses = []

        for epoch in range(nb_epoch):
            # Shuffle dataset
            V = V[torch.randperm(n)]
            
            # Iterate with batch_size step
            for j in range(0, n, batch_size):
                V_batch = V[j:min(j + batch_size, n)]
                batch_size_actual = V_batch.size(0)

                v_0 = copy.deepcopy(V_batch)
                p_h_v_0 = self.entree_sortie(v_0)
                # Sample
                h_0 = (torch.rand(batch_size_actual, q) < p_h_v_0).float()
                
                p_v_h_0 = self.sortie_entree(h_0)
                v_1 = (torch.rand(batch_size_actual, p) < p_v_h_0).float()
                
                p_h_v_1 = self.entree_sortie(v_1)

                # grad
                grad_a = torch.sum(v_0 - v_1, dim=0)
                grad_b = torch.sum(p_h_v_0 - p_h_v_1, dim=0)
                grad_W = torch.matmul(v_0.t(), p_h_v_0) - torch.matmul(v_1.t(), p_h_v_1)

                # Update params - Normalise to batch size
                # Note: We bypass the pytorch computation graph with .data to avoid accumulating gradients
                self.W.data += eps * grad_W.t() / batch_size_actual
                self.b.data += eps * grad_b / batch_size_actual
                self.a.data += eps * grad_a / batch_size_actual

            quad_error = self.reconstruction_error(V)
            
            losses.append(quad_error.item())
            if verbose and epoch%10==0:
                print(f"Epoch {epoch+1}/{nb_epoch}, Reconstruction Error (EQ): {quad_error.item():.6f}")
        
        self.losses = losses
        self.nb_epoch = nb_epoch

        if verbose:
            self.plot_loss()
            print(f"Final loss: {self.losses[-1]}")

        return losses

    def generer_image_RBM(self, num_iterations, num_images):
        """
        Generate samples from an RBM using Gibbs sampling.

        Args:
            rbm (RBM): The RBM object.
            num_iterations (int): Number of Gibbs sampling steps to use.
            num_images (int): Number of images to generate.

        Returns:
            torch.Tensor: A tensor of generated images.
        """
        # Initialize random visible states
        generated_images = torch.bernoulli(torch.rand(num_images, self.p))  # Random binary states

        # Perform Gibbs sampling
        for _ in range(num_iterations):
            # Sample hidden states given visible states
            p_h_given_v = self.entree_sortie(generated_images)
            h = (torch.rand_like(p_h_given_v) < p_h_given_v).float()

            # Sample visible states given hidden states
            p_v_given_h = self.sortie_entree(h)
            generated_images = (torch.rand_like(p_v_given_h) < p_v_given_h).float()

        print(f"Infos on generated_images:")
        print(f"type: {type(generated_images)}")
        print(f"size: {generated_images.size()}")

        # Plot the generated images
        _, axes = plt.subplots(1, num_images, figsize=(num_images * 2, 2))
        for i, ax in enumerate(axes):
            ax.imshow(generated_images[i].view(self.height, self.width), cmap="gray")
            ax.axis("off")

        plt.show()

        return generated_images
    
    def save_generated_images(self, generated_images: torch.Tensor, path: str, title=None) -> None:

        num_images = generated_images.shape[0]

        # Plot the generated images
        _, axes = plt.subplots(1, num_images, figsize=(num_images * 2, 2))
        for i, ax in enumerate(axes):
            ax.imshow(generated_images[i].view(self.height, self.width), cmap="gray")
            ax.axis("off")

        if title is not None:
            plt.suptitle(title)

        plt.savefig(path)
    
    def plot_loss(self):

        if not hasattr(self, "losses"):
            raise AttributeError("The attribute 'losses' is missing. Make sure to initialize it in runing the train() method before calling plot().")
        
        plt.plot(self.losses)
        plt.xlabel('epochs')
        plt.ylabel('loss')
        plt.title(f'Loss | {self.nb_epoch} epochs')
        plt.show()

    def reconstruction_error(self, V: torch.Tensor, return_float=False) -> torch.Tensor | float:
        """Compute the reconstruction error on a set of tensor images"""
        n = V.size(0)
        H = self.entree_sortie(V)
        V_rec = self.sortie_entree(H)
        torch_sum = torch.sum((V - V_rec)**2) / (n*self.p)
        if return_float:
            return float(torch_sum.item())
        return torch_sum




## Training

In [14]:
rbm = RBM(p=image_tensor.size(1), q=64, img_size=(height, width))

In [17]:
def train_rbm_on_dataset(rbm: RBM, data_folder, dataset='binaryalphadigs.mat', nb_epoch=1000, batch_size=10, eps=1e-2, p=256, q=128, verbose=True):
    """
    Function to train the RBM model on the given dataset of binary images.

    Args:
    - data_folder: Path to the folder containing the 'binaryalphadigs.mat' file.
    - nb_epoch: Number of training epochs.
    - batch_size: Batch size for training.
    - eps: Learning rate for gradient update.
    - p: Number of visible units.
    - q: Number of hidden units.
    """
    
    # Load the dataset
    image_tensor, images, height, width = load_data(data_folder, dataset)

    # Normalize the data (if necessary)
    # Assuming the data is binary, no need to scale between 0 and 1 but if it isn't, you can scale it:
    # image_tensor = image_tensor / 255.0  # If data isn't binary, you can normalize

    # Initialize RBM model
    

    # Train the RBM
    rbm.train(image_tensor, nb_epoch=nb_epoch, batch_size=batch_size, eps=eps, verbose=verbose)

    print("Training complete.")

In [None]:
train_rbm_on_dataset(rbm, data_folder=data_folder)

## Test

In [None]:
# Assume rbm is a trained RBM object
generated_images = rbm.generer_image_RBM(num_iterations=100000, num_images=10)

# Study

### 1. hyperparameters variations

#### - Number of hidden unities

In [None]:
image_tensor, images, height, width = load_data(data_folder, 'binaryalphadigs.mat')

trained_rbm = []

for q in [16, 32, 64, 128, 256]:

    rbm = RBM(p=image_tensor.size(1), q=q, img_size=(height, width))
    rbm.train(image_tensor, nb_epoch=400, batch_size=10, eps=1e-2, verbose=True)
    trained_rbm.append({"q":q, "rbm": rbm})


In [None]:
for q_rbm in trained_rbm:
    q = q_rbm["q"]
    rbm: RBM = q_rbm["rbm"]

    print(f">>>>> q = {q}")

    generated_images = rbm.generer_image_RBM(num_iterations=100000, num_images=10)
    error = rbm.reconstruction_error(generated_images, return_float=True)
    rbm.save_generated_images(generated_images, path=f"assets/rbm_q_{q}.png", title=f"Nombre d'unités cachées : {q} | Erreur de reconstruction : {round(error, 3)}")

#### - batch size

In [None]:
image_tensor, images, height, width = load_data(data_folder, 'binaryalphadigs.mat')

trained_rbm = []

for batch_size in [1, 5, 10, 50, 100]:

    rbm = RBM(p=image_tensor.size(1), q=128, img_size=(height, width))
    rbm.train(image_tensor, nb_epoch=400, batch_size=batch_size, eps=1e-2, verbose=True)
    trained_rbm.append({"batch_size":batch_size, "rbm": rbm})

In [23]:
for rbm_ in trained_rbm:
    batch_size = rbm_["batch_size"]
    rbm: RBM = rbm_["rbm"]

    print(f">>>>> batch_size = {batch_size}")

    generated_images = rbm.generer_image_RBM(num_iterations=100000, num_images=10)
    error = rbm.reconstruction_error(generated_images, return_float=True)
    rbm.save_generated_images(generated_images, path=f"assets/rbm_batchsize_{batch_size}.png", title=f"Batch size : {batch_size} | Erreur de reconstruction : {round(error, 3)}")

### 2. Number of Characters to Learn

In [24]:
def image_to_tensor(images: np.ndarray) -> torch.Tensor:
    # Flatten all images across all classes (0-35: 10 digits + 26 letters)
    flattened_images = [img.flatten() for img in images]

    # Convert to a tensor
    image_tensor = torch.tensor(np.array(flattened_images), dtype=torch.float32)
    return image_tensor


#### Only A

In [None]:
images = lire_alpha_digit(data, [10]) # Only A
image_tensor = image_to_tensor(images)

rbm = RBM(p=image_tensor.size(1), q=128, img_size=(20, 16))

_ = rbm.train(image_tensor, nb_epoch=1300, batch_size=10, eps=1e-2, verbose=True)

In [None]:
generated_images = rbm.generer_image_RBM(num_iterations=1000, num_images=10)
error = rbm.reconstruction_error(generated_images, return_float=True)
rbm.save_generated_images(generated_images, path=f"assets/BALD_A.png", title=f"Dataset : Binary Alpha Digits - [A] | q : {128} | Erreur de reconstruction : {round(error, 3)}")

#### Try A, AB, ABC, ABCD... A-Z

In [None]:
alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

list_error = []

for i in range(26):
    
    images = lire_alpha_digit(data, list(range(10, 11+i)))
    image_tensor = image_to_tensor(images)

    rbm = RBM(p=image_tensor.size(1), q=128, img_size=(20, 16))

    rbm.train(image_tensor, nb_epoch=1000, batch_size=10, eps=1e-2, verbose=False)

    generated_images = rbm.generer_image_RBM(num_iterations=1000, num_images=10)
    error = rbm.reconstruction_error(generated_images, return_float=True)
    rbm.save_generated_images(generated_images, path=f"assets/BALD_A-{alphabet[i]}.png", title=f"Dataset : Binary Alpha Digits - [A-{alphabet[i]}] | q : {128} | Erreur de reconstruction : {round(error, 3)}")

    list_error.append(error)

- MNIST

In [28]:
from pathlib import Path

def load_mnist_data(data_folder='./data/mnist-dataset') -> tuple[torch.Tensor, np.ndarray, int, int]:

    folder_path =Path(data_folder)
    train_image_path = folder_path / 'train-images.idx3-ubyte'

    with open(train_image_path, 'rb') as file: 
        data = np.frombuffer(file.read(), dtype = np.uint8)

    # values of pixels : [0-255] -> [0, 1]
    binarized_data = (data > 127).astype(int)

    height, width = 28, 28
    images = binarized_data[16:].reshape(-1, height, width)

    # Keep only 1404 images like in binaryalphadigs dataset
    images: list[np.ndarray] = images[np.random.choice(images.shape[0], 1404, replace=False)]

    flattened_images = [img.flatten() for img in images]

    # Convert to a tensor
    image_tensor = torch.tensor(np.array(flattened_images), dtype=torch.float32)

    return image_tensor, images, height, width


In [None]:
image_tensor, images, height, width = load_mnist_data()

plt.imshow(images[3], cmap='gray')
plt.show()

In [None]:
q=128
rbm = RBM(p=image_tensor.size(1), q=q, img_size=(height, width))

_ = rbm.train(image_tensor, nb_epoch=400, batch_size=10, eps=1e-2, verbose=True)

In [None]:
generated_images = rbm.generer_image_RBM(num_iterations=100000, num_images=10)
error = rbm.reconstruction_error(generated_images, return_float=True)
rbm.save_generated_images(generated_images, path=f"assets/mnist.png", title=f"Dataset : MNIST | q : {q} | Erreur de reconstruction : {round(error, 3)}")

- Fashion-MNIST

In [32]:
from pathlib import Path

def load_fashion_mnist_data(data_folder='./data/fashionmnist-dataset') -> tuple[torch.Tensor, np.ndarray, int, int]:

    folder_path =Path(data_folder)
    train_image_path = folder_path / 'train-images-idx3-ubyte'

    with open(train_image_path, 'rb') as file: 
        data = np.frombuffer(file.read(), dtype = np.uint8)

    # values of pixels : [0-255] -> [0, 1]
    binarized_data = (data > 127).astype(int)

    height, width = 28, 28
    images = binarized_data[16:].reshape(-1, height, width)

    # Keep only 1404 images like in binaryalphadigs dataset
    images: list[np.ndarray] = images[np.random.choice(images.shape[0], 1404, replace=False)]

    flattened_images = [img.flatten() for img in images]

    # Convert to a tensor
    image_tensor = torch.tensor(np.array(flattened_images), dtype=torch.float32)

    return image_tensor, images, height, width

In [None]:
image_tensor, images, height, width = load_fashion_mnist_data()

plt.imshow(images[1], cmap='gray')
plt.show()

In [None]:
q=128
rbm = RBM(p=image_tensor.size(1), q=q, img_size=(height, width))

_ = rbm.train(image_tensor, nb_epoch=400, batch_size=10, eps=1e-2, verbose=True)

In [None]:
generated_images = rbm.generer_image_RBM(num_iterations=100000, num_images=10)
error = rbm.reconstruction_error(generated_images, return_float=True)
rbm.save_generated_images(generated_images, path=f"assets/fashion_mnist.png", title=f"Dataset : Fashion MNIST | q : {q} | Erreur de reconstruction : {round(error, 3)}")

### 4. Try other models

- GAN (Not fully functional)

based on the following github repository : https://github.com/sssingh/mnist-digit-generation-gan.git

In [None]:
import math
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
import os
import scipy

data_folder = "../data"


def create_dataset(data_folder=data_folder, file='binaryalphadigs.mat'):
    
    data_file = os.path.join(data_folder, file)
    data = scipy.io.loadmat(data_file)
    images = data['dat']

    image_array = np.array([img.tolist() for img in images.flatten()])
    image_tensor = torch.tensor(image_array, dtype=torch.float32)

    targets = torch.tensor([[c]*images.shape[1] for c in range(images.shape[0])])
    targets = targets.flatten()

    return image_tensor, targets

In [None]:
from torch.utils.data import TensorDataset, DataLoader

image_tensor, targets = create_dataset()

dataset = TensorDataset(image_tensor, targets)

In [None]:
# Build dataloader
dl = DataLoader(dataset=dataset,
                shuffle=True,
                batch_size=64)

In [None]:
# Examine a sample batch from the dataloader
image_batch = next(iter(dl))
print(len(image_batch), type(image_batch))
print(image_batch[0].shape)
print(image_batch[1].shape)

In [None]:
## ----------------------------------------------------------------------------
## Visualise a sample batch
## ----------------------------------------------------------------------------

def display_images(images, n_cols=4, figsize=(12, 6)):
    """
    Utility function to display a collection of images in a grid
    
    Parameters
    ----------
    images: Tensor
            tensor of shape (batch_size, channel, height, width)
            containing images to be displayed
    n_cols: int
            number of columns in the grid
            
    Returns
    -------
    None
    """
    plt.style.use('ggplot')
    n_images = len(images)
    n_rows = math.ceil(n_images / n_cols)
    plt.figure(figsize=figsize)
    for idx in range(n_images):
        ax = plt.subplot(n_rows, n_cols, idx+1)
        image = images[idx]
        ax.imshow(image, cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])        
    plt.tight_layout()
    plt.show()

display_images(images=image_batch[0], n_cols=8)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Discriminator will down-sample the input producing a binary output
        self.fc1 = nn.Linear(in_features=in_features, out_features=128)
        self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2)        
        self.fc3 = nn.Linear(in_features=64, out_features=32)
        self.leaky_relu3 = nn.LeakyReLU(negative_slope=0.2)        
        self.fc4 = nn.Linear(in_features=32, out_features=out_features)
        self.dropout = nn.Dropout(0.3)
        
        
    def forward(self, x):
        # Rehape passed image batch
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        # Feed forward
        x = self.fc1(x)
        x = self.leaky_relu1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.leaky_relu2(x)
        x = self.dropout(x)                        
        x = self.fc3(x)
        x = self.leaky_relu3(x)        
        x = self.dropout(x)
        logit_out = self.fc4(x)
        
        return logit_out

In [None]:
class Generator(nn.Module):
    def __init__(self, in_features, out_features):
        super(Generator, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Generator will up-sample the input producing input of size
        # suitable for feeding into discriminator
        self.fc1 = nn.Linear(in_features=in_features, out_features=32)
        self.relu1 = nn.LeakyReLU(negative_slope=0.2)
        self.fc2 = nn.Linear(in_features=32, out_features=64)
        self.relu2 = nn.LeakyReLU(negative_slope=0.2)        
        self.fc3 = nn.Linear(in_features=64, out_features=128)
        self.relu3 = nn.LeakyReLU(negative_slope=0.2)        
        self.fc4 = nn.Linear(in_features=128, out_features=out_features)
        self.dropout = nn.Dropout(0.3)
        self.tanh = nn.Tanh()
        
        
    def forward(self, x):
        # Feed forward
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu2(x)        
        x = self.dropout(x)
        x = self.fc3(x)
        x = self.relu3(x)        
        x = self.dropout(x)
        x = self.fc4(x)
        tanh_out = self.tanh(x)
        
        return tanh_out

Loss 

In [None]:
def real_loss(predicted_outputs, loss_fn, device):
    """
    Function for calculating loss when samples are drawn from real dataset
    
    Parameters
    ----------
    predicted_outputs: Tensor
                       predicted logits
            
    Returns
    -------
    real_loss: int
    """
    batch_size = predicted_outputs.shape[0]
    # Targets are set to 1 here because we expect prediction to be 
    # 1 (or near 1) since samples are drawn from real dataset
    targets = torch.ones(batch_size).to(device)
    real_loss = loss_fn(predicted_outputs.squeeze(), targets)
    
    return real_loss


def fake_loss(predicted_outputs, loss_fn, device):
    """
    Function for calculating loss when samples are generated fake samples
    
    Parameters
    ----------
    predicted_outputs: Tensor
                       predicted logits
            
    Returns
    -------
    fake_loss: int
    """
    batch_size = predicted_outputs.shape[0]
    # Targets are set to 0 here because we expect prediction to be 
    # 0 (or near 0) since samples are generated fake samples
    targets = torch.zeros(batch_size).to(device)
    fake_loss = loss_fn(predicted_outputs.squeeze(), targets)
    
    return fake_loss 

Training

In [None]:
# Training loop function

def train_gan(d, g, d_optim, g_optim, loss_fn, dl, n_epochs, device, verbose=False):
    print(f'Training on [{device}]...')
    
    # Generate a batch (say 16) of latent image vector (z) of fixed size 
    # (say 100 pix) to be as input to the Generator after each epoch of 
    # training to generate a fake image. We'll visualise these fake images
    # to get a sense how generator improves as training progresses
    z_size = 100
    fixed_z = (np.random.uniform(-1, 1, size=(16, z_size)) > 0.5).astype(np.int32)
    fixed_z = torch.from_numpy(fixed_z).float().to(device)
    fixed_samples = []
    d_losses = []
    g_losses = []
    
    
    # Move discriminator and generator to available device
    d = d.to(device)
    g = g.to(device)
    
    for epoch in range(n_epochs):
        print(f'Epoch [{epoch+1}/{n_epochs}]:')
        # Switch the training mode on
        d.train()
        g.train()
        d_running_batch_loss = 0
        g_running_batch_loss = 0
        for curr_batch, (real_images, _) in enumerate(dl):
            # Move input batch to available device
            real_images = real_images.to(device)
            
            ## ----------------------------------------------------------------
            ## Train discriminator using real and then fake MNIST images,  
            ## then compute the total-loss and back-propogate the total-loss
            ## ----------------------------------------------------------------
            
            # Reset gradients
            d_optim.zero_grad()
            
            # Real MNIST images
            # Convert real_images value range of 0 to 1 to -1 to 1
            # this is required because latter discriminator would be required 
            # to consume generator's 'tanh' output which is of range -1 to 1
            real_images = (real_images * 2) - 1  
            d_real_logits_out = d(real_images)
            d_real_loss = real_loss(d_real_logits_out, loss_fn, device)
            #d_real_loss = real_loss(d_real_logits_out, smooth=True)
            
            # Fake images
            with torch.no_grad():
                # Generate a batch of random latent vectors 
                z = (np.random.uniform(-1, 1, size=(dl.batch_size, z_size)) > 0.5).astype(np.int32)
                z = torch.from_numpy(z).float().to(device)
                # Generate batch of fake images
                fake_images = g(z) 
            # feed fake-images to discriminator and compute the 
            # fake_loss (i.e. target label = 0)
            d_fake_logits_out = d(fake_images)
            d_fake_loss = fake_loss(d_fake_logits_out, loss_fn, device)
            #d_fake_loss = fake_loss(d_fake_logits_out)
            # Compute total discriminator loss
            d_loss = d_real_loss + d_fake_loss
            # Backpropogate through discriminator
            d_loss.backward()
            d_optim.step()
            # Save discriminator batch loss
            d_running_batch_loss += d_loss
            
            ## ----------------------------------------------------------------
            ## Train generator, compute the generator loss which is a measure
            ## of how successful the generator is in tricking the discriminator 
            ## and finally back-propogate generator loss
            ## ----------------------------------------------------------------

            # Reset gradients
            g_optim.zero_grad()
            
            # Generate a batch of random latent vectors
            #z = torch.rand(size=(dl.batch_size, z_size)).to(device)
            z = (np.random.uniform(-1, 1, size=(dl.batch_size, z_size)) > 0.5).astype(np.int32)
            z = torch.from_numpy(z).float().to(device)       
            # Generate a batch of fake images, feed them to discriminator
            # and compute the generator loss as real_loss 
            # (i.e. target label = 1)
            fake_images = g(z) 
            g_logits_out = d(fake_images)
            g_loss = real_loss(g_logits_out, loss_fn, device)
            #g_loss = real_loss(g_logits_out)
            # Backpropogate thorugh generator
            g_loss.backward()
            g_optim.step()
            # Save discriminator batch loss
            g_running_batch_loss += g_loss
            
            # Display training stats for every 200 batches 
            if curr_batch % 400 == 0 and verbose:
                print(f'\tBatch [{curr_batch:>4}/{len(dl):>4}] - d_batch_loss: {d_loss.item():.6f}\tg_batch_loss: {g_loss.item():.6f}')
            
        # Compute epoch losses as total_batch_loss/number_of_batches
        d_epoch_loss = d_running_batch_loss.item()/len(dl)
        g_epoch_loss = g_running_batch_loss.item()/len(dl)
        d_losses.append(d_epoch_loss)
        g_losses.append(g_epoch_loss)
        
        # Display training stats for every 200 batches 
        print(f'epoch_d_loss: {d_epoch_loss:.6f} \tepoch_g_loss: {g_epoch_loss:.6f}')
        
        # Generate fake images from fixed latent vector using the trained 
        # generator so far and save images for latter viewing
        g.eval()
        fixed_samples.append(g(fixed_z).detach().cpu())
        
    # Finally write generated fake images from fixed latent vector to disk
    with open('fixed_samples.pkl', 'wb') as f:
        pkl.dump(fixed_samples, f)
     
    return d_losses, g_losses

In [None]:
##
## Prepare and start training
##

# Instantiate Discriminator and Generator
d = Discriminator(in_features=784, out_features=1)
g = Generator(in_features=100, out_features=784)
#g = Generator(100, 32, 784)
print(d)
print()
print(g)

# Instantiate optimizers
d_optim = optim.Adam(d.parameters(), lr=0.002)
g_optim = optim.Adam(g.parameters(), lr=0.002)

# Instantiate the loss function
loss_fn = nn.BCEWithLogitsLoss()

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Train
n_epochs = 100
d_losses, g_losses = train_gan(d, g, d_optim, g_optim, 
                                     loss_fn, dl, n_epochs, device,
                                     verbose=False)