Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upsample 3d, subsample 2d and 3d modules #1348

Closed
wants to merge 1 commit into from

Conversation

lantiga
Copy link
Contributor

@lantiga lantiga commented Apr 25, 2017

As anticipated on Slack #general on Apr 20, this PR includes

  • THNN and THCUNN implementations of:
    • VolumetricUpsamplingNearest
    • VolumetricUpsamplingTrilinear
    • VolumetricSubsampling
  • nn (and functional) modules:
    • UpsamplingNearest2d, UpsamplingNearest3d (functional: upsample_nearest)
    • UpsamplingTrilinear3d (functional: upsample_trilinear)
    • Subsampling2d, Subsampling3d (functional: subsample)
  • docs and tests for all modules (including gradcheck)

Commits may need to be squashed, in which case I'd need some direction.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 25, 2017

Ok, autopep8 didn't solve the linting issues. I will revert the commit and fix issues by hand.

@andrewgiessel
Copy link
Contributor

Just wanted to comment that I have an open PR that affects changing the inputs to bilinear 2d upsampling (#1317). I the the changes here are mostly independent, but you've changed some stuff that overlaps. I think the major thing you did was change the size parameter to be a tuple in a slightly different place.

cc @apaszke for guidance

@andrewgiessel
Copy link
Contributor

Also, this is cool work! I'm 👍 for sure! This is a informative PR for me to read .cu code related to what I've been doing, so thanks!

@lantiga
Copy link
Contributor Author

lantiga commented Apr 25, 2017

Thank you Andrew! Sorry for not spotting your PR earlier.
I'll take a look at your code (feel free to point me to the relevant lines), it could be good to adopt the same strategy for the new modules.

@andrewgiessel
Copy link
Contributor

@lantiga - Heya Luca. I looked at our branches.

I think the most narrow, technical difference between our branches is primarily the constructors of _UpsamplingBase(Function) and _UpsamplingBase(Module). Therein, I set self.size to a tuple using _pair() if it isn't None. You seem to have done something similar, but at different points.

More generally, my PR assumes 2d upsampling: size and scale_factor can be single ints or pairs (a tuple). I think it makes sense to generalize.

@andrewgiessel
Copy link
Contributor

@lantiga @apaszke I just pushed a change to to #1317 wherein I think I managed to come up with a good way to reconcile these two PRs. I moved all casting of size to _pair(size) out of the parent class, and into the UpsamplingBilinear2d and UpsamplingNearest2d classes themselves. There is no casting in the functions in functional.py. It makes sense to me for each subclass to do this casting, as the length of the tuple depends on the dimensionality of the data.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 28, 2017

Thank you @andrewgiessel, looks good to me! In general
The flexible aspect ratio should be extended to other dimensions. We could do it if/when this PR eventually gets merged. Unfortunately it's a big review already, so it might take a while.

@andrewgiessel
Copy link
Contributor

andrewgiessel commented Apr 28, 2017

@lantiga no problem! I will comment in-line on the places in your PR that need adjusting to be congruent with my PR.

Edit: I guess I was a little prescriptive here. The convention is contingent upon maintainer approval, etc. of course.

return _functions.thnn.UpsamplingNearest2d(size, scale_factor)(input)
if input.dim() == 4:
assert type(size) == int or len(size) == 2, '4D tensors expect size as int or Tuple[int, int]'
return _functions.thnn.UpsamplingNearest2d(_pair(size), scale_factor)(input)

This comment was marked as off-topic.

return _functions.thnn.UpsamplingNearest2d(_pair(size), scale_factor)(input)
elif input.dim() == 5:
assert type(size) == int or len(size) == 3, '5D tensors expect size as int or Tuple[int, int, int]'
return _functions.thnn.UpsamplingNearest3d(_triple(size), scale_factor)(input)

This comment was marked as off-topic.

return _functions.thnn.UpsamplingBilinear2d(size, scale_factor)(input)
assert input.dim() == 4, "4D tensors expected in input"
assert type(size) == int or len(size) == 2, '4D tensors expect size as int or Tuple[int, int]'
return _functions.thnn.UpsamplingBilinear2d(_pair(size), scale_factor)(input)

This comment was marked as off-topic.

"""
assert input.dim() == 5, "5D tensors expected in input"
assert type(size) == int or len(size) == 3, '5D tensors expect size as int or Tuple[int, int, int]'
return _functions.thnn.UpsamplingTrilinear3d(_triple(size), scale_factor)(input)

This comment was marked as off-topic.

>>> inputs = autograd.Variable(torch.randn(1,10,4,4))
>>> F.subsample(inputs, weight, bias, 2, stride=2)
"""
if input.dim() == 4:

This comment was marked as off-topic.

@@ -64,6 +64,8 @@ class UpsamplingNearest2d(_UpsamplingBase):
[torch.FloatTensor of size 1x1x4x4]

"""
def __init__(self, size=None, scale_factor=None):
super(UpsamplingNearest2d, self).__init__(_pair(size), scale_factor)

This comment was marked as off-topic.

@@ -109,6 +111,104 @@ class UpsamplingBilinear2d(_UpsamplingBase):
[torch.FloatTensor of size 1x1x4x4]

"""
def __init__(self, size=None, scale_factor=None):
super(UpsamplingBilinear2d, self).__init__(_pair(size), scale_factor)

This comment was marked as off-topic.


"""
def __init__(self, size=None, scale_factor=None):
super(UpsamplingNearest3d, self).__init__(_triple(size), scale_factor)

This comment was marked as off-topic.


"""
def __init__(self, size=None, scale_factor=None):
super(UpsamplingTrilinear3d, self).__init__(_triple(size), scale_factor)

This comment was marked as off-topic.

@andrewgiessel
Copy link
Contributor

I also didn't do any commenting on the 3d upsampling code, but I think the pattern is clear enough.

@fmassa
Copy link
Member

fmassa commented Apr 28, 2017

Hey, this is a big PR! :)
I quickly skimmed through it, and correct me if I'm wrong but SpatialSubSampling is actually equivalent to Conv2d with groups == nInputChannels == nOutputChannels, but maybe more efficient?

