Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS bug on torch.transpose and torch.log #89673

Open
hiro4bbh opened this issue Nov 25, 2022 · 2 comments
Open

MPS bug on torch.transpose and torch.log #89673

hiro4bbh opened this issue Nov 25, 2022 · 2 comments
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hiro4bbh
Copy link

hiro4bbh commented Nov 25, 2022

馃悰 Describe the bug

The following code uses torch.transpose and torch.log for computing the loss value.
However, it seems that the application order of the above two functions causes the different results.

#!/usr/bin/env python3
import argparse
import torch
torch.random.manual_seed(31337)

parser = argparse.ArgumentParser()
parser.add_argument("--bug", help="use buggy path", action="store_true")
parser.add_argument("--device", help="specify device", type=str, default="cpu")
args = parser.parse_args()

class Diff:
  def __init__(self, model, device, use_bug):
    self.model = model
    self.device = device
    self.use_bug = use_bug
    self.lossfn = torch.nn.NLLLoss(reduction="sum")
  def forward(self, x0_indices):
    x_indices = torch.zeros((1+1, x0_indices.shape[0], self.model.length), dtype=torch.long).to(self.device)
    q = torch.zeros((1+1, x0_indices.shape[0], self.model.length, 2)).to(self.device)
    x_indices[0,] = x0_indices
    q[1,] = 0.5*torch.ones(q[1,].shape)
    x_indices[1,] = torch.distributions.Categorical(q[1,]).sample()
    return x_indices, q
  def loss(self, x0_indices):
    x_indices, q = self.forward(x0_indices)
    if self.use_bug:
      pt = torch.log(torch.transpose(self.model(x_indices[1,], 1), -2, -1))
    else:
      pt = torch.transpose(torch.log(self.model(x_indices[1,], 1)), -2, -1)
    qt = torch.log(torch.transpose(q[1,], -2, -1))
    return self.lossfn(pt, x_indices[0,])

class MLP(torch.nn.Module):
  def __init__(self, length):
    super().__init__()
    self.length = length
    self.embed_input = torch.nn.Embedding(2, 50, padding_idx=0)
    self.readouts = torch.nn.Linear(50, 2)
    self.softmax = torch.nn.Softmax(dim=-1)
  def forward(self, x_indices, t):
    x = self.embed_input(x_indices)
    x = x.reshape((x.shape[0], self.length, -1))
    return self.softmax(self.readouts(x))

x0_indices = torch.zeros((200, 20))
for i in range(x0_indices.shape[0]):
  for j in range(i%5, x0_indices.shape[1], 5):
    x0_indices[i, j] = 1

model = MLP(x0_indices.shape[1]).to(args.device)
diff = Diff(model, args.device, args.bug)

optim = torch.optim.Adam(diff.model.parameters())
for epoch in range(10000):
  loss = diff.loss(x0_indices)
  print(f"[*] epoch={epoch}: loss={loss.item():.3f}")
  if loss < 0.0:
    print(f"[-] loss is not positive")
    break
  optim.zero_grad()
  loss.backward()
  optim.step()

Here is my result:

% python3 pytorch_mps_bug.py --device cpu      
[*] epoch=0: loss=2608.710
...
[*] epoch=9999: loss=2001.556
% python3 pytorch_mps_bug.py --device cpu --bug
[*] epoch=0: loss=2608.710
...
[*] epoch=9999: loss=2001.556
% python3 pytorch_mps_bug.py --device mps      
[*] epoch=0: loss=3261.913
...
[*] epoch=9999: loss=0.016
% python3 pytorch_mps_bug.py --device mps --bug
[*] epoch=0: loss=105.605
[*] epoch=1: loss=66.850
[*] epoch=2: loss=28.175
[*] epoch=3: loss=-10.516
[-] loss is not positive

At least, I think loss values (torch.nn.NLLLoss) should not be negative value, because torch.nn.Softmax is applied.

In addition, the loss values after 10,000 epochs on CPU and MPS avoiding buggy path are different.
I wonder why this difference happens.

Versions

On my Apple M2 MacBook Air:

% curl https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py > ./collect_env.py
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 17278  100 17278    0     0   479k      0 --:--:-- --:--:-- --:--:--  581k
% python3 ./collect_env.py 
Collecting environment information...
PyTorch version: 1.13.0a0+git6dc8fba
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.24.3
Libc version: N/A

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.13.0a0+git7c98e70
[conda] numpy                     1.23.4          py310h5d7c261_1    conda-forge
[conda] torch                     1.13.0a0+git7c98e70           dev_0    <develop>

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels Nov 28, 2022
@tvercaut
Copy link

I am observing a very similar issue (probably the same root cause) with torch.moveaxis and torch.clamp. Here is a minimal example:

import torch
print(f'Running PyTorch version: {torch.__version__}')

dtype = torch.float32

devices = [torch.device("mps"), torch.device("cpu")]

for device in devices:
    print(f"Using device: {device}")

    source = torch.randn(3, 1088, 2048, dtype=dtype, device=device)
    print("source: ", source.shape, source.dtype, source.device,
              source.cpu().numpy().flatten().min(), source.cpu().numpy().flatten().max())

    target = torch.clamp(torch.moveaxis(source, 0, -1), 0.0, 1.0)
    print("clamp(moveaxis(source)): ", target.shape, target.dtype, target.device,
              target.cpu().numpy().flatten().min(), target.cpu().numpy().flatten().max())

    target = torch.moveaxis(torch.clamp(source, 0.0, 1.0), 0, -1)
    print("moveaxis(clamp(source)): ", target.shape, target.dtype, target.device,
              target.cpu().numpy().flatten().min(), target.cpu().numpy().flatten().max())

which leads to completely wrong results when clamp is applied second on mps

Running PyTorch version: 2.0.0
Using device: mps
source:  torch.Size([3, 1088, 2048]) torch.float32 mps:0 -14.372121 5.0354056
clamp(moveaxis(source)):  torch.Size([1088, 2048, 3]) torch.float32 mps:0 0.0 0.0
moveaxis(clamp(source)):  torch.Size([1088, 2048, 3]) torch.float32 mps:0 0.0 1.0
Using device: cpu
source:  torch.Size([3, 1088, 2048]) torch.float32 cpu -5.206722 5.2658024
clamp(moveaxis(source)):  torch.Size([1088, 2048, 3]) torch.float32 cpu 0.0 1.0
moveaxis(clamp(source)):  torch.Size([1088, 2048, 3]) torch.float32 cpu 0.0 1.0

@hiro4bbh
Copy link
Author

hiro4bbh commented Apr 4, 2023

From version 2.0.0, my issue is fixed.
(To be honest, newer MPS version seems to become slower ... but this is a small case example and unoptimized, so there is no consideration on performance)

% python3
Python 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:12:31) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'2.0.0'
% python3 pytorch_mps_bug.py --device cpu
[*] epoch=0: loss=2592.284
...
[*] epoch=9999: loss=2001.773
% python3 pytorch_mps_bug.py --device cpu --bug
[*] epoch=0: loss=2592.284
...
[*] epoch=9999: loss=2001.773
% python3 pytorch_mps_bug.py --device mps      
[*] epoch=0: loss=2617.115
...
[*] epoch=9999: loss=2001.842
% python3 pytorch_mps_bug.py --device mps --bug
[*] epoch=0: loss=2617.115
...
[*] epoch=9999: loss=2001.842

There are some other (maybe) related issues, so I won't close this issue.
Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants