# Energy-based Model (In-progress)

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**
- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184?keywords=generative+deep+learning,+2nd+edition&qid=1684708209&sprefix=generative+de,aps,93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/07_ebm/01_ebm/ebm.ipynb)

In [15]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import random
import collections
from typing import Any
from IPython import display
from tqdm.notebook import tqdm
import numpy as np

import tensorflow as tf
from tensorflow.keras import datasets

import jax
from jax import numpy as jnp
from flax import struct
import flax.linen as nn
from flax.training import train_state

import optax
from clu import metrics

from matplotlib import pyplot as plt

## 0. Training Parameter

In [3]:
IMG_SIZE = 32
CHANNELS = 1
STEP_SIZE = 10
STEPS = 60
NOISE = 5e-3
ALPHA = 0.1
GRADIENT_CLIP = 3e-2
BATCH_SIZE = 128
BUFFER_SIZE = 8192
LEARNING_RATE = 1e-4
EPOCHS = 60

# kwargs for model's tabulate function
console_kwargs = {"width": 100, 
                  "force_terminal": False, 
                  "force_jupyter": True,
                  "soft_wrap": True}

## 1. Preparing MNIST Dataset

In [4]:
def preprocess(imgs):
    imgs = (imgs.astype("float32") - 127.5) / 127.5
    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values = -1.0)
    imgs = np.expand_dims(imgs, axis=-1)
    return imgs

def get_dataset():
    (train_ds, _), (test_ds, _) = datasets.mnist.load_data()
    train_ds = preprocess(train_ds)
    test_ds = preprocess(test_ds)

    train_ds = tf.data.Dataset.from_tensor_slices(train_ds).shuffle(1024).batch(BATCH_SIZE)
    test_ds = tf.data.Dataset.from_tensor_slices(test_ds).batch(BATCH_SIZE)

    return train_ds, test_ds

In [5]:
sample_ds, _ = get_dataset()
sample_batch = next(sample_ds.as_numpy_iterator())
print(sample_batch.shape)

(128, 32, 32, 1)


## 2. Build Energy Function $E(x)$

In [12]:
class Energy(nn.Module):

    num_conv_layers:int = 4
    # Flax module does not acceptable mutable properties
    channels:tuple = tuple([16, 32, 64, 64])
    kernels:tuple = tuple([5, 3, 3, 3])

    def setup(self):
        layers = []
        for i in range(self.num_conv_layers):
            layers.append(nn.Conv(features=self.channels[i],
                                  kernel_size=(self.kernels[i], self.kernels[i]),
                                  strides=2,
                                  padding="same"))
            layers.append(nn.activation.swish)
        
        self.conv_layers = nn.Sequential(layers)
        self.liner_layers = nn.Sequential([nn.Dense(64),
                                           nn.activation.swish,
                                           nn.Dense(1)])
        

    def __call__(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        return self.liner_layers(x)

In [13]:
print(Energy().tabulate(jax.random.PRNGKey(0), jnp.ones((1, 32, 32, 1)), console_kwargs=console_kwargs))






## 3. Setting Up Langevin Sampler

In [14]:
def generate_samples(state, rng, input_imgs, steps, step_size, return_img_per_step=False):

    @jax.grad
    def grad_fn(input_imgs):
        return state.apply_fn({"params": state.params},
                              input_imgs)

    imgs_per_step = []
    
    for _ in range(steps):
        # step 1, add noise to the input image
        key, rng = jax.random.split(rng, 2)
        noise = jax.random.normal(key, shape=input_imgs.shape)
        input_imgs += noise
        input_imgs = jnp.clip(input_imgs, min=-1.0, max=1.0)

        # step 2, get gradients for the current input
        grads = grad_fn(input_imgs)
        grads = jnp.clip(grads, min=-GRADIENT_CLIP, max=GRADIENT_CLIP)

        # step 3, apply gradients to the current input
        input_imgs += grads
        input_imgs = jnp.clip(input_imgs, min=-1.0, max=1.0)

        if return_img_per_step:
            imgs_per_step.append(np.array(input_imgs))

    if return_img_per_step:
        return np.stack(imgs_per_step, axis=0)
    else:
        return input_imgs.asarray()

## 4. Setting up Buffer to Store Examples

In [16]:
class Buffer:

    sample_size:int = BATCH_SIZE
    channels:int = CHANNELS
    img_size:int = IMG_SIZE
    buffer_size:int = BUFFER_SIZE
    rng:Any = jax.random.PRNGKey(0)

    def __init__(self):
        self.examples = []
        for _ in range(self.sample_size):
            self.examples.append(
                np.random.uniform(low=-1.0, high=1.0, size=(1, self.img_size, self.img_size, self.channels))
            )

    def sample_new_exmps(self, state, steps, step_size):
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = np.random.uniform(low=-1.0, 
                                      high=1.0, 
                                      size=(n_new, self.img_size, self.img_size, self.channels))
        old_imgs = np.concatenate(random.choices(self.examples, k=self.sample_size-n_new), axis=0)
        input_imgs = np.concatenate([rand_imgs, old_imgs], axis=0)

        input_imgs = generate_samples(state, self.rng, input_imgs, steps, step_size)
        self.examples = np.split(input_imgs, 
                                 indices_or_sections=self.sample_size,
                                 axis=0) + self.examples
        self.examples = self.examples[:self.buffer_size]
        return input_imgs

## 5. Setting up EBM `TrainState`

In [17]:
@struct.dataclass
class TrainMetrics(metrics.Collection):
    loss: metrics.Average.from_output("loss")
    cdiv_loss: metrics.Average.from_output("cdiv_loss")
    reg_loss: metrics.Average.from_output("reg_loss")
    real: metrics.Average.from_output("real")
    fake: metrics.Average.from_output("fake")

@struct.dataclass
class ValidMetrics(metrics.Collection):
    cdiv_loss: metrics.Average.from_output("cdiv_loss")
    real: metrics.Average.from_output("real")
    fake: metrics.Average.from_output("fake")

# EBM state class inherits flax's TrainState class
class EBM_state(train_state.TrainState):
    train_metrics: TrainMetrics
    valid_metrics: ValidMetrics

def create_ebm_state():
    raise NotImplementedError
