Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eqx.filter_{vmap,pmap}(out=...) not experimental #124

Merged
merged 2 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ name: Run tests

on:
pull_request:
schedule:
- cron: "0 2 * * 6"

jobs:
run-test:
Expand Down
4 changes: 4 additions & 0 deletions docs/api/filtering/filtered-transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ Practically speaking these are usually the only kind of filtering you ever have
---

::: equinox.filter_pmap

---

::: equinox.filter_eval_shape
3 changes: 2 additions & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import experimental, nn
from .eval_shape import filter_eval_shape
from .filters import (
combine,
filter,
Expand All @@ -18,4 +19,4 @@
from .vmap_pmap import filter_pmap, filter_vmap


__version__ = "0.5.3"
__version__ = "0.5.4"
14 changes: 11 additions & 3 deletions equinox/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,22 @@ class Static(Module):
value: Any = static_field()


def strip_wrapped_partial(fun):
def _strip_wrapped_partial(fun):
if hasattr(fun, "__wrapped__"): # ft.wraps
return strip_wrapped_partial(fun.__wrapped__)
return _strip_wrapped_partial(fun.__wrapped__)
if isinstance(fun, ft.partial):
return strip_wrapped_partial(fun.func)
return _strip_wrapped_partial(fun.func)
return fun


def get_fun_names(fun):
fun = _strip_wrapped_partial(fun)
try:
return fun.__name__, fun.__qualname__
except AttributeError:
return type(fun).__name__, type(fun).__qualname__


def compile_cache(fun):
@ft.lru_cache(maxsize=None)
def _cache(leaves, treedef):
Expand Down
28 changes: 28 additions & 0 deletions equinox/eval_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools as ft
from typing import Callable

import jax

from .compile_utils import Static
from .filters import combine, is_array_like, partition


def _filter(x):
return isinstance(x, jax.ShapeDtypeStruct) or is_array_like(x)


def filter_eval_shape(fun: Callable, *args, **kwargs):
"""As `jax.eval_shape`, but allows any Python object as inputs and outputs.

(`jax.eval_shape` is constrained to only work with JAX arrays, Python float/int/etc.)
"""

def _fn(_static, _dynamic):
_args, _kwargs = combine(_static, _dynamic)
_out = fun(*_args, **_kwargs)
_dynamic_out, _static_out = partition(_out, _filter)
return _dynamic_out, Static(_static_out)

dynamic, static = partition((args, kwargs), _filter)
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
return combine(dynamic_out, static_out.value)
17 changes: 9 additions & 8 deletions equinox/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@

from .compile_utils import (
compile_cache,
get_fun_names,
hashable_combine,
hashable_partition,
Static,
strip_wrapped_partial,
)
from .custom_types import BoolAxisSpec, PyTree, sentinel, TreeDef
from .doc_utils import doc_strip_annotations
from .filters import combine, filter, is_array, partition
from .filters import combine, is_array, partition
from .module import Module, module_update_wrapper


@compile_cache
def _filter_jit_cache(unwrapped_fun, **jitkwargs):
@ft.partial(jax.jit, static_argnums=1, **jitkwargs)
@ft.wraps(unwrapped_fun)
def _filter_jit_cache(fun_names, **jitkwargs):
def fun_wrapped(dynamic, static):
dynamic_fun, dynamic_spec = dynamic
(
Expand All @@ -40,7 +38,11 @@ def fun_wrapped(dynamic, static):
dynamic_out, static_out = partition(out, filter_out)
return dynamic_out, Static(static_out)

return fun_wrapped
fun_name, fun_qualname = fun_names
fun_wrapped.__name__ = fun_name
fun_wrapped.__qualname__ = fun_qualname

return jax.jit(fun_wrapped, static_argnums=1, **jitkwargs)


class _JitWrapper(Module):
Expand Down Expand Up @@ -266,11 +268,10 @@ def apply(f, x):
new_style = True
# ~Backward compatibility

unwrapped_fun = filter(strip_wrapped_partial(fun), filter_fn, inverse=True)
dynamic_fun, static_fun_leaves, static_fun_treedef = hashable_partition(
fun, filter_fn
)
cached = _filter_jit_cache(unwrapped_fun, **jitkwargs)
cached = _filter_jit_cache(get_fun_names(fun), **jitkwargs)

jit_wrapper = _JitWrapper(
_new_style=new_style,
Expand Down