In [None]:
from __future__ import annotations
from typing import Any, cast
from dataclasses import dataclass
from functools import partial
import jax
import jax.numpy as jnp
from jax.tree_util import register_dataclass
import flax.linen as nn

import hodel

class Encoder(nn.Module):
    """Pixels to prototype."""

    hidden_size: int
    prototype_size: int

    @nn.compact
    def __call__(self, pixels):
        x = nn.leaky_relu(nn.Dense(self.hidden_size)(pixels))
        x = nn.Dense(self.prototype_size)(x)
        return x


class Classifier(nn.Module):
    """Prototype to one-hot output."""

    hidden_size: int
    class_size: int

    @nn.compact
    def __call__(self, prototype):
        x = nn.leaky_relu(nn.Dense(self.hidden_size)(prototype))
        x = nn.Dense(self.class_size)(x)
        return x


class Decoder(nn.Module):
    """Prototype to pixels."""

    hidden_size: int
    pixel_size: int

    @nn.compact
    def __call__(self, prototype):
        x = nn.leaky_relu(nn.Dense(self.hidden_size)(prototype))
        x = nn.sigmoid(nn.Dense(self.pixel_size)(x))
        return x


@register_dataclass
@dataclass
class MNISTParams:
    encoder_params: Any  # Flax params dict
    classifier_params: Any  # Flax params dict
    decoder_params: Any  # Flax params dict
    prototypes: jax.Array  # Latent


def create_mnist_params(key, hidden_size=256, prototype_size=64):
    """Initialize all parameters."""
    keys = jax.random.split(key, 5)

    # Create modules
    encoder_module = Encoder(hidden_size, prototype_size)
    classifier_module = Classifier(hidden_size, 10)
    decoder_module = Decoder(hidden_size, 784)

    # Initialize with dummy inputs
    dummy_pixels = jnp.zeros(784)
    dummy_prototype = jnp.zeros(prototype_size)

    encoder_params = encoder_module.init(keys[0], dummy_pixels)
    classifier_params = classifier_module.init(keys[1], dummy_prototype)
    decoder_params = decoder_module.init(keys[2], dummy_prototype)

    # Initialize prototypes: (10, 64)
    prototypes = jax.random.normal(keys[3], (10, prototype_size)) * 0.1

    return MNISTParams(
        encoder_params=encoder_params,
        classifier_params=classifier_params,
        decoder_params=decoder_params,
        prototypes=prototypes,
    )


@partial(
    register_dataclass,
    data_fields=[],
    meta_fields=["encoder", "classifier", "decoder"],
)
@dataclass
class MNIST:
    """
    MNIST Energy

    Classification:
    xb: pixels
    xf: one-hot class
    """

    encoder: nn.Module
    classifier: nn.Module
    decoder: nn.Module

    def get_energy(
        self,
        xf: jax.Array,
        xb: Any = None,
        Theta: MNISTParams | None = None,
        aux: Any = None,
    ) -> jax.Array:
        if Theta is None:
            raise ValueError("MNIST requires parameters")
        pixels = xb
        class_logits = xf
        return self._joint_energy(pixels, class_logits, Theta)

    def _joint_energy(
        self, pixels: jax.Array, class_logits: jax.Array, Theta: MNISTParams
    ) -> jax.Array:
        """Compute energy measuring pixel-class consistency."""

        # RECOGNITION PATH: pixels → z_enc → class_pred
        z_enc = cast(jax.Array, self.encoder.apply(Theta.encoder_params, pixels))
        class_pred = cast(
            jax.Array, self.classifier.apply(Theta.classifier_params, z_enc)
        )
        recognition_error = jnp.sum((class_pred - class_logits) ** 2)

        # GENERATION PATH: class → z_dec → pixels_pred
        class_probs = jax.nn.softmax(class_logits)
        z_dec = class_probs @ Theta.prototypes  # (10,) @ (10, 64) → (64,)
        pixels_pred = cast(jax.Array, self.decoder.apply(Theta.decoder_params, z_dec))
        generation_error = jnp.sum((pixels_pred - pixels) ** 2)

        # LATENT CONSISTENCY: z_enc should match z_dec
        latent_consistency = jnp.sum((z_enc - z_dec) ** 2)

        # CLASS PRIOR: encourage discrete (low entropy)
        class_probs = jax.nn.softmax(class_logits)
        entropy = -jnp.sum(class_probs * jnp.log(class_probs + 1e-8))

        # Weighted sum of terms
        energy = (
            generation_error
            + recognition_error
            + 0.5 * latent_consistency
            + 0.1 * entropy
        )
        return energy


hidden_size = 256
prototype_size = 64
key = jax.random.PRNGKey(0)

encoder_module = Encoder(hidden_size, prototype_size)
classifier_module = Classifier(hidden_size, 10)
decoder_module = Decoder(hidden_size, 784)
theta = create_mnist_params(key, hidden_size, prototype_size)
energy_model = MNIST(encoder_module, classifier_module, decoder_module)

Array(206.05923, dtype=float32)