Skip to content

An Toy Object Oriented Deep Learning framework based on Jax

License

Notifications You must be signed in to change notification settings

sullivancolin/lorax

Repository files navigation

Lorax: a Toy OOP Deep Learning Framework using Jax and Pydantic

Inspired by Joel Grus' Live Coding a Deep Learning Library Original Repo Here

Features

  • Replace numpy backend with Jax
  • Automatic calculation of gradients using Jax Autograd via jax.grad
  • Automatic Pytree class registration via inheritance
  • Allow for compiliation to GPU or TPU
  • layers are immutiable pydantic models with simple json definition
  • Seamlessly parallelize from single instance inference to batch inference with jax.vmap
  • Additional activation layers and loss funcitons
  • Track progress with wandb
  • Includes Dropout
  • LSTM and Bidirectional LSTM
  • Frozen Linear, Embedding, LSTMcell layers
  • Experiment Config system with json schema compliant

About

An Toy Object Oriented Deep Learning framework based on Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published