Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollup: No-batch-dim support for torch.nn modules #60585

Closed
jbschlosser opened this issue Jun 23, 2021 · 7 comments
Closed

Rollup: No-batch-dim support for torch.nn modules #60585

jbschlosser opened this issue Jun 23, 2021 · 7 comments
Labels
module: batching module: nn Related to torch.nn tracker A tracking issue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects

Comments

@jbschlosser
Copy link
Contributor

jbschlosser commented Jun 23, 2021

Background

A previous issue (#47149) requested support of arbitrary batch dimensions across module inputs. This rollup issue details the torch.nn module updates required to address a subset of this functionality: specifically, the case of no batch dimensions. This particular case is useful for composability with a future vmap.

Module support

  • nn.AdaptiveAvgPool1d - Should be updated to support 2D inputs
  • nn.AdaptiveAvgPool2d - Supports 3D inputs (no batch), but docs should make this clear
  • nn.AdaptiveAvgPool3d - Supports 4D inputs (no batch), but docs should make this clear
  • nn.AdaptiveLogSoftmaxWithLoss - Should be updated to support 1D input and scalar target / output1 (@george-qi)
  • nn.AdaptiveMaxPool1d - Should be updated to support 2D inputs
  • nn.AdaptiveMaxPool2d - Supports 3D inputs (no batch), but docs should make this clear
  • nn.AdaptiveMaxPool3d - Supports 4D inputs (no batch), but docs should make this clear
  • nn.AlphaDropout
  • nn.AvgPool1d - Should be updated to support 2D inputs
  • nn.AvgPool2d - Supports 3D inputs (no batch), but docs should make this clear
  • nn.AvgPool3d - Supports 4D inputs (no batch), but docs should make this clear
  • nn.BCELoss
  • nn.BCEWithLogitsLoss
  • nn.Bilinear - Already supports no batch dims, but docs should make this clear (@george-qi)
  • nn.CELU
  • nn.CTCLoss - Should be updated to support a single (source, target) sequence (@george-qi)
  • nn.ConstantPad1d
  • nn.ConstantPad2d
  • nn.ConstantPad3d
  • nn.Conv1d - Should be updated to support 2D input (@jbschlosser )
  • nn.Conv2d - Should be updated to support 3D input (@jbschlosser )
  • nn.Conv3d - Should be updated to support 4D input (@jbschlosser )
  • nn.ConvTranspose1d - Should be updated to support 2D input (@jbschlosser )
  • nn.ConvTranspose2d - Should be updated to support 3D input (@jbschlosser )
  • nn.ConvTranspose3d - Should be updated to support 4D input (@jbschlosser )
  • nn.CosineEmbeddingLoss - Should be updated to support 1D input with scalar target (@kshitij12345 )
  • nn.CosineSimilarity (pass dim=0 for no-batch-dim support)
  • nn.CrossEntropyLoss - Should be updated to support 1D input with scalar target (@kshitij12345 )
  • nn.Dropout
  • nn.Dropout2d
  • nn.Dropout3d
  • nn.ELU
  • nn.Embedding
  • nn.FeatureAlphaDropout - Already supports no batch dims, but docs are missing
  • nn.Flatten (pass start_dim=0 for no-batch-dim support)
  • nn.Fold - Should be updated to support 3D inputs (@kshitij12345)
  • nn.FractionalMaxPool2d - Should be updated to support 3D inputs
  • nn.FractionalMaxPool3d - Should be updated to support 4D inputs (@george-qi)
  • nn.GELU
  • nn.GLU
  • nn.GRU - Should be updated to support a single input sequence / hidden state; batch_first arg is meaningless in this case (@kshitij12345)
  • nn.GRUCell - Should be updated to support 1D input / hidden state (@kshitij12345)
  • nn.GaussianNLLLoss - Already supports no batch dims, but docs should make this clear (@george-qi)
  • nn.Hardshrink
  • nn.Hardsigmoid
  • nn.Hardswish
  • nn.Hardtanh
  • nn.HingeEmbeddingLoss
  • nn.HuberLoss
  • nn.Identity
  • nn.InstanceNorm1d - Should be updated to support 1D or 2D inputs (@kshitij12345)
  • nn.InstanceNorm2d - Should be updated to support 2D or 3D inputs (@kshitij12345)
  • nn.InstanceNorm3d - Should be updated to support 3D or 4D inputs (@kshitij12345)
  • nn.KLDivLoss
  • nn.L1Loss
  • nn.LPPool1d - Should be updated to support 2D inputs
  • nn.LPPool2d
  • nn.LSTM - Should be updated to support single input, hidden state, cell state; batch_first arg is meaningless in this case (@kshitij12345)
  • nn.LSTMCell - Should be updated to support 1D input, hidden state, cell state (@kshitij12345)
  • nn.LayerNorm (pass correctly-set normalized_shape for no-batch-dim support)
  • nn.LazyConv1d - Should be updated to support 2D input (@jbschlosser )
  • nn.LazyConv2d - Should be updated to support 3D input (@jbschlosser )
  • nn.LazyConv3d - Should be updated to support 4D input (@jbschlosser )
  • nn.LazyConvTranspose1d - Should be updated to support 2D input (@jbschlosser )
  • nn.LazyConvTranspose2d - Should be updated to support 3D input (@jbschlosser )
  • nn.LazyConvTranspose3d - Should be updated to support 4D input (@jbschlosser )
  • nn.LazyLinear - Already supports no batch dims, but docs should make this clear
  • nn.LeakyReLU
  • nn.Linear - Already supports no batch dims, but docs should make this clear
  • nn.LogSigmoid
  • nn.LogSoftmax
  • nn.MSELoss - Already supports no batch dims, but docs should make this clear
  • nn.MarginRankingLoss - Should be updated to support scalars (@kshitij12345)
  • nn.MaxPool1d - Already supports no batch dims when return_indices=False, but need to support return_indices=True
  • nn.MaxPool2d - Already supports no batch dims when return_indices=False, but need to support return_indices=True
  • nn.MaxPool3d - Already supports no batch dims when return_indices=False, but need to support return_indices=True
  • nn.MaxUnpool1d - Should be updated to support 2D input / indices
  • nn.MaxUnpool2d - Should be updated to support 3D input / indices
  • nn.MaxUnpool3d - Should be updated to support 4D input / indices
  • nn.Mish
  • nn.MultiLabelMarginLoss
  • nn.MultiLabelSoftMarginLoss - Should be updated to support 1D input / target
  • nn.MultiMarginLoss - Already supports no batch dims, but docs should make this clear
  • nn.MultiheadAttention - Should be updated to support unbatched query, key, values, and masks; batch_first arg is meaningless in this case
  • nn.NLLLoss - Should be updated to support 1D input and scalar target
  • nn.PairwiseDistance - Should be updated to support 1D inputs with scalar output
  • nn.PixelShuffle
  • nn.PixelUnshuffle
  • nn.PoissonNLLLoss - Already supports no batch dims, but docs should make this clear
  • nn.RNN - Should be updated to support a single input sequence / hidden state; batch_first arg is meaningless in this case (@kshitij12345)
  • nn.RNNCell - Should be updated to support 1D input and hidden state (@kshitij12345)
  • nn.RReLU
  • nn.ReLU
  • nn.ReLU6
  • nn.ReflectionPad1d - Should be updated to support 2D inputs
  • nn.ReflectionPad2d
  • nn.ReflectionPad3d
  • nn.ReplicationPad1d - Should be updated to support 2D inputs
  • nn.ReplicationPad2d
  • nn.ReplicationPad3d
  • nn.SELU
  • nn.SiLU
  • nn.Sigmoid
  • nn.SmoothL1Loss - Already supports no batch dims, but docs should make this clear
  • nn.SoftMarginLoss
  • nn.Softmax
  • nn.Softmax2d - Should be updated to support 3D input
  • nn.Softmin
  • nn.Softplus
  • nn.Softshrink
  • nn.Softsign
  • nn.Tanh
  • nn.Tanhshrink
  • nn.Threshold
  • nn.Transformer - Should be updated to support unbatched source / target sequences and masks (@kshitij12345)
  • nn.TransformerDecoderLayer - Should be updated to support unbatched inputs (@kshitij12345)
  • nn.TransformerEncoderLayer - Should be updated to support unbatched inputs (@kshitij12345)
  • nn.TripletMarginLoss - Should be updated to support 1D anchor, positive, and negative (@kshitij12345)
  • nn.TripletMarginWithDistanceLoss (support is based on the choice of distance function)
  • nn.Unflatten - Already supports no batch dims, but docs should make this clear
  • nn.ZeroPad2d

Semantically incompatible modules

  • nn.BatchNorm1d - Only defined over batch
  • nn.BatchNorm2d - Only defined over batch
  • nn.BatchNorm3d - Only defined over batch
  • nn.LazyBatchNorm1d - Only defined over batch
  • nn.LazyBatchNorm2d - Only defined over batch
  • nn.LazyBatchNorm3d - Only defined over batch
  • nn.SyncBatchNorm - Only defined over batch

Modules that require BC-breaking changes

  • nn.ChannelShuffle - While the docs say that the module accepts shape (*, C, H, W), the implementation assumes shape (N, C, *) with * being 1 or more dimensions; switching to (C, *) is BC-breaking because it would reinterpret dims
  • nn.EmbeddingBag - Already supports 1D inputs with different semantics
  • nn.GroupNorm - Supports arbitrary spatial dims; switching to (C, *) instead of (N, C, *) is BC-breaking since it would reinterpret dims
  • nn.LocalResponseNorm - Supports (N, C, *) shape now; switching to (C, *) is BC-breaking since it would reinterpret dims
  • nn.PReLU - Always assumes the 2nd dim is the channels dim
  • nn.Unfold - Supports 4D inputs; switching to (C, *) instead of (N, C, *) is BC-breaking since it would reinterpret dims
  • nn.Upsample - Supporting unbatched inputs would be BC-breaking because it would reinterpret dims

Irrelevant modules

  • nn.Module
  • nn.ModuleDict
  • nn.ModuleList
  • nn.ParameterDict
  • nn.ParameterList
  • nn.Sequential
  • nn.TransformerDecoder - see nn.TransformerDecoderLayer instead
  • nn.TransformerEncoder - see nn.TransformerEncoderLayer instead

cc @albanD @mruberry @jbschlosser

@jbschlosser jbschlosser added this to To Do in torch.nn Jun 23, 2021
@VitalyFedyunin VitalyFedyunin added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 24, 2021
facebook-github-bot pushed a commit that referenced this issue Jul 1, 2021
Summary:
Towards #60585

This PR updates docs for `Linear` and adds a non-batch test case to `common_nn.py`.

Pull Request resolved: #60992

Reviewed By: VitalyFedyunin

Differential Revision: D29518451

Pulled By: jbschlosser

fbshipit-source-id: 6dd79c0f21ac5b6f693e3e1ba954379d2606d4e0
@thomasjpfan
Copy link
Contributor

What does the semantics of no-batch-dim for criterion look like when there is a reduction? Consider:

import torch
import torch.nn as nn

loss_mean = nn.MSELoss(reduction='mean')
input = torch.randn(3, 5, 5, requires_grad=True)
target = torch.randn(3, 5, 5)

torch.isclose(
    loss_mean(input, target),
    # Need to scale the sum down by the number of batches
    sum(loss_mean(input[i], target[i]) for i in range(3)) / 3
)

@jbschlosser
Copy link
Contributor Author

What does the semantics of no-batch-dim for criterion look like when there is a reduction? Consider:

Great question - I'd say that "reduction" only has non-trivial meaning when being applied over a batch. If there's only a single item in an implicit "batch" (i.e. the no batch dim case), I don't see a problem with the output being equivalent across all reduction types.

Note that you could tweak your example to pass 3 inputs of shape (1, 5, 5) and you'd run into the same thing.

facebook-github-bot pushed a commit that referenced this issue Jul 12, 2021
Summary:
Towards #60585

Pull Request resolved: #61264

Reviewed By: iramazanli

Differential Revision: D29615292

Pulled By: jbschlosser

fbshipit-source-id: 826d1c87d67261a7211270e90e3a1022bbbe37bd
facebook-github-bot pushed a commit that referenced this issue Jul 16, 2021
…61300)

Summary:
Towards #60585

This PR updates docs and tests for activation modules that already support no-batch dims.

Pull Request resolved: #61300

Reviewed By: heitorschueroff

Differential Revision: D29660543

Pulled By: jbschlosser

fbshipit-source-id: 5edad45f7e9995aca6c3403469668e6e1cbb94b6
facebook-github-bot pushed a commit that referenced this issue Jul 19, 2021
…atch (#61262)

Summary:
Toward #60585

Pull Request resolved: #61262

Reviewed By: mrshenli

Differential Revision: D29660554

Pulled By: jbschlosser

fbshipit-source-id: d5e3dc7096fcf8621bce4a1063d521b84092e0ca
facebook-github-bot pushed a commit that referenced this issue Jul 21, 2021
Summary:
Toward #60585

This PR adds a `single_batch_reference_fn` that uses the single batch implementation to check no-batch.

Pull Request resolved: #61060

Reviewed By: mrshenli

Differential Revision: D29739823

Pulled By: jbschlosser

fbshipit-source-id: d90d88a3671177a647171801cc6ec7aa3df35482
facebook-github-bot pushed a commit that referenced this issue Jul 21, 2021
Summary:
Towards #60585

I think `Dropout` is already tested in `test_Dropout` for no batch dims.

Pull Request resolved: #61911

Reviewed By: albanD

Differential Revision: D29810928

Pulled By: jbschlosser

fbshipit-source-id: 7716a1a808e9e34aae43573f38706212552afbb4
facebook-github-bot pushed a commit that referenced this issue Jul 22, 2021
Summary:
Towards #60585

Pull Request resolved: #61860

Reviewed By: albanD

Differential Revision: D29826382

Pulled By: jbschlosser

fbshipit-source-id: 47e12073d866f0604310fc1ff270cde9907e516d
facebook-github-bot pushed a commit that referenced this issue Jul 23, 2021
…61984)

Summary:
Towards #60585

(Interesting how the maxpool tests are currently in `test/test_nn.py`)

Pull Request resolved: #61984

Reviewed By: suo

Differential Revision: D29883846

Pulled By: jbschlosser

fbshipit-source-id: 1e0637c96f8fa442b4784a9865310c164cbf61c8
facebook-github-bot pushed a commit that referenced this issue Jul 23, 2021
…t no-batch-dims (#61461)

Summary:
Towards #60585

This PR does not use `check_sum_reduction` because I wanted to test every reduction option.

Pull Request resolved: #61461

Reviewed By: suo

Differential Revision: D29883744

Pulled By: jbschlosser

fbshipit-source-id: cdad0effb41f0484938caad0d4c9d6d83e2aec07
torch.nn automation moved this from Needs Triage to Done Jan 7, 2022
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jan 7, 2022

Oops! Had incorrectly marked the PR (GRU and RNN) with Fixes

@kshitij12345 kshitij12345 reopened this Jan 7, 2022
torch.nn automation moved this from Done to Needs Triage Jan 7, 2022
facebook-github-bot pushed a commit that referenced this issue Jan 7, 2022
Summary:
Reference #60585

Reland: #70442

Pull Request resolved: #70977

Reviewed By: dagitses, george-qi

Differential Revision: D33477256

Pulled By: jbschlosser

fbshipit-source-id: 2035c2d00b2f627c7046fd9b13c71b9360cd6fad
@github-actions github-actions bot closed this as completed Jan 7, 2022
torch.nn automation moved this from Needs Triage to Done Jan 7, 2022
@jbschlosser
Copy link
Contributor Author

Wrongly closed again

@jbschlosser jbschlosser reopened this Jan 10, 2022
torch.nn automation moved this from Done to Needs Triage Jan 10, 2022
facebook-github-bot pushed a commit that referenced this issue Jan 13, 2022
Summary:
Reference: #60585

cc albanD mruberry jbschlosser walterddr kshitij12345

Pull Request resolved: #71055

Reviewed By: anjali411

Differential Revision: D33567403

Pulled By: jbschlosser

fbshipit-source-id: 4d0a311ad7419387c4547e43e533840c8b6d09d8
facebook-github-bot pushed a commit that referenced this issue Jan 24, 2022
Summary:
Reference: #60585

TODO:
* [x] Update docs

Pull Request resolved: #71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
pytorchmergebot pushed a commit that referenced this issue Jan 24, 2022
Summary:
Reference: #60585

TODO:
* [x] Update docs

Pull Request resolved: #71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d584)
@atalman atalman closed this as completed Jan 26, 2022
torch.nn automation moved this from Needs Triage to Done Jan 26, 2022
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Reference: pytorch/pytorch#60585

TODO:
* [x] Update docs

Pull Request resolved: pytorch/pytorch#71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d584)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Reference: pytorch/pytorch#60585

TODO:
* [x] Update docs

Pull Request resolved: pytorch/pytorch#71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d584)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Reference: pytorch/pytorch#60585

TODO:
* [x] Update docs

Pull Request resolved: pytorch/pytorch#71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d584)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Reference: pytorch/pytorch#60585

TODO:
* [x] Update docs

Pull Request resolved: pytorch/pytorch#71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d584)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: batching module: nn Related to torch.nn tracker A tracking issue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
torch.nn
  
Done
5 participants