Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New improved Conv3D implementation for MPS and support for ConvTranspose3D #116580

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mattiaspaul
Copy link

I noticed that the native Conv3D code has severe performance issues on Mac GPUs. This improved implementation replaces the native Conv3D with two operations: unfold of depth dimension followed by Conv2D (details below). It is up to 600% faster (depending on the kernel-shapes) see table further down. It also enables ConvTranspose3D, which was not possible before and hence re-fixes #77818 and enables architectures such as 3D UNets to work out-of-the-box. It also circumvents the MacOS 13 requirement.

The equivalent PyTorch/python code for the new implementation is given below for reference (for MPSGraph details see code):

 @staticmethod
     def forward(ctx, x, weight, shapes):
         B,in_C,in_D,in_H,in_W = x.shape
         out_C,_,k_D,k_H,k_W = weight.shape
         p_D,p_H,p_W = shapes[0].tolist()#padding
         s_D,s_H,s_W = shapes[1].tolist()#stride
         d_D,d_H,d_W = shapes[2].tolist()#dilation
         out_D,out_H,out_W = shapes[3].tolist()#shape_out
         groups,_,_ = shapes[4].tolist()
         weight2d = weight.view(out_C,-1,k_H,k_W)
         unfold_weight = torch.eye(k_D,k_D).to(device).view(k_D,1,k_D,1)
         x2d = F.conv2d(x.view(-1,1,in_D,in_H*in_W),unfold_weight,padding=(p_D,0),stride=(s_D,1),dilation=(d_D,1))
         x2d_ = x2d.view(B,in_C,k_D,out_D,in_H,in_W).permute(0,3,1,2,4,5).reshape(B*out_D,in_C*k_D,in_H,in_W)
         out = F.conv2d(x2d_,weight2d,padding=(p_H,p_W),stride=(s_H,s_W),dilation=(d_H,d_W),groups=groups).view(B,out_D,out_C,out_H,out_W).permute(0,2,1,3,4)
         ctx.save_for_backward(x2d_,weight2d,unfold_weight,shapes)
         return out
@staticmethod
     def backward(ctx, gradient):
         x2d_,weight2d,unfold_weight,shapes = ctx.saved_tensors
         B,in_C,in_D,in_H,in_W = x.shape
         out_C,_,k_D,k_H,k_W = weight.shape
         p_D,p_H,p_W = shapes[0].tolist()#padding
         s_D,s_H,s_W = shapes[1].tolist()#stride
         d_D,d_H,d_W = shapes[2].tolist()#dilation
         out_D,out_H,out_W = shapes[3].tolist()#shape_out
         groups,_,_ = shapes[4].tolist()

         outback = gradient.permute(0,2,1,3,4).reshape(B*out_D,out_C,out_H,out_W)
         x2d_grad_ = -jacobian(lambda x: (F.conv2d(x,weight2d,padding=(p_H,p_W),dilation=(d_H,d_W),stride=(s_H,s_W),groups=groups)-outback)\
                               .pow(2).mul(0.5).sum(),torch.zeros(B*out_D,in_C*k_D,in_H,in_W))
         x2d_grad = x2d_grad_.reshape(B,out_D,in_C,k_D,in_H,in_W).permute(0,2,3,1,4,5).reshape(B*in_C,k_D,out_D,in_H*in_W)
         x_grad_ = -jacobian(lambda x: (F.conv2d(x,unfold_weight,padding=(p_D,0),dilation=(d_D,1),stride=(s_D,1))-x2d_grad)\
                             .pow(2).mul(0.5).sum(),torch.zeros(B*in_C,1,in_D,in_H*in_W))
         x_grad = x_grad_.view(B,in_C,in_D,in_H,in_W)
         w_grad = -jacobian(lambda w: (F.conv2d(x2d_,w,padding=(p_H,p_W),dilation=(d_H,d_W),stride=(s_H,s_W),groups=groups)-outback).pow(2).mul(0.5).sum(), torch.zeros(out_C,in_C*k_D//groups,k_H,k_W)).view(out_C,in_C//groups,k_D,k_H,k_W)

         return x_grad,w_grad,None #shapes has no grad
Shapes new GPU (fwd) new GPU (fwd+bwd) CPU (fwd+bwd) Old-Master GPU (fwd) Old-Master GPU (fwd+bwd)
ch=64, kernel=3, b=32 5831 4276 210 1515 1430
ch=128, kernel=3, b=32 8723 6371 389 1511 1545
ch=256, kernel=3, b=16 10045 6701 1126 1511 1579
ch=512, kernel=3, b=8 10243 7875 1718 1496 1575
ch=64, kernel=1, b=128 2451 1904 264 1865 203
ch=128, kernel=1, b=128 3858 3084 541 2003 652
ch=256, kernel=1, b=128 5210 4512 736 2101 1368
ch=512, kernel=1, b=64 6220 5084 1204 2149 1707

new implementation of Conv3D that addresses severe performance issues of native MPSGraph code and adds support for ConvTranspose3d
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Dec 31, 2023
Copy link

pytorch-bot bot commented Dec 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116580

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures

As of commit 7e2ec42 with merge base a919742 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mattiaspaul
Copy link
Author

Thanks to @LucasSte for providing the fixes for the original pull-request that enabled a great start to work with Conv3D on Mac GPUs. I'm cc'ing @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr for potential reviews.
PS: the experiments above were performed on an M2 Max with 30-core GPU, which has a theoretical throughput of 11 TFlops, hence the new forward Conv3d performance for large channel sizes comes close.

@cpuhrsch cpuhrsch requested a review from albanD January 3, 2024 08:54
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 3, 2024
@QianMuXiao
Copy link

Will the update to add ConvTranspose3D functionality to MPS be merged recently?

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work. Results indeed look impressive, but please fix lint issues and add unit test to test_mps.py

aten/src/ATen/native/mps/operations/Convolution.mm Outdated Show resolved Hide resolved
auto output_t =
mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups);
return output_t;
if(is3DConv){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(is3DConv){
if (is3DConv) {

return nil;
}
MPSGraphTensor* outputTensor = inputTensor;
outputTensor = [graph transposeTensor:outputTensor permutation:permuteOrder name:nil];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this function return something?

static MPSGraphTensor* reshapePermuteReshape(MPSGraph* mpsGraph, MPSGraphTensor* tensor__, MPSShape* reshape1, MPSShape* permutation, MPSShape* reshape2) {
MPSGraphTensor *tensor_ = [mpsGraph reshapeTensor:tensor__ withShape:reshape1 name:nil];
MPSGraphTensor *tensor;
if (@available(macOS 13.0, *)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use @available(as it false false if executed from Python runtime built for older MacOS), but rather is_macos_13_or_newer()

aten/src/ATen/native/mps/operations/Convolution.mm Outdated Show resolved Hide resolved
aten/src/ATen/native/mps/operations/Convolution.mm Outdated Show resolved Hide resolved
@tasansal
Copy link

hi guys, this is great. any ETA on merging this? Thanks!

@blasscoc
Copy link

I compiled and ran this on my mac, it's like x8-x10 faster than the current pytorch application using Conv3D, so thanks for your work implementing this.

@Datamance
Copy link

Interested to see how this compares to manual convolution with something like taichi.

@francescopisu
Copy link

Thought it would be useful for other folks to have a quick reference on how to build PyTorch from source on an Apple Silicon Mac on this PR's state.

@francescopisu
Copy link

francescopisu commented May 9, 2024

I compiled and ran this on my mac, it's like x8-x10 faster than the current pytorch application using Conv3D, so thanks for your work implementing this.

Have anyone tried backward() on some network with ConvTranspose3d ? I'm getting this error at the loss.backward(). No errors on CPU.
I don't understand it very well but I see a mismatch in dimensions and channels: 64 vs 3 and 1024 which is 32*32.
(mpsFileLoc): /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:303:0: error: 'mps.reshape' op the result shape is not compatible with the input shape (mpsFileLoc): /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:303:0: note: see current operation: %5 = "mps.reshape"(%arg1, %4) : (tensor<8x64x32x32x32xf32>, tensor<4xsi32>) -> tensor<8x3x32x1024xf32> /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphComputePackage.mm:180: failed assertionexpected a valid model URL'
`

P.s. the forward pass is OK

@QianMuXiao
Copy link

I compiled and ran this on my mac, it's like x8-x10 faster than the current pytorch application using Conv3D, so thanks for your work implementing this.

Have anyone tried backward() on some network with ConvTranspose3d ? I'm getting this error at the loss.backward(). No errors on CPU. I don't understand it very well but I see a mismatch in dimensions and channels: 64 vs 3 and 1024 which is 32*32. (mpsFileLoc): /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:303:0: error: 'mps.reshape' op the result shape is not compatible with the input shape (mpsFileLoc): /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:303:0: note: see current operation: %5 = "mps.reshape"(%arg1, %4) : (tensor<8x64x32x32x32xf32>, tensor<4xsi32>) -> tensor<8x3x32x1024xf32> /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphComputePackage.mm:180: failed assertionexpected a valid model URL' `

P.s. the forward pass is OK

@francescopisu Yeah Ive got this same 'mps.reshape' error on my mac after your help building pytorch when Im traing a 3D-CycleGAN model and the code run perfect on cuda or cpu but got this error on mps

@jbrown81
Copy link

Any progress getting this merged?

@timoyang
Copy link

timoyang commented Jul 22, 2024

Any progress getting this merged? @mattiaspaul

@frxderic
Copy link

frxderic commented Aug 8, 2024

Are there any updated in when this is getting merged? I tried to compile it from source following the guide of @francescopisu but run repeatedly into the error:
pytorch/torch/csrc/utils/tensor_numpy.cpp:404:34: error: no member named 'elsize' in '_PyArray_Descr'
dtype_size_in_bytes = descr->elsize;
~~~~~ ^
1 error generated.
[6959/6962] Building CXX object functorch/CMakeFiles/functorch.dir/csrc/init_dim_only.cpp.o
ninja: build stopped: subcommand failed.

Any held is greatly appreciated!

@shuuul
Copy link

shuuul commented Aug 11, 2024

I read this Issue MIC-DKFZ/nnUNet#2435. There is a fork https://github.com/LalithShiyam/pytorch-mps. The 'elsize' problem comes from the Numpy version, and you can refer to mattiaspaul@ffda73c.

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 10, 2024
@bghira
Copy link

bghira commented Oct 10, 2024

no stale

@malfet
Copy link
Contributor

malfet commented Oct 10, 2024

Let me try to rebase it today and see if it still works...

@malfet
Copy link
Contributor

malfet commented Oct 11, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/116580/head returned non-zero exit code 1

Rebasing (1/26)
Auto-merging aten/src/ATen/native/mps/operations/Convolution.mm
CONFLICT (content): Merge conflict in aten/src/ATen/native/mps/operations/Convolution.mm
error: could not apply 3c61c52569... Add files via upload
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Could not apply 3c61c52569... Add files via upload

Raised by https://github.com/pytorch/pytorch/actions/runs/11284391883

@malfet
Copy link
Contributor

malfet commented Oct 11, 2024

@bghira do you want to take over this PR and try rebasing it against latest trunk?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) open source release notes: mps Release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.nn.Conv3D on MPS backend