# Energy-based Model (In-progress)

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**
- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184?keywords=generative+deep+learning,+2nd+edition&qid=1684708209&sprefix=generative+de,aps,93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/07_ebm/01_ebm/ebm.ipynb)

In [1]:
import random
import collections
from IPython import display
from tqdm.notebook import tqdm

import numpy as np

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
from tensorflow.keras import datasets

import jax
from jax import numpy as jnp

from matplotlib import pyplot as plt

## 0. Training Parameter

In [2]:
CHANNELS = 1
STEP_SIZE = 10
STEPS = 60
NOISE = 5e-3
ALPHA = 0.1
GRADIENT_CLIP = 3e-2
BATCH_SIZE = 128
BUFFER_SIZE = 8192
LEARNING_RATE = 1e-4
EPOCHS = 60

## 1. Preparing MNIST Dataset

In [5]:
def preprocess(imgs):
    imgs = (imgs.astype("float32") - 127.5) / 127.5
    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values = -1.0)
    imgs = np.expand_dims(imgs, axis=-1)
    return imgs

def get_dataset():
    (train_ds, _), (test_ds, _) = datasets.mnist.load_data()
    train_ds = preprocess(train_ds)
    test_ds = preprocess(test_ds)

    train_ds = tf.data.Dataset.from_tensor_slices(train_ds).shuffle(1024).batch(BATCH_SIZE)
    test_ds = tf.data.Dataset.from_tensor_slices(test_ds).batch(BATCH_SIZE)

    return train_ds, test_ds

In [7]:
sample_ds, _ = get_dataset()
sample_batch = next(sample_ds.as_numpy_iterator())

In [8]:
print(sample_batch.shape)

(128, 32, 32, 1)
