Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nn.CenterCrop1d and nn.CenterCrop2d #1331

Closed
bodokaiser opened this issue Apr 22, 2017 · 10 comments
Closed

add nn.CenterCrop1d and nn.CenterCrop2d #1331

bodokaiser opened this issue Apr 22, 2017 · 10 comments
Labels
feature A request for a proper, new feature.

Comments

@bodokaiser
Copy link

There are some models which use "crop layers" (e.g. U-Net, ParseNet) so I think it wouldn't be bad to have a CenterCropNd layer.

As F.pad supports negative padding we just need to calculate padding offsets on top.

import torch

from torch.nn import functional as F
from torch.autograd import Variable

def center_crop(x, height, width):
    crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2)
    crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2)

    return F.pad(x, [
        crop_w.ceil().int()[0], crop_w.floor().int()[0],
        crop_h.ceil().int()[0], crop_h.floor().int()[0],
    ])

variable = Variable(torch.randn(1, 3, 60, 40))

print(center_crop(variable, 20, 20).size())
print(center_crop(variable, 20, 40).size())

I can send a PR on interest.

@fmassa
Copy link
Member

fmassa commented Apr 22, 2017

Hum, I think the type of cropping that one might want to apply might be problem-specific (not necessarily only center crop).
Plus, cropping an image can be performed as efficiently as F.pad by indexing the tensor with slices like x[:, :, 3:7, 5:10], so I'm not so sure about the need of a special layer for that.

@bodokaiser
Copy link
Author

What about caffe's crop layer?

@fmassa
Copy link
Member

fmassa commented Apr 22, 2017

Caffe doesn't give you the ability to easily manipulate the tensors, so there is a need of such layer. The same applies for lua torch, as you need to define a module that will compute the operation that can be differentiated, but in pytorch you can just directly manipulate the tensors and the backprop will be performed.
if you check caffe's crop layer, it's almost the same as the snippet I sent in the previous message.

@apaszke
Copy link
Contributor

apaszke commented Apr 24, 2017

Even though it's easy to do it with indexing, I think it'd be better to provide some standard crops in the library. It's much more convenient than messing with indices, and can save making a few errors along the way each time you need to crop.

@andrewgiessel
Copy link
Contributor

+1 to this. I think standard cropping layers would be great!

@bodokaiser
Copy link
Author

bodokaiser commented Apr 24, 2017

I could make a PR this week if there is consensus on API.

@apaszke
Copy link
Contributor

apaszke commented Apr 24, 2017

I think F.center_crop(x, *args) and nn.CenterCrop(*args) (where len(args) == x.dim() - 2) would be the best API for it.

@andrewgiessel
Copy link
Contributor

Is it reasonable to include non-symmetric cropping in this API as well?

@bodokaiser
Copy link
Author

@andrewgiessel See the current proposition at #1349. I think it would be also reasonable to specify center crop in more detail (e.g. you want to crop 10 -> 3 you have one offset 3 and one 4 but you cants specify the order at the moment).
Is this what you thought of?

@soumith soumith added this to Medium Priority in Issue Status Aug 23, 2017
@soumith soumith added this to nn / autograd / torch in Issue Categories Sep 13, 2017
@ezyang ezyang added feature A request for a proper, new feature. and removed enhancement labels Apr 1, 2019
@fmassa
Copy link
Member

fmassa commented Oct 23, 2019

This can be achieved by passing a negative pad in torch.nn.functional.pad.

@fmassa fmassa closed this as completed Oct 23, 2019
zasdfgbnm pushed a commit that referenced this issue Apr 26, 2022
Fixes #1331

Added an entry for device constant
jithunnair-amd pushed a commit that referenced this issue Mar 18, 2024
)

Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature.
Projects
Issue Categories
neural-nets
Issue Status
Medium Priority
Development

No branches or pull requests

5 participants