# Energy-based Model (In-progress)

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**
- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184?keywords=generative+deep+learning,+2nd+edition&qid=1684708209&sprefix=generative+de,aps,93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/07_ebm/01_ebm/ebm.ipynb)

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import random
import collections
from IPython import display
from tqdm.notebook import tqdm
import numpy as np

import tensorflow as tf
from tensorflow.keras import datasets

import jax
from jax import numpy as jnp
import flax.linen as nn

from matplotlib import pyplot as plt

## 0. Training Parameter

In [2]:
CHANNELS = 1
STEP_SIZE = 10
STEPS = 60
NOISE = 5e-3
ALPHA = 0.1
GRADIENT_CLIP = 3e-2
BATCH_SIZE = 128
BUFFER_SIZE = 8192
LEARNING_RATE = 1e-4
EPOCHS = 60

# kwargs for model's tabulate function
console_kwargs = {"width": 100, 
                  "force_terminal": False, 
                  "force_jupyter": True,
                  "soft_wrap": True}

## 1. Preparing MNIST Dataset

In [3]:
def preprocess(imgs):
    imgs = (imgs.astype("float32") - 127.5) / 127.5
    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values = -1.0)
    imgs = np.expand_dims(imgs, axis=-1)
    return imgs

def get_dataset():
    (train_ds, _), (test_ds, _) = datasets.mnist.load_data()
    train_ds = preprocess(train_ds)
    test_ds = preprocess(test_ds)

    train_ds = tf.data.Dataset.from_tensor_slices(train_ds).shuffle(1024).batch(BATCH_SIZE)
    test_ds = tf.data.Dataset.from_tensor_slices(test_ds).batch(BATCH_SIZE)

    return train_ds, test_ds

In [4]:
sample_ds, _ = get_dataset()
sample_batch = next(sample_ds.as_numpy_iterator())
print(sample_batch.shape)

(128, 32, 32, 1)


## 2. Build Energy Function $E(x)$

In [5]:
class Energy(nn.Module):

    num_conv_layers:int = 4
    # Flax module does not acceptable mutable properties
    channels:list = tuple([16, 32, 64, 64])
    kernels:list = tuple([5, 3, 3, 3])

    def setup(self):
        layers = []
        for i in range(self.num_conv_layers):
            layers.append(nn.Conv(features=self.channels[i],
                                  kernel_size=(self.kernels[i], self.kernels[i]),
                                  strides=2,
                                  padding="same"))
            layers.append(nn.activation.swish)
        
        self.conv_layers = nn.Sequential(layers)
        self.liner_layers = nn.Sequential([nn.Dense(64),
                                           nn.activation.swish,
                                           nn.Dense(1)])
        

    def __call__(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        return self.liner_layers(x)

In [6]:
print(Energy().tabulate(jax.random.PRNGKey(0), jnp.ones((1, 32, 32, 1)), console_kwargs=console_kwargs))






## 3. Setting Up Langevin Sampler

In [7]:
def generate_samples(state, rng, input_imgs, steps, step_size, return_img_per_step=False):

    @jax.grad
    def grad_fn(input_imgs):
        return state.apply_fn({"params": state.params},
                              input_imgs)

    imgs_per_step = []
    
    for _ in range(steps):
        # step 1, add noise to the input image
        key, rng = jax.random.split(rng, 2)
        noise = jax.random.normal(key, shape=input_imgs.shape)
        input_imgs += noise
        input_imgs = jnp.clip(input_imgs, min=-1.0, max=1.0)

        # step 2, get gradients for the current input
        grads = grad_fn(input_imgs)
        grads = jnp.clip(grads, min=-GRADIENT_CLIP, max=GRADIENT_CLIP)

        # step 3, apply gradients to the current input
        input_imgs += grads
        input_imgs = jnp.clip(input_imgs, min=-1.0, max=1.0)

        if return_img_per_step:
            imgs_per_step.append(np.array(input_imgs))

    if return_img_per_step:
        return np.stack(imgs_per_step, axis=0)
    else:
        return input_imgs.asarray()

## 4. Setting up Buffer to Store Examples

In [8]:
class Buffer:

    sample_size = BATCH_SIZE
    channels = CHANNELS

    def __init__(self):
        
        raise NotImplementedError