In [None]:
class ConditionalRBM(nn.Module):
    # ... (same as before) ...
    film_width: int = 64
    λ_l2: float = 1e-4

    @nn.compact
    def _conditioner(self, cond_flat: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        # now output four vectors: γ_b, β_b, γ_c, β_c
        x = nn.Dense(self.film_width)(cond_flat)
        x = nn.tanh(x)
        x = nn.Dense(2*(self.num_visible + self.num_hidden))(x)
        gamma_b, beta_b, gamma_c, beta_c = jnp.split(x,
                                                     [self.num_visible, 2*self.num_visible,
                                                      2*self.num_visible + self.num_hidden], axis=-1)
        return gamma_b, beta_b, gamma_c, beta_c

    def __call__(self, batch: jnp.ndarray, aux_vars: Dict[str, Any]) -> Tuple[jnp.ndarray, PRNGKey]:
        mode, key = aux_vars["mode"], aux_vars["key"]
        data = batch[...,0].astype(jnp.float32)   # (B,N)
        cond = batch[...,1:].reshape(data.shape[0], -1)  # (B, N*C)

        # base params
        W0, b0, c0 = self.W, self.b, self.c0
        if mode == TrainingMode.RBM_ONLY:
            W, b, c = W0, b0, c0
            reg = 0.0
        else:
            # stop gradients on base if COND_ONLY
            if mode == TrainingMode.COND_ONLY:
                W0, b0, c0 = jax.lax.stop_gradient(W0), jax.lax.stop_gradient(b0), jax.lax.stop_gradient(c0)
            γ_b, β_b, γ_c, β_c = self._conditioner(cond)
            b = (1+γ_b) * b0 + β_b         # conditioned visible biases
            c = (1+γ_c) * c0 + β_c         # conditioned hidden biases
            # L2 regularization on shifts β_b, β_c
            reg = self.λ_l2 * (jnp.sum(β_b**2) + jnp.sum(β_c**2))

        # positive phase
        F_pos = jnp.mean(self._free_energy(data, W0, b, c))

        # negative phase via CD-k
        key, k1 = jax.random.split(key)
        v_neg = jax.random.bernoulli(k1, 0.5, data.shape).astype(jnp.float32)
        def step(state, _):
            v, k = state
            return self._gibbs_step((v, k), W0, b, c, temp=self.T), None
        (v_k, _), _ = jax.lax.scan(step, (v_neg, key), None, length=self.k)
        F_neg = jnp.mean(self._free_energy(v_k, W0, b, c))

        loss = (F_pos - F_neg) + reg
        return loss, {"mode": mode, "key": key}