@lantiga
Copy link
Contributor Author

lantiga commented Apr 28, 2017

@andrewgiessel Thank you! This is super-helpful.

Hey, this is a big PR! :)

@fmassa Yeah, I know and I feel bad about it :-) It rabbit-holed from needing those upsampling modules for 3d stuff.

I quickly skimmed through it, and correct me if I'm wrong but SpatialSubSampling is actually equivalent to Conv2d with groups == nInputChannels == nOutputChannels, but maybe more efficient?

Almost: it's like a Conv2d with groups == nInputChannels == nOutputChannels and kernels in which all weight and bias elements are equal. The dim of weight and bias is 1 and the size is (nInputChannel,).
It's really a local averaging+subsampling filter where you can learn the scale and bias of the average by channel.
SpatialSubSampling has been in torch for a long time (torch7 had a module). It's more memory efficient and has fewer parameters if you just need subsampling. Not sure about speed when CUDNN kicks in.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 28, 2017

@fmassa see also torch/nn#944

@ngimel
Copy link
Collaborator

ngimel commented Apr 28, 2017

I'm sorry I'm late to the party, but isn't VolumetricUpsamplingNearest the same as VolumetricAveragePooling inverse (possibly with a multiplication coefficient, cudnn has that as a parameter). I mean, in forward upsampling you are putting a same input value into a window of an output tensor, which is what updateGradInput of average pooling does, and in updateGradInput of upsampling you are calculating an average (or a sum) of your gradOutput over some window and put it into gradInput, which is what average pooling forward does, so you could just reuse average pooling kernels probably?
And from your description subsampling sounds like average pooling + per-channel learnable scale and shift, which also looks like it does not require special kernels - average pooling is there already, and scale and shift can be implemented as autograd tensor operations.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 28, 2017

@ngimel Aha, I see.
I don't think we can directly use the average pooling kernels without refactoring average pooling modules to expose their kernels for use in forward and backward, am I right?
Maybe we could reuse average pooling only when using CUDNN?

@ngimel
Copy link
Collaborator

ngimel commented Apr 28, 2017

I'm not closely familiar with THCUNN structure, but it definitely should be possible to expose average pooling kernels so that you could call them. I'm just for avoiding code repetition wherever possible:-) You could also call cudnn, that's true.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 28, 2017

You've got a good point for sure.
Both SpatialUpsamplingNearest (which was there prior to this PR) as well as VolumetricUpsamplingNearest could undergo the same treatment, along with changes to THNN and THCUNN to take kernels out of updateOutput and updateGradInput.
I'd go ahead with this work in a separate PR though, this one is already pretty large.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 28, 2017

@andrewgiessel I have a question on the changes: with your strategy, if you call the module directly from functional - e.g. F.upsample_nearest(input,10) - will fail because the casting only happens at the nn/module level.
I bring this up because tests are failing for this reason after the changes.

@andrewgiessel
Copy link
Contributor

@lantiga - I followed the same strategy at the functional level, too. See here, in particular, look at the way I wrote the constructor for the parent and child classes.

You have better tests in this PR that I do, frankly, so I should probably copy some of yours.

@lantiga
Copy link
Contributor Author

lantiga commented Apr 29, 2017

@andrewgiessel I see now, thanks for the pointer.

@andrewgiessel
Copy link
Contributor

@lantiga Note that #1317 was merged into master- so hopefully the conflict resolution won't be too bad! Let me know if I can help

@lantiga
Copy link
Contributor Author

lantiga commented May 5, 2017

Merging conflicts left a few tests failing. I'll fix them asap.

@LukasMosser
Copy link

@lantiga any update on the remaining conflicts? Would be awesome to have this as part of the standard pytorch.

@lantiga
Copy link
Contributor Author

lantiga commented May 27, 2017

I had to take a break from PRs due to work, but I'm back now.

@soumith are you willing to review this one after I fix the conflicts?

@soumith
Copy link
Member

soumith commented May 27, 2017

I'll take a look at this today, make necessary changes and merge it in. you don't have to do any additional changes

@lantiga
Copy link
Contributor Author

lantiga commented May 27, 2017

Great! Thank you @soumith

@soumith
Copy link
Member

soumith commented Jun 7, 2017

superceded by #1676 and a yet to be named PR for subsampling

@soumith soumith closed this Jun 7, 2017
eqy pushed a commit to eqy/pytorch that referenced this pull request Jan 20, 2022
* Refactor War Sync Insertion Pass (pytorch#1339)
* Remove kir::Expr::scope_ (pytorch#1341)
* Fusion IR Refactor (pytorch#1343)
* Refactor KIR Step 1 - Remove kir::Node (pytorch#1347)
* Refactor KIR Step 2 - TMP IrUtils change (pytorch#1348)
* Refactor KIR Step 3 - Remove kir::Expr and kir::Val. (pytorch#1349)
* Refactor KIR Step 4 - Remove kir::Bool,Double,Int,NamedScalar. (pytorch#1350)
* Refactor KIR Step 5 - Remove kir::IterDomain/TensorDomain/TensorView (pytorch#1351)
* Refactor KIR Step 6 - Remove 
 kir::UnaryOp/BinaryOp/TernaryOp/ReductionOp/WelfordOp/BroadcastOp. (pytorch#1352)
* Refactor KIR Step 7 - Remove kir dispatch (pytorch#1353)
* Refactor KIR Step 8 - Clean up lower_utils (pytorch#1355)
* Refactor KIR Step 9 - lower_utils ir_utils::applyReplacements. (pytorch#1354)
* Refactor KIR Step 10 - Remove kir_printer in favor of io_stream (pytorch#1356)
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this pull request Nov 1, 2022
* add warning to pyprof

* add warning to reparameterization

note: this module is already not import-able as follows:

```
(base) root@c4bb3f161482:/vscode/apex# python -c 'import torch; import
apex; from apex import reparameterization'
/vscode/apex/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be
removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022",
FutureWarning)
/vscode/apex/apex/reparameterization/__init__.py:2: FutureWarning:
reparameterization will be removed by the end of June, 2022
  warnings.warn("reparameterization will be removed by the end of June,
2022", FutureWarning)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/vscode/apex/apex/reparameterization/__init__.py", line 4, in
<module>
    from .weight_norm import WeightNorm
  File "/vscode/apex/apex/reparameterization/weight_norm.py", line 3, in
<module>
    from ..fp16_utils import Fused_Weight_Norm
ImportError: cannot import name 'Fused_Weight_Norm' from
'apex.fp16_utils' (/vscode/apex/apex/fp16_utils/__init__.py)
```
pytorch-bot bot pushed a commit that referenced this pull request Jan 18, 2024
pytorch-bot bot pushed a commit that referenced this pull request Jan 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants