In [None]:
import os
import numpy as np

import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state

import functools

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

data_dir = "./data"
device = jax.devices('cpu')[0]

print(f"Data resides in        : {data_dir}")
print(f"Training model on      : {str(device)}")

In [None]:
from typing import Tuple, Dict, Any
from typing import Optional


class DoubleRBM(nn.Module):
    n_visible: int
    n_hidden: int
    k: int = 1
    n_chains: int = 100

    def setup(self):
        self.param("W_amp", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.param("b_amp", nn.initializers.zeros, (self.n_visible,))
        self.param("c_amp", nn.initializers.zeros, (self.n_hidden,))

        self.param("W_pha", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.param("b_pha", nn.initializers.zeros, (self.n_visible,))
        self.param("c_pha", nn.initializers.zeros, (self.n_hidden,))


    def __call__(self, data_dict: Dict[str, jnp.ndarray], aux_vars: Dict[str, Any]) -> Tuple[jnp.ndarray, Dict[str, Any]]:
        # access aux vars
        random_key = aux_vars.get("random_key") # we require some random key to be there
        persistent_chains = aux_vars.get("v_persistent", None)

        aux_vars = dict()

        comp_basis = 'Z' * self.n_visible
        if len(data_dict) == 1 and comp_basis in data_dict:
            data_batch = data_dict[comp_basis]
            pcd_loss, persistent_chains, random_key = self._loss_amp(data_batch, persistent_chains, random_key)

            if persistent_chains is not None:
                aux_vars["v_persistent"] = persistent_chains

            aux_vars["random_key"] = random_key

            return pcd_loss, aux_vars

        # else throw
        raise ValueError("Encountered batch of unsupported basis.")


    def _loss_amp(self, data_batch, persistent_chains, random_key):
        W = self.params["W_amp"]
        b = self.params["b_amp"]
        c = self.params["c_amp"]

        if persistent_chains is None:
            random_key, random_key_bern = jax.random.split(random_key)
            chains = jax.random.bernoulli(random_key_bern, p=0.5, shape=(self.n_chains, self.n_visible))
            model_batch, random_key = self._gibbs_sample(W, b, c, chains, random_key, k=self.k)
        else:
            model_batch, random_key = self._gibbs_sample(W, b, c, persistent_chains, random_key, k=self.k)
            persistent_chains = model_batch

        model_batch = jax.lax.stop_gradient(model_batch)  # stopping gradient tracking before computing the loss

        free_energy_data = self._free_energy(W, b, c, data_batch)
        free_energy_model = self._free_energy(W, b, c, model_batch)
        pcd_loss = jnp.mean(free_energy_data) - jnp.mean(free_energy_model)

        return pcd_loss, persistent_chains, random_key

    @staticmethod
    def _free_energy(W, b, c, v):
        visible_term = jnp.dot(v, b)
        hidden_term  = jnp.sum(jax.nn.softplus(v @ W + c), axis=-1)
        free_energy = -visible_term - hidden_term
        return free_energy

    @staticmethod
    def _gibbs_step(i, state, W, b, c, T=1.0):
        v, key = state

        # splitting generates different random numbers for each step, one of them is passed on
        key, h_key, v_key = jax.random.split(key, 3)

        h_logits = (v @ W + c) / T
        h_probs = jax.nn.sigmoid(h_logits)
        h = jax.random.bernoulli(h_key, h_probs).astype(jnp.float32)

        v_logits = (h @ W.T + b) / T
        v_probs = jax.nn.sigmoid(v_logits)
        v = jax.random.bernoulli(v_key, v_probs).astype(jnp.float32)
        return v, key

    @staticmethod
    def _gibbs_sample(W, b, c, v_init, rng, k=1, T=1.0):
        # the fori_loop enables JIT compilation of loops. It basically unrolls the loop over the fixed length k.

        body_fun = lambda i, state: DoubleRBM._gibbs_step(i, state, W, b, c, T)
        v_final, key = jax.lax.fori_loop(0, k, body_fun, (v_init, rng))
        return v_final, key

    @staticmethod
    def _annealing_step(i, state, params, T_schedule):
        v, rng = state
        T = T_schedule[i]
        # Perform one Gibbs step using the current temperature T.
        # Note: _gibbs_sample already handles k=1 and T
        v_next, rng_next = DoubleRBM._gibbs_sample(params, v, rng, k=1, T=T)
        return (v_next, rng_next)

    # the nowrap attribute basically tells JAX to not do the magic wrapping, which injects the params argument
    @nn.nowrap
    def generate(self, params: dict, n_samples: int, T_schedule: jnp.ndarray, rng: PRNGKey) -> jnp.ndarray:

        # get the initial state and perform initial key splitting
        rng, init_key = jax.random.split(rng)
        v = jax.random.bernoulli(init_key, p=0.5, shape=(n_samples, self.n_visible)).astype(jnp.float32)
        init_state = (v, rng)

        body_fun = lambda i, state: DoubleRBM._gibbs_step(i, state, params, T_schedule[i])

        # the fori_loop is still required, since this function will also be JIT-compiled
        v_final, _ = jax.lax.fori_loop(0, len(T_schedule), body_fun, init_state)
        return v_final