Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for adaptive average pooling? #121

Closed
paganpasta opened this issue Jun 23, 2022 · 8 comments
Closed

Support for adaptive average pooling? #121

paganpasta opened this issue Jun 23, 2022 · 8 comments
Labels
feature New feature

Comments

@paganpasta
Copy link
Contributor

Hi again,

I was trying out equinox for some computer vision experiments and found that at the moment there is no support for adaptive average pooling. A similar functionality exists in Pytorch. So I just wanted to check if it is something you intend to add later.

Thanks.

@patrick-kidger
Copy link
Owner

This is something I'd be happy to accept a PR on.

@paganpasta
Copy link
Contributor Author

@patrick-kidger I'll definitely take a look at this. I'll first try to work on AdaptiveAvgPool1d and then make my way to 2d and 3d.

Do you have any resources at hand regarding this? If not, don't worry, I found a Stack Overflow discussion (link) and will use it as a starting point.

@patrick-kidger
Copy link
Owner

I don't know of any resources I'm afraid. Good luck!

@patrick-kidger patrick-kidger added the feature New feature label Jun 28, 2022
@paganpasta
Copy link
Contributor Author

Hi @patrick-kidger

I tried out a simple implementation for AdaptiveAvgPool1d to make sure that the overall working is correct.

import jax.numpy as jnp
import equinox as eqx

#for comparing results
import torch  
import random


class AdaptiveAvgPool1d(eqx.Module):
    target_size: int

    def __init__(self, *,target_size):
        self.target_size = target_size
    
    def __call__(self, x):
        #Assertion based on https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool1d.html
        assert x.ndim == 2 or x.ndim==1, 'Only supports 2D:[row, cols] or 1D:[cols] input'  

        input_shape = x.shape
        start_points = (jnp.arange(self.target_size, dtype='float32')*(input_shape[-1]/self.target_size)).astype(dtype='int32')
        end_points = jnp.ceil((jnp.arange(self.target_size, dtype='float32')+1)*(input_shape[-1]/self.target_size)).astype(dtype='int32')
        pooled  = []
        for idx in range(self.target_size):
            if x.ndim == 1:
                pooled.append(jnp.mean(x[start_points[idx]:end_points[idx]], axis=-1, keepdims=False))
            else:
                pooled.append(jnp.mean(x[:,start_points[idx]:end_points[idx]], axis=-1, keepdims=False))
        
        x = jnp.asarray(pooled)
        if x.ndim == 2:
            x = jnp.swapaxes(jnp.asarray(pooled), -1, -2)
        return x


#### Test NxCxK ####
for i in range(100):

    n,c,k = random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)
    pool_size = random.randint(1, 10)
    arr = torch.rand(n,c,k)
    
    torch_ans = torch.nn.AdaptiveAvgPool1d(pool_size)(arr)

    jnp_arr = jnp.asarray(arr.numpy().tolist())
    eq1d = AdaptiveAvgPool1d(target_size=pool_size)
    j1d = jax.vmap(eq1d)(jnp_arr)
    assert jnp.isclose(jnp.asarray(torch_ans.numpy().tolist()), j1d).all() == True, f'output mismatch for {n}x{c}x{k} with kernel {pool_size}' 


#### Test CxK ####
for i in range(100):

    c,k = random.randint(1, 10), random.randint(1, 10)
    pool_size = random.randint(1, 10)
    arr = torch.rand(c,k)
    
    torch_ans = torch.nn.AdaptiveAvgPool1d(pool_size)(arr)

    jnp_arr = jnp.asarray(arr.numpy().tolist())
    eq1d = AdaptiveAvgPool1d(target_size=pool_size)
    j1d = jax.vmap(eq1d)(jnp_arr)
    assert jnp.isclose(jnp.asarray(torch_ans.numpy().tolist()), j1d).all() == True, f'output mismatch for {c}x{k} with kernel {pool_size}'

At the moment I don't see how eqx.nn.Pool can be used as a base class for AdaptiveAvgPool1d.
Also, the iteration bit of computing average can be optimized further?

@patrick-kidger
Copy link
Owner

At first glance this looks mostly reasonable!

Indeed the iteration looks quite expensive; I suspect that will be quite slow at runtime. Perhaps it may be possible to group x into pieces of different lengths, and then vmap the pooling operation over all pieces of the same length?

If necessary then I don't think it's important we exactly match PyTorch here, so for example we could e.g. put all the shorter pieces on the left of x and all the longer pieces on the right, so that separating them out is itself an efficient operation. (I'm not sure what PyTorch does here.)

A few nits: (1) it'd be best to only support a single dimensionality for x; courtesy of vmap etc. we don't need our operations to handle optional batch dimensions like PyTorch does. (2) the start_points and end_points code could probably be done with a jnp.linspace (thus avoiding the messy dtype stuff). (3) Don't use input_shape[-1]; much better to unpack e.g. _, chanenls = input_shape, which is more readable.

@paganpasta
Copy link
Contributor Author

paganpasta commented Jul 8, 2022

class AdaptiveAvgPool1d(eqx.Module):
    target_size: int

    def __init__(self, *,target_size):
        self.target_size = target_size
    
    def __call__(self, x):
        assert x.ndim==1, 'Only supports 1D input' 
        channels = jnp.size(x)
        assert channels >= self.target_size, f'Final Pooled size {self.target_size} cannot be greater than input size {channels}'

        splits = jnp.array_split(x, self.target_size)
        num_head_arrays = channels%(self.target_size)
        if num_head_arrays:
            head_mean = jax.vmap(jnp.mean)(jnp.asarray(splits[:num_head_arrays]))
            tail_mean = jax.vmap(jnp.mean)(jnp.asarray(splits[num_head_arrays:]))
            mean = jnp.concatenate([head_mean, tail_mean])
        else:
            mean = jax.vmap(jnp.mean)(jnp.asarray(splits))
        return mean

Agreed, there's no requirement to closely follow Pytorch implementation. I have added an extra check to make sure final_size < num_channels.

@patrick-kidger
Copy link
Owner

This looks pretty good. Do you want to open a PR for this?

(There's a few nits I'll comment on, but I'll do that in the PR rather than here.)

@paganpasta
Copy link
Contributor Author

Merged in v0.5.6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants