Skip to content

tomondev/autoencoder-jax-mnist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Autoencoder + MNIST + JAX

Run:

uv init
uv add "jax[cuda12]" numpy tensorflow_datasets flax tensorflow clu optax matplotlib
uv run main.py

Minimal example of an autoencoder model trained on MNIST using JAX. Both the encoder and the decoder are dense layers.

Ground truth

Reconstructed

About

Autoencoder trained on MNIST in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages