Skip to content

Commit

Permalink
transpose of vprim now handles symbolic zeros; see patrick-kidger/opt…
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Mar 2, 2024
1 parent c79c393 commit 46cd9dd
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,11 @@ def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree:


# Useful helper for JVP rules of higher-order primitives.
def materialise_zeros(primal, tangent):
if tangent is None and is_array_like(primal):
def materialise_zeros(primal, tangent, allow_struct=False):
arraylike = is_array_like(primal)
if allow_struct:
arraylike = arraylike or isinstance(primal, jax.ShapeDtypeStruct)
if tangent is None and arraylike:
tangent = _zero_from_primal(primal)
return ad.instantiate_zeros(tangent)
else:
Expand Down Expand Up @@ -410,6 +413,10 @@ def _vprim_transpose(
axis_size=__axis_size,
axis_name=__axis_name,
)
if prim.multiple_results:
cts = tuple(None if type(c) is ad.Zero else c for c in cts)
else:
cts = None if type(cts) is ad.Zero else cts
return transpose(cts, *inputs)


Expand Down

0 comments on commit 46cd9dd

Please sign in to comment.