Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

.backward() on MSELoss fails with IndexError: Dimension out of range on MPS #79784

Closed
ilyarepko opened this issue Jun 17, 2022 · 2 comments
Closed
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ilyarepko
Copy link

ilyarepko commented Jun 17, 2022

馃悰 Describe the bug

import torch
import torch.nn as nn

loss_fn = torch.nn.MSELoss(reduction="sum")
model = nn.Sequential(
    nn.Linear(2, 2)
)

# Works on CPU
x = torch.tensor([1.0, 1.0], dtype=torch.float32)
y = torch.tensor([2.0, 2.0], dtype=torch.float32)
model.zero_grad()
y_predicted = model(x)
loss = loss_fn(y_predicted, y)
loss.backward()

# Doesn't on MPS
device = torch.device("mps")
x_mps = x.to(device=device)
y_mps = y.to(device=device)
model_mps = model.to(device)
model_mps.zero_grad()
y_predicted = model_mps(x_mps)
loss = loss_fn(y_predicted, y_mps)
loss.backward()
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [48], in <cell line: 25>()
     23 y_predicted = model_mps(x_mps)
     24 loss = loss_fn(y_predicted, y_mps)
---> 25 loss.backward()

File /opt/homebrew/lib/python3.9/site-packages/torch/_tensor.py:400, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    391 if has_torch_function_unary(self):
    392     return handle_torch_function(
    393         Tensor.backward,
    394         (self,),
   (...)
    398         create_graph=create_graph,
    399         inputs=inputs)
--> 400 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File /opt/homebrew/lib/python3.9/site-packages/torch/autograd/__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
    171 # some Python versions print out the first line of a multi-line function
    172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Versions

bash-5.1$ python3 collect_env.py 
Collecting environment information...
PyTorch version: 1.13.0.dev20220614
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: version 3.23.2
Libc version: N/A

Python version: 3.9.13 (main, May 24 2022, 21:13:51)  [Clang 13.1.6 (clang-1316.0.21.2)] (64-bit runtime)
Python platform: macOS-12.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.0rc2
[pip3] torch==1.13.0.dev20220614
[pip3] torchaudio==0.14.0.dev20220603
[pip3] torchvision==0.14.0.dev20220614
[conda] Could not collect

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @kulinseth

@albanD albanD added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels Jun 21, 2022
pytorchmergebot pushed a commit that referenced this issue Jul 4, 2022
facebook-github-bot pushed a commit that referenced this issue Jul 6, 2022
Summary:
Fixes ##79784

Pull Request resolved: #80759
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/0e3953fc52ecab510c17f6d6ecb46655174ab8f5

Reviewed By: mehtanirav

Differential Revision: D37604743

Pulled By: mehtanirav

fbshipit-source-id: 2b0771449666512a8b935d854879f08479fdf5fc
@qqaatw
Copy link
Collaborator

qqaatw commented Jul 6, 2022

I can reproduce. The problem stems from linear backward

def test_79784():
    import torch
    import torch.nn as nn

    device = torch.device("mps")
    weight = torch.nn.Parameter(torch.randn(2, 2))


    x = torch.tensor([1.0, 1.0], dtype=torch.float32)
    out = nn.functional.linear(x, weight)
    grad = torch.autograd.grad(out.sum(), weight)
    
    x_mps = x.to(device=device)
    weight_mps = weight.to(device)
    out_mps = x_mps @ weight_mps.T
    grad_mps = torch.autograd.grad(out_mps.sum(), weight_mps)

    torch.testing.assert_close(out, out_mps, check_device=False)
    torch.testing.assert_close(grad[0], grad_mps[0], check_device=False)

    out_mps = nn.functional.linear(x_mps, weight_mps)
    grad_mps = torch.autograd.grad(out_mps.sum(), weight_mps) # fail

Edited: seems already fixed by #80759

@ilyarepko
Copy link
Author

Yes, I tested it with torch==1.13.0.dev20220705, and now it works. Thanks a lot!

kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Jul 9, 2022
atalman pushed a commit to atalman/pytorch that referenced this issue Jul 22, 2022
atalman added a commit that referenced this issue Jul 25, 2022
* MPS: Fixes (#78930)

Cast integer to float in UnaryOps
Add tensor dtype in key generation
Enable FP16 scalars and use placeholder for alpha tensor in add/sum ops

Fixes #ISSUE_NUMBER

Pull Request resolved: #78930
Approved by: https://github.com/albanD

* MPS: Binary cast fix by proper type promotion and remove spurious copy warning (#79185)

Fixes #78019, #78020
Fixes #79185
Pull Request resolved: #79185
Approved by: https://github.com/albanD, https://github.com/razarmehr

* MPS: add exponential op (#79188)

Add exponential distribution

Fixes #ISSUE_NUMBER

Pull Request resolved: #79188
Approved by: https://github.com/razarmehr, https://github.com/albanD

* [MPS] Delete unused vars from OperationUtils.mm

Pull Request resolved: #79514

Approved by: https://github.com/kulinseth, https://github.com/albanD

* [MPS] Fix getDefaultGenerator and copy_kernel_mps

Returning reference to stack memory is really bad

Pull Request resolved: #79515

Approved by: https://github.com/albanD

* [MPS][BE]Do not use `new/delete[]` in `chainViewOperation`

`std::array` will do just fine

Pull Request resolved: #79516

Approved by: https://github.com/albanD

* [MPS] Support stride of stride

Fixes #79181

Pull Request resolved: #79521

Approved by: https://github.com/kulinseth

* MPS: TopK raise an error if K>16 (#79677)

* Error out in TopK when k>16.
* Add a test case too.

Fixes #78915

Pull Request resolved: #79677
Approved by: https://github.com/albanD

* [MPS]: Add fix for squeezed input axes handling in BCE loss (#79676)

Fixes #79527

Pull Request resolved: #79676
Approved by: https://github.com/razarmehr, https://github.com/albanD

* MPS: Add amax and amin Ops with tests  (#79682)

* Add amax and amin with tests

Fixes #ISSUE_NUMBER

Pull Request resolved: #79682
Approved by: https://github.com/albanD

* [MPS] Fix torch.uint8 support (#80049)

`ScalarType.Byte` should be cast to `MPSDataTypeUInt8`
And support for `torch.int8` as well as test those conversions in `TestMPS.test_to`

Fixes #80006

Pull Request resolved: #80049
Approved by: https://github.com/albanD

* [MPS] Fix binary ops between int32 tensor with int64 scalar (#80220)

For some reason, tensor *op* scalar does not follow the normal binary promotion rules
So cast output tensor to expected type if needed
It seems that one should have casted input tensors to expected output tensor type, but it does not really work for boolean binary ops, so...
Add output tensor type/shape to cached graph key
Extend `TestMPS. test_add_scalars` to test for this regression

Fixes #79835

Pull Request resolved: #80220
Approved by: https://github.com/albanD

* [MPS] Add equal operator (#80195)

Which is, in essence is composite of `eq`->`all`->`item`
`native/mps/operators/Equal.cpp` is an almost verbatim copy of `native/cuda/Equal.cpp`

Fix codegen by generating MPSFunctions headers

Pull Request resolved: #80195
Approved by: https://github.com/albanD

* [MPS] add `aten::normal.Tensor_float` `aten::normal.float_Tensor` `aten::normal.Tensor_Tensor` (#80297)

Fixes #ISSUE_NUMBER

Pull Request resolved: #80297
Approved by: https://github.com/albanD, https://github.com/kulinseth

* [MPS] Add flip (#80214)

Fixes #ISSUE_NUMBER

Pull Request resolved: #80214
Approved by: https://github.com/DenisVieriu97, https://github.com/albanD

* [MPS] Add logical ops (#80216)

This PR adds `logical_not`, `logical_and`, `logical_or`, `logical_xor`.
Pull Request resolved: #80216
Approved by: https://github.com/albanD, https://github.com/kulinseth

* [MPS] Add glu (#79866)

Adds mps op for `aten::glu.out`.

Pull Request resolved: #79866
Approved by: https://github.com/kulinseth, https://github.com/albanD

* [MPS] Fix std/var cache issue (#80502)

Use `getTensorsStringKey` which has tensor shape info added as part of the key to prevent cache lookup issue when the shape of input tensor is changed.

Fixes #80499

Pull Request resolved: #80502
Approved by: https://github.com/malfet, https://github.com/kulinseth

* Add scatter support for view operations (#79939)

* Add scatter support for view operations; #78074, #78886, #79672
* Update test_slicing_replace_column to properly test different sizes
* Handle in-place changes for binary ops; add new testcase
* Add new view ops testing scatter; add MPSDebugConfig.h config file for debugging purposes
* Merge gatherViewTensor and scatterViewTensor into a generic function
* Add scatter on demand in scatterViewOperation instead of caching it into a generic graph
* Create separate graphs for scatter and gather;
* Create scatter graph at scatter time

Fixes #ISSUE_NUMBER

Pull Request resolved: #79939
Approved by: https://github.com/razarmehr

* MPS: Fix handling of 1D tensors in linear backward (#80759)

Fixes ##79784

Pull Request resolved: #80759
Approved by: https://github.com/ezyang

* [MPS] Move the View ops to a separate file and reduce the number of graphs created (#80491)

This is dependent on the PR to go in first: #79939

Remove the data_ptr from the View Graph key which reduces the number of
graphs created significantly.

Don't wait when copying from MPS to MPS tensors

Pull Request resolved: #80491
Approved by: https://github.com/malfet

* [MPS] Add softplus backward (#79873)

Fixes #ISSUE_NUMBER

Pull Request resolved: #79873
Approved by: https://github.com/malfet

* [MPS] Add argmin (#80828)

This PR

1. adds argmin
2. refactors `reduction_type` in `ReduceOps.mm` with enum.

Co-authored by Kulin Seth <kulinseth@gmail.com>
Pull Request resolved: #80828
Approved by: https://github.com/malfet

* [MPS] Fix LSTM batch_first output transposed (#80597)

The output of LSTM with `batch_first` should be transposed back to batch first format.

Fixes #80306

Pull Request resolved: #80597
Approved by: https://github.com/kulinseth

* [MPS][BE] Introduce MPSUnaryCachedGraph (#81033)

I.e. CachedGraph that has input and output tensors
Also, add `MPSGraphCache::LookUpAs` template, which combines LookUp with
static_cast to target type

Pull Request resolved: #81033
Approved by: https://github.com/kulinseth

* [MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)

Pull Request resolved: #79532
Approved by: https://github.com/albanD, https://github.com/malfet

* [MPS] Add huber loss (#80163)

Fixes #ISSUE_NUMBER

Pull Request resolved: #80163
Approved by: https://github.com/kulinseth, https://github.com/malfet

* Remove two tests dependent on the MPS serialization checkin.

* Fix lint error (FLAKE8) F401

* Remove the serialization test from test_mps as its support is not there in 1.12.1.

Co-authored-by: Kulin Seth <kulinseth@gmail.com>
Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
Co-authored-by: Kulin Seth <kulin_seth@apple.com>
Co-authored-by: Abhishek Pathak <abhipathak97@gmail.com>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
Co-authored-by: qqaatw <qqaatw@gmail.com>
Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants