New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rollup: No-batch-dim support for torch.nn modules #60585
Comments
What does the semantics of no-batch-dim for criterion look like when there is a reduction? Consider: import torch
import torch.nn as nn
loss_mean = nn.MSELoss(reduction='mean')
input = torch.randn(3, 5, 5, requires_grad=True)
target = torch.randn(3, 5, 5)
torch.isclose(
loss_mean(input, target),
# Need to scale the sum down by the number of batches
sum(loss_mean(input[i], target[i]) for i in range(3)) / 3
) |
Great question - I'd say that "reduction" only has non-trivial meaning when being applied over a batch. If there's only a single item in an implicit "batch" (i.e. the no batch dim case), I don't see a problem with the output being equivalent across all reduction types. Note that you could tweak your example to pass 3 inputs of shape |
…t no-batch-dims (#61461) Summary: Towards #60585 This PR does not use `check_sum_reduction` because I wanted to test every reduction option. Pull Request resolved: #61461 Reviewed By: suo Differential Revision: D29883744 Pulled By: jbschlosser fbshipit-source-id: cdad0effb41f0484938caad0d4c9d6d83e2aec07
Oops! Had incorrectly marked the PR (GRU and RNN) with Fixes |
Wrongly closed again |
Summary: Reference: pytorch/pytorch#60585 TODO: * [x] Update docs Pull Request resolved: pytorch/pytorch#71056 Reviewed By: samdow Differential Revision: D33638643 Pulled By: jbschlosser fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3 (cherry picked from commit f94d584)
Summary: Reference: pytorch/pytorch#60585 TODO: * [x] Update docs Pull Request resolved: pytorch/pytorch#71056 Reviewed By: samdow Differential Revision: D33638643 Pulled By: jbschlosser fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3 (cherry picked from commit f94d584)
Summary: Reference: pytorch/pytorch#60585 TODO: * [x] Update docs Pull Request resolved: pytorch/pytorch#71056 Reviewed By: samdow Differential Revision: D33638643 Pulled By: jbschlosser fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3 (cherry picked from commit f94d584)
Summary: Reference: pytorch/pytorch#60585 TODO: * [x] Update docs Pull Request resolved: pytorch/pytorch#71056 Reviewed By: samdow Differential Revision: D33638643 Pulled By: jbschlosser fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3 (cherry picked from commit f94d584)
Background
A previous issue (#47149) requested support of arbitrary batch dimensions across module inputs. This rollup issue details the
torch.nn
module updates required to address a subset of this functionality: specifically, the case of no batch dimensions. This particular case is useful for composability with a future vmap.Module support
nn.AdaptiveAvgPool1d
- Should be updated to support 2D inputsnn.AdaptiveAvgPool2d
- Supports 3D inputs (no batch), but docs should make this clearnn.AdaptiveAvgPool3d
- Supports 4D inputs (no batch), but docs should make this clearnn.AdaptiveLogSoftmaxWithLoss
- Should be updated to support 1D input and scalar target / output1 (@george-qi)nn.AdaptiveMaxPool1d
- Should be updated to support 2D inputsnn.AdaptiveMaxPool2d
- Supports 3D inputs (no batch), but docs should make this clearnn.AdaptiveMaxPool3d
- Supports 4D inputs (no batch), but docs should make this clearnn.AlphaDropout
nn.AvgPool1d
- Should be updated to support 2D inputsnn.AvgPool2d
- Supports 3D inputs (no batch), but docs should make this clearnn.AvgPool3d
- Supports 4D inputs (no batch), but docs should make this clearnn.BCELoss
nn.BCEWithLogitsLoss
nn.Bilinear
- Already supports no batch dims, but docs should make this clear (@george-qi)nn.CELU
nn.CTCLoss
- Should be updated to support a single (source, target) sequence (@george-qi)nn.ConstantPad1d
nn.ConstantPad2d
nn.ConstantPad3d
nn.Conv1d
- Should be updated to support 2D input (@jbschlosser )nn.Conv2d
- Should be updated to support 3D input (@jbschlosser )nn.Conv3d
- Should be updated to support 4D input (@jbschlosser )nn.ConvTranspose1d
- Should be updated to support 2D input (@jbschlosser )nn.ConvTranspose2d
- Should be updated to support 3D input (@jbschlosser )nn.ConvTranspose3d
- Should be updated to support 4D input (@jbschlosser )nn.CosineEmbeddingLoss
- Should be updated to support 1D input with scalar target (@kshitij12345 )nn.CosineSimilarity
(passdim=0
for no-batch-dim support)nn.CrossEntropyLoss
- Should be updated to support 1D input with scalar target (@kshitij12345 )nn.Dropout
nn.Dropout2d
nn.Dropout3d
nn.ELU
nn.Embedding
nn.FeatureAlphaDropout
- Already supports no batch dims, but docs are missingnn.Flatten
(passstart_dim=0
for no-batch-dim support)nn.Fold
- Should be updated to support 3D inputs (@kshitij12345)nn.FractionalMaxPool2d
- Should be updated to support 3D inputsnn.FractionalMaxPool3d
- Should be updated to support 4D inputs (@george-qi)nn.GELU
nn.GLU
nn.GRU
- Should be updated to support a single input sequence / hidden state;batch_first
arg is meaningless in this case (@kshitij12345)nn.GRUCell
- Should be updated to support 1D input / hidden state (@kshitij12345)nn.GaussianNLLLoss
- Already supports no batch dims, but docs should make this clear (@george-qi)nn.Hardshrink
nn.Hardsigmoid
nn.Hardswish
nn.Hardtanh
nn.HingeEmbeddingLoss
nn.HuberLoss
nn.Identity
nn.InstanceNorm1d
- Should be updated to support 1D or 2D inputs (@kshitij12345)nn.InstanceNorm2d
- Should be updated to support 2D or 3D inputs (@kshitij12345)nn.InstanceNorm3d
- Should be updated to support 3D or 4D inputs (@kshitij12345)nn.KLDivLoss
nn.L1Loss
nn.LPPool1d
- Should be updated to support 2D inputsnn.LPPool2d
nn.LSTM
- Should be updated to support single input, hidden state, cell state;batch_first
arg is meaningless in this case (@kshitij12345)nn.LSTMCell
- Should be updated to support 1D input, hidden state, cell state (@kshitij12345)nn.LayerNorm
(pass correctly-setnormalized_shape
for no-batch-dim support)nn.LazyConv1d
- Should be updated to support 2D input (@jbschlosser )nn.LazyConv2d
- Should be updated to support 3D input (@jbschlosser )nn.LazyConv3d
- Should be updated to support 4D input (@jbschlosser )nn.LazyConvTranspose1d
- Should be updated to support 2D input (@jbschlosser )nn.LazyConvTranspose2d
- Should be updated to support 3D input (@jbschlosser )nn.LazyConvTranspose3d
- Should be updated to support 4D input (@jbschlosser )nn.LazyLinear
- Already supports no batch dims, but docs should make this clearnn.LeakyReLU
nn.Linear
- Already supports no batch dims, but docs should make this clearnn.LogSigmoid
nn.LogSoftmax
nn.MSELoss
- Already supports no batch dims, but docs should make this clearnn.MarginRankingLoss
- Should be updated to support scalars (@kshitij12345)nn.MaxPool1d
- Already supports no batch dims whenreturn_indices=False
, but need to supportreturn_indices=True
nn.MaxPool2d
- Already supports no batch dims whenreturn_indices=False
, but need to supportreturn_indices=True
nn.MaxPool3d
- Already supports no batch dims whenreturn_indices=False
, but need to supportreturn_indices=True
nn.MaxUnpool1d
- Should be updated to support 2D input / indicesnn.MaxUnpool2d
- Should be updated to support 3D input / indicesnn.MaxUnpool3d
- Should be updated to support 4D input / indicesnn.Mish
nn.MultiLabelMarginLoss
nn.MultiLabelSoftMarginLoss
- Should be updated to support 1D input / targetnn.MultiMarginLoss
- Already supports no batch dims, but docs should make this clearnn.MultiheadAttention
- Should be updated to support unbatched query, key, values, and masks;batch_first
arg is meaningless in this casenn.NLLLoss
- Should be updated to support 1D input and scalar targetnn.PairwiseDistance
- Should be updated to support 1D inputs with scalar outputnn.PixelShuffle
nn.PixelUnshuffle
nn.PoissonNLLLoss
- Already supports no batch dims, but docs should make this clearnn.RNN
- Should be updated to support a single input sequence / hidden state;batch_first
arg is meaningless in this case (@kshitij12345)nn.RNNCell
- Should be updated to support 1D input and hidden state (@kshitij12345)nn.RReLU
nn.ReLU
nn.ReLU6
nn.ReflectionPad1d
- Should be updated to support 2D inputsnn.ReflectionPad2d
nn.ReflectionPad3d
nn.ReplicationPad1d
- Should be updated to support 2D inputsnn.ReplicationPad2d
nn.ReplicationPad3d
nn.SELU
nn.SiLU
nn.Sigmoid
nn.SmoothL1Loss
- Already supports no batch dims, but docs should make this clearnn.SoftMarginLoss
nn.Softmax
nn.Softmax2d
- Should be updated to support 3D inputnn.Softmin
nn.Softplus
nn.Softshrink
nn.Softsign
nn.Tanh
nn.Tanhshrink
nn.Threshold
nn.Transformer
- Should be updated to support unbatched source / target sequences and masks (@kshitij12345)nn.TransformerDecoderLayer
- Should be updated to support unbatched inputs (@kshitij12345)nn.TransformerEncoderLayer
- Should be updated to support unbatched inputs (@kshitij12345)nn.TripletMarginLoss
- Should be updated to support 1D anchor, positive, and negative (@kshitij12345)nn.TripletMarginWithDistanceLoss
(support is based on the choice of distance function)nn.Unflatten
- Already supports no batch dims, but docs should make this clearnn.ZeroPad2d
Semantically incompatible modules
nn.BatchNorm1d
- Only defined over batchnn.BatchNorm2d
- Only defined over batchnn.BatchNorm3d
- Only defined over batchnn.LazyBatchNorm1d
- Only defined over batchnn.LazyBatchNorm2d
- Only defined over batchnn.LazyBatchNorm3d
- Only defined over batchnn.SyncBatchNorm
- Only defined over batchModules that require BC-breaking changes
nn.ChannelShuffle
- While the docs say that the module accepts shape (*, C, H, W), the implementation assumes shape (N, C, *) with * being 1 or more dimensions; switching to (C, *) is BC-breaking because it would reinterpret dimsnn.EmbeddingBag
- Already supports 1D inputs with different semanticsnn.GroupNorm
- Supports arbitrary spatial dims; switching to (C, *) instead of (N, C, *) is BC-breaking since it would reinterpret dimsnn.LocalResponseNorm
- Supports (N, C, *) shape now; switching to (C, *) is BC-breaking since it would reinterpret dimsnn.PReLU
- Always assumes the 2nd dim is the channels dimnn.Unfold
- Supports 4D inputs; switching to (C, *) instead of (N, C, *) is BC-breaking since it would reinterpret dimsnn.Upsample
- Supporting unbatched inputs would be BC-breaking because it would reinterpret dimsIrrelevant modules
nn.Module
nn.ModuleDict
nn.ModuleList
nn.ParameterDict
nn.ParameterList
nn.Sequential
nn.TransformerDecoder
- seenn.TransformerDecoderLayer
insteadnn.TransformerEncoder
- seenn.TransformerEncoderLayer
insteadcc @albanD @mruberry @jbschlosser
The text was updated successfully, but these errors were encountered: