Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added max and avg pool, alongside tests and docs #77

Merged
merged 1 commit into from May 6, 2022
Merged

Added max and avg pool, alongside tests and docs #77

merged 1 commit into from May 6, 2022

Conversation

Benjamin-Walker
Copy link
Contributor

I currently use einops.rearrange so that the pooling layers and conv layers take the same tensor format as input, BCHW. For some reason lax.conv_general_dilated and lax.reduce_window are BCHW and BHWC respectively. Can do this not using einops if you would prefer to not have the extra dependency.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'd prefer to avoid the additional dependency on einops. Good thinking to have the conv/pooling layers use the same dimension ordering. (Honestly the lax.conv_{...} options are a mess of conflicting defaults.)

Can you add some tests for the backward pass? As per #66 it seems like that's a potential issue we may need to try and resolve.

equinox/nn/pool.py Outdated Show resolved Hide resolved
equinox/nn/pool.py Outdated Show resolved Hide resolved
equinox/nn/pool.py Show resolved Hide resolved
equinox/nn/pool.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

Also: I think the new pooling documentation needs to be added to mkdocs.yml.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if you can get to the bottom of the backprop business.

Other than that, and the nits below, LGTM.

class Pool(Module):
""" General N-dimensional downsampling over a sliding window. """

operation: Callable[[Array], Array] = static_field()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be static. (In principle the operation may be learnt.) Also I think the annotation should be Callable[[Array, Array], Array].


def __init__(
self,
operation: Callable[[Array], Array],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise annotation here

self.operation = operation
self.init = 0.0
if operation is lax.max:
self.init = -jnp.inf
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make init an input, don't try to guess based on special values. Essentially Pool should make sense as a standalone operation.

if isinstance(padding, int):
self.padding = tuple((padding, padding) for _ in range(num_spatial_dims))
elif isinstance(padding, Sequence) and \
all(isinstance(element, tuple) for element in padding):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(element, tuple)
to
isinstance(element, Sequence)

@Benjamin-Walker
Copy link
Contributor Author

Made changes, think I have correctly typeset to the flake8 precommit, but I don't have approval to run workflows.

@Benjamin-Walker
Copy link
Contributor Author

Also with regards to the non-differentiability, my understanding is that there was a small mistake in #66 with the init for the maximum pooling. From the tensorflow docs `All initial values have to form an identity under computation' (https://www.tensorflow.org/xla/operation_semantics#reducewindow).

Hence when using lax.add as your computation you need to init the pooling at 0.0, as lax.add of 0 and your first element will be your first element. Similarly when using lax.max, you need to init at -jnp.inf, as lax.max of -jnp.inf and your first element will always be your first element.

This has been noted to be flagged as the error "Differentiation rule for 'reduce_window' not implemented" in Jax, google/jax#7718.

kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[int, Sequence[int], Sequence[tuple[int, int]]] = 0,
init: int = 0,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the choice of init is intrinsically tied to the operation. So I wouldn't give this any default value at all, and instead demand that it be user-specified. (For consistency with lax.reduce_window I'd make the argument order init then operation, by the way.)

self.operation = operation
self.init = init
if operation is lax.max:
self.init = -jnp.inf
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line still needs to be removed.

num_spatial_dims: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[int, Sequence[int], Sequence[tuple[int, int]]] = 0,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be typing.Tuple, not tuple. The latter only works for Python 3.9+.

- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
**Returns:**
A JAX array of shape `(in_channels, new_dim_1, ..., new_dim_N)`.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: rename both instances of in_channels to just channels, since there's no in/out distinction here.

tests/test_nn.py Outdated
@@ -6,6 +6,8 @@
import jax.random as jrandom
import pytest

import sys
sys.path.append('/Users/benwalker/PycharmProjects/equinox')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hahahah oops, didn't mean to leave that one in

@patrick-kidger
Copy link
Owner

patrick-kidger commented May 6, 2022

Various nits above but I think we're nearly there.

Nice work on the nondifferentiability; I'm glad to hear that won't be an issue.

@Benjamin-Walker
Copy link
Contributor Author

Made those changes, also for the differentiability do you think it is worth noting somewhere (maybe in the arguments docstring?) that if you use Pool by itself you need to make sure that operation(init, x) = x for all finite x?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good thought; add that somewhere in the documentation for Pool. Preferably with a few references to the relevant JAX issues etc.


def __init__(
self,
init: int,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: not necessarily an integer! e.g. jnp.NINF is a float.

Admittedly I'm not sure what the correct annotation is here. Union[int, float] maybe? Probably Arrays are valid too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have gone for Union[int, float, Array], haven't seen anyone use an Array but if the elements you were comparing are arrays then this would be necessary.

@Benjamin-Walker
Copy link
Contributor Author

Fixed the remaining typesetting mistake

@patrick-kidger
Copy link
Owner

@Benjamin-Walker Checks have failed again!
It's actually a bit mysterious to me that the tests on 3.9 have failed as well, as list[eqx.nn.Conv2d | eqx.nn.MaxPool2D] should be valid in 3.9, I think. Any idea what's going on there?

By the way, to tidy up this branch before merging, can you squash all the commits together? I think the appropriate commands are something like (untested):

git reset --soft v050
git commit -m "Added max/average pooling."
git push -f

Noting that the last command will force-push and thus override this branch. (If you're feeling uncertain about that then you can use a different branch name, and open a new PR, instead.)

@Benjamin-Walker
Copy link
Contributor Author

I think it was actually 3.10 where they introduced using | instead of Union. Think I have squashed all the commits together, let me know if it hasn't worked.

tests/test_nn.py Outdated Show resolved Hide resolved
@patrick-kidger patrick-kidger merged commit 8666507 into patrick-kidger:v050 May 6, 2022
patrick-kidger added a commit that referenced this pull request May 6, 2022
* Improved filtered transformations (#71)

* Updated filter_{jit,grad,value_and_grad}

* Documented that BatchNorm supports multiple axis_names

* Doc fix

* added tree_inference

* Added filter_vmap and filter_pmap

* Made filter_{vmap,pmp} non-experimental. Fixed doc issues.

* Added 'inference' call-time argument to 'MultiheadAttention', which was missed when switching over 'Dropout' to use 'inference'.

* Added test for tree_inference

* Doc improvements

* Doc tweaks

* Updated tests

* Two minor bugfixes for filter_{jit,grad,vmap,pmap,value_and_grad}.

1. Applying filter_{jit,vmap,pmap} to a function with *args and **kwargs should now work. Previously inspect.Signature.apply_defaults was not filling these in.
2. The output of all filtered transformations has become a Module. This means that e.g. calling filter_jit(filter_grad(f)) multiple times will not induce any recompilation, as all filter_grad(f)s are PyTrees of the same structure as each other.

* Crash and test fixes

* tidy up

* Fixes and test fixes

* Finished adding tests.

* added is_leaf argument to filter and partition (#72)

* Improved tree_at a lot: (#73)

- Can now substitute arbitrary nodes, not just leaves
- These substituted nodes can now have different structures to each other.
- Will raise an error if `where` depends on the values of the leaves of the PyTree.

In addition, `tree_equal` should now work under JIT.

* Lots of doc fixes and tweaks (#74)

* Added tree_{de,}serialise_leaves (#80)

* Added max/average pooling. (#77)

* Tidy up pooling implementation (#81)

* Version bump

* Fixed tree_at breaking when replacing Nones. (#83)

Co-authored-by: Ben Walker <38265558+Benjamin-Walker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants