Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear model cannot calculate grads correctly on device MPS #123259

Open
mehmetozsoy1 opened this issue Apr 3, 2024 · 3 comments
Open

linear model cannot calculate grads correctly on device MPS #123259

mehmetozsoy1 opened this issue Apr 3, 2024 · 3 comments
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mehmetozsoy1
Copy link

mehmetozsoy1 commented Apr 3, 2024

馃悰 Describe the bug

linear model cannot calculate grads correctly on device MPS

import torch 

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(1, 1)

    def forward(self, x):
        x = self.linear1(x)
        return x
    
# Set device and dtype
device = torch.device("mps:0")
dtype = torch.float32

# set models
model = Model()
criterion = torch.nn.L1Loss()
model.to(device)

# set non-trainable params
data = torch.tensor([[0.1]], dtype= dtype, device = device)
target = torch.tensor([[0.1]], dtype= dtype, device = device)

state_dict = model.state_dict()

# take trainable params
weight = state_dict["linear1.weight"]
bias = state_dict["linear1.bias"]

# calculate output of the model manually
calculated_linear_output = weight * data + bias

# find output of the torch model
output = model(data)

# check if two outputs are the same (there is no problem in here both in MPS and CPU)
assert output == calculated_linear_output

# calculate loss and gradients
loss = criterion(output, target)
loss.backward()

# calculate expected weight and biases manually 
expected_gradient_b = torch.sign(output - target)
expected_gradient_w = torch.sign(output - target) * data
grad_list = [expected_gradient_w, expected_gradient_b]

# check if all calculated gradients are the same (fails in MPS and not fails in CPU)
for param, my_param in zip(model.parameters(), grad_list):
    assert param.grad == my_param

In the code above, it runs without a problem when the device is set to CPU. However, when I set the device to MPS, it fails the gradient check.

Here is another code that does exactly the same thing, but it runs without a problem for both CPU and MPS.

import torch 

device = torch.device("mps:0")
dtype = torch.float32

def linear_1d(a,x,b):
    return a @ x + b
criterion = torch.nn.L1Loss()

data = torch.tensor([[2.0]], device= device, dtype = dtype)
weight = torch.tensor([[1.0]], device= device, dtype = dtype, requires_grad = True)
bias = torch.tensor([[0.7]], device= device, dtype = dtype, requires_grad = True)
target = torch.tensor([[0.3]], device= device, dtype = dtype)

output = linear_1d(weight, data, bias)
loss = criterion(output, target)
loss.backward()

expected_weight_gradients =  torch.sign(output - target) * data
expected_bias_gradients =  torch.sign(output - target)

assert weight.grad == expected_weight_gradients
assert bias.grad == expected_bias_gradients

Versions

PyTorch version: 2.2.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] torchaudio==2.2.2
[pip3] torchvision==0.17.2
[conda] numpy 1.26.4 py311he598dae_0
[conda] numpy-base 1.26.4 py311hfbfe69c_0
[conda] pytorch 2.2.2 py3.11_0 pytorch
[conda] torchaudio 2.2.2 py311_cpu pytorch
[conda] torchvision 0.17.2 py311_cpu pytorch

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels Apr 3, 2024
@siemens999
Copy link

siemens999 commented Apr 8, 2024

It appears that the issue has been resolved in the nightly build. I encountered the same problem while using a MacBook Air with the M2 chip. The impact of this bug is significant as it affects model convergence during training. Despite no warnings or errors, the loss values would not converge to the expected values. For instance, when fine-tuning a pretrained OCR model using the MPS backend, the model would slowly converge to a loss ten times higher than when training on the CPU or CUDA GPUs, resulting in a Character Error Rate (CER) score around ten times worse. However, even with this performance degradation, the model still outperformed the pretrained model prior to fine-tuning. Detecting such errors can be exceedingly difficult without conducting comparisons of results using identical code across multiple machines. It's reassuring to see that the issue has been addressed in the nightly build.

@jhavukainen
Copy link
Collaborator

Hi @siemens999 and @mehmetozsoy1! I didn't look too deeply into this particular case as I didn't notice the issue open at the time but since you've noted it working in the nightly build I assume the root cause got fixed in #123234

@siemens999
Copy link

siemens999 commented Apr 8, 2024

After testing on different Mac models, I've discovered that the issue I mentioned earlier and the bug reported here are not the same. I won't be submitting a new bug report since both issues have been addressed in the nightly release.

It seems sensible to prioritize integrating this fix into a stable release sooner rather than later. Additionally, it might be beneficial to alert mac users currently on Torch 2.2.2 to upgrade, as they may not realize there's an issue despite its significant impact.

If necessary, I can create a notebook to demonstrate the behavior by tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants