In [37]:
class RBM(nn.Module):
    n_visible: int
    n_hidden: int

    def setup(self):
        self.W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        self.c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

    def _sample_hidden(self, v, T=1.0):
        key = self.make_rng("sample")
        logits = (v @ self.W + self.c) / T
        h_probs = jax.nn.sigmoid(logits)
        h_sample = jax.random.bernoulli(key, h_probs)
        return h_sample, h_probs

    def _sample_visible(self, h, T=1.0):
        key = self.make_rng("sample")
        logits = (h @ self.W.T + self.b) / T
        v_probs = jax.nn.sigmoid(logits)
        v_sample = jax.random.bernoulli(key, v_probs)
        return v_sample, v_probs

    def sample_gibbs(self, v0_sample, k=1, T=1.0):
        v = v0_sample
        for _ in range(k):
            h, _ = self._sample_hidden(v, T)
            v, _ = self._sample_visible(h, T)
        return v

    def free_energy(self, v):
        visible_term = jnp.dot(v, self.b)
        hidden_term = jnp.sum(jax.nn.softplus(v @ self.W + self.c), axis=1)
        return -visible_term - hidden_term

    def generate(self, params, n_samples=16, T_schedule=None, seed=0):
        key = jax.random.PRNGKey(seed)
        rbm = self.bind({"params": params}, rngs={"sample": key})
        v = jax.random.bernoulli(key, shape=(n_samples, self.n_visible)).astype(jnp.float32)

        for i, T in enumerate(T_schedule):
            key = jax.random.fold_in(key, i)
            rbm = rbm.replace_rngs({"sample": key})
            v = rbm.sample_gibbs(v, k=1, T=T)

        return v

    def __call__(self, v):
        return v # flax linen requires a __call__ method