Skip to content

patil-suraj/stable-diffusion-jax

Repository files navigation

TODOs:

  • Finish implementing the UNet2D model in modeling_unte2d.py. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file.
  • Adapt the PNDMScheduler from diffusers for JAX: Use jnp arrays and make it stateless.
  • Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in modeling_vae.py file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence.
  • Add an inference loop in pipeline_stabel_diffusion. We should able to jit/pmap the loop to deploy on TPUs.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published