Autoencoder + MNIST + JAX Run: uv init uv add "jax[cuda12]" numpy tensorflow_datasets flax tensorflow clu optax matplotlib uv run main.py Minimal example of an autoencoder model trained on MNIST using JAX. Both the encoder and the decoder are dense layers. Ground truth Reconstructed