Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preprocessing for pretrained models? #39

Closed
jcjohnson opened this issue Jan 22, 2017 · 15 comments
Closed

Preprocessing for pretrained models? #39

jcjohnson opened this issue Jan 22, 2017 · 15 comments

Comments

@jcjohnson
Copy link

What kind of image preprocessing is expected for the pretrained models? I couldn't find this documented anywhere.

If I had to guess I would assume that they expect RGB images with the mean/std normalization used in fb.resnet.torch and pytorch/examples/imagenet. Is this correct?

@soumith
Copy link
Member

soumith commented Jan 24, 2017

yes, the mean/std normalization that is used in pytorch/examples/imagenet is what is expected. I'll document it now.

@Atcold
Copy link

Atcold commented Feb 8, 2017

@soumith, are you referring to this documentation -> http://pytorch.org/docs/torchvision/models.html
I cannot find any reference to preprocessing the images.
I think the network object should have a preprocessing attribute, where those values are stored. Moreover, they should also have a classes attribute, that let you go from the output max index to the class name.
As they are right now they are hardly usable.
Finally, most of the times, these nets are retrained, so it would be nice to have a method which allows you to replace the final classifier.

Here is a link to the required preprocessing -> https://github.com/pytorch/examples/blob/master/imagenet/main.py#L92-L93

@soumith
Copy link
Member

soumith commented Mar 18, 2017

documented in the README of vision now.

https://github.com/pytorch/vision/blob/master/README.rst#models

@soumith soumith closed this as completed Mar 18, 2017
@jianchao-li
Copy link

jianchao-li commented Jul 10, 2018

Reply for easy reference

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
  • For training images
preprocessing = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])
  • For validation images
preprocessing = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

@youkaichao
Copy link
Contributor

is it better that we can keep the mean and std inside the torchvision models? It is annoying to keep some magic numbers inside the code.

@fmassa
Copy link
Member

fmassa commented Jul 30, 2018

@youkaichao this is a good point, and the pre-trained models should have something like that.
But that's not all of it, as there are other underlying assumptions that are made as well that should be known (image is RGB in 0-1 range, even though that's the current default in PyTorch).
But I'm open to suggestions. I'm not sure where we should include such information: should it be in the state_dict of the serialized models (that can be read specially some mechanism)? Should it be hard-coded in the model implementation?

@youkaichao
Copy link
Contributor

@fmassa how about registering mean and std as a buffer?
As for the input range, I think you can print out a line that says "accepted images are in range [0, 1]" at initialization.

@fmassa
Copy link
Member

fmassa commented Jul 30, 2018

Registering them as a buffer is an option, but that also means that we would either need to change the way we do image normalization (which is currently handled in a transform) and do it in the model, or find a way of loading the state dict into a transform.

Both solutions are backwards-incompatible, so I'm not very happy with them...

@youkaichao
Copy link
Contributor

@fmassa you can add a parameter at __init__ like pre_process=False, the default value for backwards compatibility, and if pre_process==True, use the registered buffer. this way, users can use pre-defined preprocessing just by setting a boolean flag, which seems much better than searching for the exact mean and std value everywhere

@fmassa
Copy link
Member

fmassa commented Jul 30, 2018

well, the good thing about torchvision models is that (almost) all of them have the same pre-processing values.

Also, it's a bit more involved than that, because before one could just load the model using load_state_dict, but now if we add extra buffers, old users might need to load it using strict=False, or else their loading part will crash.

@gursimar
Copy link

Hi, I want to extract features from pre-trained resnet pool5 and res5c layer.
I'm using extracted frames (RGB values) from the TGIF-QA dataset (gifs).

  1. Should I transform my image using the values specified above?
  2. I'm using the following preprocessing. Does this okay for my purpose?
loader = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

@fmassa
Copy link
Member

fmassa commented Oct 30, 2018

@gursimar yes, it should be fine

@yashrathi-git
Copy link

yashrathi-git commented Mar 2, 2022

Hey @Atcold the link no longer works. I still cannot find the documentation for the pre-processing transforms used for various pre-trained models in torchvision. I think the transforms should be included with the model.
Would I get better performance if while fine-tuning I use the same transforms or it doesn't matters?

Do all pretrained models in torchvision use the same pre-processing transforms as described by jianchao-li?

@Atcold
Copy link

Atcold commented Mar 7, 2022

The new link -> https://pytorch.org/vision/stable/models.html

@datumbox
Copy link
Contributor

datumbox commented Mar 7, 2022

I think the transforms should be included with the model.

They are on the new Multi-weights API. Currently on prototype and you can read more here:

transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),

We plan to roll it out within the next couple of weeks on main TorchVision. We have dedicated issue for feedback.

rajveerb pushed a commit to rajveerb/vision that referenced this issue Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants