# Energy-based Model (In-progress)

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**

- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184?keywords=generative+deep+learning,+2nd+edition&qid=1684708209&sprefix=generative+de,aps,93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/07_ebm/01_ebm/ebm.ipynb)

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms as Transforms
from torchinfo import summary

from matplotlib import pyplot as plt

## 0. Training Parameters

In [2]:
IMAGE_SIZE = 32
CHANNELS = 1
STEP_SIZE = 10
STEPS = 60
NOISE = 5e-3
ALPHA = 0.1
GRADIENT_CLIP = 3e-2
BATCH_SIZE = 128
BUFFER_SIZE = 8192
LEARNING_RATE = 1e-4
EPOCHS = 60

## 1. MNIST Dataset

In [5]:
def get_dataloader():

    transform = Transforms.Compose([
                    Transforms.ToTensor(),
                    Transforms.Resize(IMAGE_SIZE, antialias=True),
                    Transforms.Normalize((0.5), (0.5)),
                    Transforms.Pad(2),])

    train_ds = torchvision.datasets.MNIST("./data", train=True,
                                          download=True,
                                          transform=transform)
    test_ds = torchvision.datasets.MNIST("./data", train=False,
                                          download=True,
                                          transform=transform)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=4)

    return train_loader, test_loader

In [4]:
# check dataset and dataloader
temp_loader, _ = get_dataloader()
print(next(iter(temp_loader))[0].shape)

torch.Size([128, 1, 36, 36])


## 2. Building Energy Function $E(x)$

In [29]:
class Energy(nn.Module):

    conv_layers = 4
    channels:list = [1, 16, 32, 64, 64]
    kernels:list = [5, 3, 3, 3]
    
    def __init__(self):
        super().__init__()
        modules = []
        # Adding convolutional layers + Swish activations
        for i in range(self.conv_layers):
            modules.append(nn.Conv2d(in_channels=self.channels[i], 
                                     out_channels=self.channels[i+1], 
                                     kernel_size=self.kernels[i],
                                     stride=2,
                                     padding=self.kernels[i] // 2))
            modules.append(nn.SiLU())

        # Adding linear layers
        modules += [nn.Flatten(),]
                    #nn.Linear(in_features=36_928, out_features=64),
                    #nn.SiLU(),
                    #nn.Linear(64, 1)]
        
        self.model = nn.Sequential(*modules)

    def forward(self, x):
        return self.model(x)

In [30]:
summary(Energy(), input_size=(32, 1, 36, 36))

Layer (type:depth-idx)                   Output Shape              Param #
Energy                                   [32, 576]                 --
├─Sequential: 1-1                        [32, 576]                 --
│    └─Conv2d: 2-1                       [32, 16, 18, 18]          416
│    └─SiLU: 2-2                         [32, 16, 18, 18]          --
│    └─Conv2d: 2-3                       [32, 32, 9, 9]            4,640
│    └─SiLU: 2-4                         [32, 32, 9, 9]            --
│    └─Conv2d: 2-5                       [32, 64, 5, 5]            18,496
│    └─SiLU: 2-6                         [32, 64, 5, 5]            --
│    └─Conv2d: 2-7                       [32, 64, 3, 3]            36,928
│    └─SiLU: 2-8                         [32, 64, 3, 3]            --
│    └─Flatten: 2-9                      [32, 576]                 --
Total params: 60,480
Trainable params: 60,480
Non-trainable params: 0
Total mult-adds (M): 41.77
Input size (MB): 0.17
Forward/backward pass si