# EBM Intro

In this notebook we show the basic methods of an EBM, how to get the probabilities, which inputs are required etc.
This notebook is aimed at future employees that need to get familiar with energax, not necessarily the end user.

In [None]:
from energax.sampling.discrete import CRBMGibbsSampler
import jax
from jax import numpy as jnp
from energax.ebms.rbms import CategoricalRBM, get_random_crbm_params

### Define a categorical RBM

with 4 visible units and 5 hidden units. We set the number of categories to `dim=3` for all the visible nodes. The hidden nodes are always binary for the current implementation of CRBMs. We also explicitly set `generate_bitstrings=True`, so that we can compute crbm.probability_vector() in the subsequent step. Note that this should be set to `False` for large systems.

In [13]:
vis = 4
hid = 5
dim = 3
structure = jnp.array([dim] * vis)
key = jax.random.PRNGKey(0)
params = get_random_crbm_params(key, num_visible=vis, num_hidden=hid, max_dim=dim)
crbm = CategoricalRBM(vis, hid, theta=params, structure=structure, generate_bitstrings=True)

Return all the probabilities for any possible visible state of the RBM

for `dim = 3` and `vis = 4`, we will have 81 possible visible states.

In [14]:
probs = crbm.probability_vector()
probs.shape

(81,)

### Claculate the conditional probabilities for the hidden nodes given the visible nodes

- The visible vectors must be *one-hot encoded*!
- And ph_given_v is not vectorized. It only takes a single sample input, e.g. `jnp.array([0,0,0,0])`. It won't work for example for `jnp.array([[0,0,0,0], [1,1,1,1]])`
- The output of this function is a vector of probabilities for each node to be 1, given the visible nodes.


In [15]:
visible_vectors = jnp.array([0, 0, 0, 0])
visible_vectors_oh = jax.nn.one_hot(visible_vectors, dim)
print(visible_vectors_oh)

crbm.compute_ph_given_v(visible_vectors_oh)

[[1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]]


Array([0.50380164, 0.5055251 , 0.48830503, 0.5088422 , 0.4791705 ],      dtype=float32)

### Claculate the conditional probabilities for the visible nodes given the hidden nodes

- The hidden vectors are **NOT** *one-hot encoded*!
- And ph_given_v is not vectorized. It only takes a single sample input, e.g. `jnp.array([0,0,0,0,0])`. It won't work for example for `jnp.array([[0,0,0,0,0], [1,1,1,1,1]])`
- The output of this function is a vector of probabilities for each node to be 1, given the visible nodes.


In [17]:
hidden_vectors = jnp.array([0, 0, 0, 0, 0])
crbm.compute_pv_given_h(hidden_vectors)

Array([[0.3300435 , 0.33413428, 0.33582225],
       [0.33285573, 0.33148077, 0.33566347],
       [0.3356481 , 0.33374014, 0.33061177],
       [0.3334181 , 0.33038253, 0.33619934]], dtype=float32)

## Sample

- Below we provide example code for sampling from a CRBM.
- The result `r` is a dictionary containing the sampled visible and hidden states, as well as the energy of the sampled states.
- The `sampler` object is a `CRBMGibbsSampler`, which implements the Gibbs sampling algorithm for CRBMs.
- The `sample_chains` function samples from the CRBM. It takes a batch of visible states as input, samples from the CRBM conditional on those states, and returns the sampled states.
- Note that we have to provide the number of chains and the number of samples per chain as arguments to the `CRBMGibbsSampler` object.
- We can use this to get samples from the CRBM, which we can then use to compute the probabilities, compute the energy, etc.

In [18]:
sampler = CRBMGibbsSampler(None, 0, 1000, 1)
r = sampler.sample_chains(crbm, jnp.expand_dims(visible_vectors_oh, 0), key)

visible_vectors_batch = jnp.array([[0, 0, 0, 0], [1, 1, 1, 1]])
visible_vectors_batch_oh = jax.nn.one_hot(visible_vectors_batch, dim)
sampler = CRBMGibbsSampler(None, 0, 10, 2)
r = sampler.sample_chains(crbm, visible_vectors_batch_oh, key)

-Below we print the result, which is a dictionary containing the sampled visible and hidden states, as well as the energy of the sampled states.

In [21]:
print(r)

{'energy': Array([[-0.16913292, -0.0085949 ,  0.04470699,  0.0292549 ,  0.1170828 ,
        -0.06824619, -0.05426199, -0.07182784, -0.05006632, -0.166739  ],
       [ 0.00611992, -0.07670631, -0.15397528, -0.0365875 , -0.02762325,
        -0.08484934,  0.00225205,  0.04843603, -0.12010484,  0.14609396]],      dtype=float32), 'position': {'h': Array([[[1, 1, 0, 1, 1],
        [0, 1, 1, 1, 1],
        [1, 1, 0, 0, 0],
        [0, 0, 1, 1, 0],
        [1, 1, 0, 1, 0],
        [0, 1, 1, 0, 0],
        [1, 1, 0, 0, 0],
        [0, 0, 1, 1, 0],
        [0, 0, 1, 0, 1],
        [1, 1, 0, 1, 1]],

       [[0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0],
        [1, 1, 1, 0, 0],
        [1, 0, 0, 0, 1],
        [0, 1, 0, 0, 1],
        [0, 1, 0, 0, 0],
        [1, 1, 0, 0, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 0, 1, 0]]], dtype=int8), 'v': Array([[[[0., 0., 1.],
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 0., 1.],
         [1., 0.