Skip to content

Extend functionality of torch.nn.functional.fold to apply to sequences (1D tensors + batch & channels) not just images (2D tensors + batch & channels) #119474

@tom-hehir

Description

@tom-hehir

🚀 The feature, motivation and pitch

torch.nn.functional.unfold can be applied to sequences (1D tensors + batch & channels) and it would be helpful to be able to reverse this operation (up to addition in overlapping regions which can be addressed via a division -- see documentation for torch.nn.functional.fold) using torch.nn.functional.fold, however torch.nn.functional.fold only works on images (2D tensors + batch & channels) so this is not possible at present.
I suggest that torch.nn.functional.fold is generalised so that it can reverse torch.nn.functional.unfold in general (or at least for sequences), not just for images.

Alternatives

I have found a way to do this using Tensor.index_add (sketched out below):

a = torch.tensor(range(1,18))
print(a)

b = a.unfold(0, 5, 2)

out = torch.zeros_like(a)
indices = b.flatten()-1
out = out.index_add(0, indices, b.flatten())
x = out

a = torch.ones((17,))
b = a.unfold(0, 5, 2)
out = torch.zeros_like(a)
out = out.index_add(0, indices, b.flatten())
y = out

z = x / y
print(z)

It would be nice to abstract these details away and just call torch.nn.functiona.fold, and I expect that this kind of approach is far from optimal.

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    To pick up

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions