Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS convolution is sometimes returning NaNs for valid inputs. #84138

Open
albanD opened this issue Aug 26, 2022 · 9 comments
Open

MPS convolution is sometimes returning NaNs for valid inputs. #84138

albanD opened this issue Aug 26, 2022 · 9 comments
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@albanD
Copy link
Collaborator

albanD commented Aug 26, 2022

Splitting #81185 into two. This one focusses on the MPS side of the problem with the following repro:

# bug_demo.py

import torch

n_trials = 100
for ii in range(n_trials):
    a = torch.randn(1024, device='mps')
    b = torch.randn(499, device='mps')
    c = torch.nn.functional.conv1d(a.view(1, 1, -1), b.view(1, 1, -1))
    if torch.isnan(torch.sum(c)):
        print(f'mps: trial {ii}, nan elements {torch.isnan(c.squeeze()).nonzero().view(-1).cpu().numpy()}')

cc @kulinseth

@albanD albanD added the module: mps Related to Apple Metal Performance Shaders framework label Aug 26, 2022
@leovinus2001
Copy link

While I hesitate to ask, may I ask for a moment of your time for a simple sanity check please?

When working on the CPU and MPS testcase for repeated torch:mm() in issue #81185, #81185 (comment), I noticed yesterday that a C++ printf() of a torch::tensor residing on the Intel Mac MPS GPU gave weird numbers like 1e-34,0.,0. instead of the "proper" data like 1,2,3,4,5. When all data copied back to the CPU then the tensor data was correct. That behavior puzzled me. Therefore, on a hunch, could someone humor me and try the Python code snippet below please on an Arm Mac with torch and MPS enabled?

For reference, on a Linux machine with CUDA, creating the tensor on CPU, copy to GPU, and copy back to CPU, you see e.g.

Torch version: 1.11.0+cu102
on cpu 1, ptr = 0x51d7600, device= cpu, type= torch.FloatTensor, data= tensor([1., 2., 3., 4., 5.])
on gpu 2, ptr = 0x7fbfa6c00000, device= cuda:0, type= torch.cuda.FloatTensor, data= tensor([1., 2., 3., 4., 5.], device='cuda:0')
on cpu 3, ptr = 0x5914240, device= cpu, type= torch.FloatTensor, data= tensor([1., 2., 3., 4., 5.])

and as you can see, the "on gpu 2" line has a pointer into the GPU memory and prints correct data.

Yesterday, for a moment, I just wondered whether this behavior is a bit different with Mac+MPS? I am probably wrong but can someone confirm that we get something similar like the CPU+CUDA output above on an Arm Mac with MPS, and NOT a broken line like

on gpu 2, ptr = 0x7fbfa6c00000, device= mps, type= torch.cuda.FloatTensor, data= tensor([1.e-34,0.111111, 0., 0., 0., 0.], device='mps')

Rational - If we DO see something like [1.e-34,0.111111, 0., 0., 0., 0.] then that might explain the issue #81185 as well as here why you see occasional NaNs. The checking code for NaNs would be looking at different memory than expected (as the CUDA vs MPS behavior would be different). But if this is the case a lot of other things would go wrong as well I think. If we see the correct output [1., 2., 3., 4., 5.] from MPS then that is a relieve and I keep checking why the C++ printf() via data_ptr() gave a different result for "on GPU" vs "on CPU". Sorry for the noise, thank you.

import torch

cpu_device = torch.device("cpu")
#gpu_device = torch.device("cuda:0")
gpu_device = torch.device("mps")

def tensorPrint(prefix_, data_):
   print (str(prefix_)+ ", ptr = "+ str(hex(data_.data_ptr())) + ", device= " + str(data_.device) + ", type= " + str(data_.type()) + ", data= " + str(data_))

if __name__ == '__main__':
   print ("Torch version: ", torch.__version__)
   x = torch.tensor([1.,2.,3.,4.,5.], dtype=torch.float32, device=cpu_device)
   tensorPrint("on cpu 1", x)
   
   x2 = x.to(gpu_device)
   tensorPrint("on gpu 2", x2)
   
   x3 = x2.to(cpu_device)
   tensorPrint("on cpu 3", x3)

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 29, 2022
@Vargol
Copy link

Vargol commented Aug 30, 2022

(diffuse) xxxx@Vargols-Mac-mini diffuse % python test.py 
Torch version:  1.13.0.dev20220828
on cpu 1, ptr = 0x1597299c0, device= cpu, type= torch.FloatTensor, data= tensor([1., 2., 3., 4., 5.])
/Volumes/Sabrent Media/Documents/Source/Python/diffuse/lib/python3.10/site-packages/torch/_tensor_str.py:114: UserWarning: The operator 'aten::masked_select' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
  nonzero_finite_vals = torch.masked_select(
on gpu 2, ptr = 0x12b267d60, device= mps:0, type= torch.mps.FloatTensor, data= tensor([1., 2., 3., 4., 5.], device='mps:0')
on cpu 3, ptr = 0x159643f00, device= cpu, type= torch.FloatTensor, data= tensor([1., 2., 3., 4., 5.])

@leovinus2001
Copy link

(diffuse) xxxx@Vargols-Mac-mini diffuse % python test.py
Torch version: 1.13.0.dev20220828
on cpu 1, ptr = 0x1597299c0, device= cpu, type= torch.FloatTensor, data= tensor([1., 2., 3., 4., 5.])
/Volumes/Sabrent Media/Documents/Source/Python/diffuse/lib/python3.10/site-packages/torch/_tensor_str.py:114: UserWarning: The operator 'aten::masked_select' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
nonzero_finite_vals = torch.masked_select(
on gpu 2, ptr = 0x12b267d60, device= mps:0, type= torch.mps.FloatTensor, data= tensor([1., 2., 3., 4., 5.], device='mps:0')
on cpu 3, ptr = 0x159643f00, device= cpu, type= torch.FloatTensor, data= tensor([1., 2., 3., 4., 5.])

Thank you! This tells me a few things

  • Basically, all is well with print and access to MPS located tensors, nice. I was too paranoid ;)
  • The field "type= torch.mps.FloatTensor" tells me that the MPS based data has more internal structure compared to the CPU side torch.FloatTensor. That explains on the C++ side why (float*)(tensor.data_ptr())[index] works for CPU data but not for GPU. It explains that the funny print results in [Mac M1] torch.mm sometimes produces incorrect results #81185 (comment) are not related to the questions of "Why NaNs here?".
  • I briefly looked for the Metal/MPS class/struct torch.mps.FloatTensor to have the C++ side take this into account but no luck. Examples welcome while the checking continues.

@leovinus2001
Copy link

leovinus2001 commented Aug 30, 2022

TL;DR On the same Intel iMac machine, the same MPS matrix multiple via Torch latest C++ does not generate NaN but the Python test case on same Torch DOES generate occasional NaNs.

On on Intel iMac with Torch 1.13.0a0+gitf4f54c7 (git from yesterday, locally built) and macOS Monterey 12.5, I see that

trial 14, elements [225]
trial 23, elements [282]
trial 60, elements [70]
trial 90, elements [8]
etc

When a NaN is found then I checked the first few elements of input and output matrices a,b,c and that looks normal to me.

PS: on first invocation of "python3 bug.py" I see the warning

/bug_demo.py:9: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at aten/src/ATen/mps/MPSFallback.mm:11.)

So where do we go from here?

@leovinus2001
Copy link

And the NaN is indeed also generated for the convolution case in the first post of this thread i.e. both torch.nn.functional.conv1d and torch.mm exhibit the NaN bevior on this setup and machine

@leovinus2001
Copy link

leovinus2001 commented Aug 31, 2022

I start to wonder whether the NaNs are related to torch.mm or torch.conv at all. The code snippet below, adapted from earlier test case in Python, generates NaN as well. Which seems the reason for the NaNs in output matrix C after torch.mm. The NaN was in the input already ;) The code generates regular NaN on a MPS device (on an Intel iMac, 1.13.0a0+gitf4f54c7) but not on CPU. It seems like it is a randn() specific issue as torch.zeros(), ones(), full( (499, 526), 3.14, device='ops') work fine.

Even better, a functional equivalent test case in C++ generating data on MPS via torch::rand() does not generate any NaN, but the Python snippet below does.

Reminds me of another Torch MPS randn issue #84288

Am happy to followup further and put breakpoints and debug stuff in MPS version of randn() or Python rand() bindings but I could use a few pointers and filenames to get started please :)

import torch
    
for ii in range(10000):
    fa = torch.randn(499, 526, device='mps') # NaNs with randn()
    #fa = torch.ones(499, 526, device='mps') # No NaN with zeros(), ones(), full( (499, 526), 3.14, device='mps')

    if torch.isnan(torch.sum(fa)):
        myindex = torch.isnan(fa.cpu().squeeze()).nonzero().view(-1)        
        print(f'trial A {ii}, elements {myindex}')

@Vargol
Copy link

Vargol commented Sep 3, 2022

Did you find the code.

I found
https://github.com/pytorch/pytorch/blob/ad44670fa1ce2dad7e2cdc3f90d27668e88e9548/aten/src/ATen/native/TensorFactories.cpp
but that just calls
auto result = at::empty(size, options);
return result.normal_(0, 1, generator);

I haven't had time to track down Tensor::normal_ yet

I also found a partial MPS Generator implementation at
https://github.com/pytorch/pytorch/blob/ad44670fa1ce2dad7e2cdc3f90d27668e88e9548/aten/src/ATen/native/mps/OperationUtils.mm

which it think is thin inheritance wrapper around the c10 generator (its been a couple of decades since I wrote any C++)

https://github.com/pytorch/pytorch/blob/master/c10/core/GeneratorImpl.cpp

which just calls /dev/urandom

@leovinus2001
Copy link

leovinus2001 commented Sep 3, 2022

Did you find the code.

Sorry, no. I was hoping for some pointers into at::mps:: and randn like here for a similar issue
#84229 (comment)

PS: Thanks for your link to the MPS Generator.

@xght
Copy link

xght commented Oct 9, 2022

I tried to rewrite random_mps_impl of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/mps/operations/Distributions.mm and I generated Nan with this simple swift code:

import Foundation
import MetalPerformanceShadersGraph

var gDevice = MTLCreateSystemDefaultDevice()!
var gCommandQueue = gDevice.makeCommandQueue()!

let randomSize: Int = 5000000

let desc = MPSGraphRandomOpDescriptor(distribution: MPSGraphRandomDistribution.normal, dataType: MPSDataType.float32)
desc?.mean = 0.0
desc?.standardDeviation = 1.0
let graph = MPSGraph()
let randomTensor = graph.randomTensor(withShape: [randomSize as NSNumber], descriptor: desc!, seed: Int.random(in:0...256*256*256), name: nil)

let sum = graph.reductionSum(with: randomTensor, axes: [0], name: nil)

var feeds: [MPSGraphTensor:MPSGraphTensorData] = [:]

let fetch = graph.run(with: gCommandQueue,
                      feeds: feeds,
                      targetTensors: [sum],
                      targetOperations: [])

let output = fetch[sum]!
var sumResult: Float32 = 0.0
output.mpsndarray().readBytes(&sumResult, strideBytes: nil)
print(sumResult)

The frequency of NaN seems to be the same as with the pytorch code

torch.sum(torch.randn((5000000), device='mps'))

This issue comes from mps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants