From 0317626df59b4cf3229712affb6a8bbb666a50e0 Mon Sep 17 00:00:00 2001 From: igm503 Date: Wed, 20 Sep 2023 02:18:24 +0000 Subject: [PATCH] [MPS] adding weight_norm_interface support for mps (#108008) Fixes #104513 Adds support for aten::_weight_norm_interface to the mps backend. Also adds a consistency test for the output and the grad. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108008 Approved by: https://github.com/kulinseth --- .../ATen/native/mps/operations/WeightNorm.mm | 192 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 80 ++++++++ 3 files changed, 274 insertions(+) create mode 100644 aten/src/ATen/native/mps/operations/WeightNorm.mm diff --git a/aten/src/ATen/native/mps/operations/WeightNorm.mm b/aten/src/ATen/native/mps/operations/WeightNorm.mm new file mode 100644 index 0000000000000..7ca63533ed19b --- /dev/null +++ b/aten/src/ATen/native/mps/operations/WeightNorm.mm @@ -0,0 +1,192 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native { + +using namespace at::native::mps; + +// Derive from MPSCachedGraph +struct WeightNormCachedGraph : public MPSCachedGraph { + WeightNormCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* g_ = nil; + MPSGraphTensor* v_ = nil; + MPSGraphTensor* norms_ = nil; + MPSGraphTensor* w_ = nil; +}; + +struct WeightNormBackwardCachedGraph : public MPSCachedGraph { + WeightNormBackwardCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* grad_w = nil; + MPSGraphTensor* saved_v = nil; + MPSGraphTensor* saved_g = nil; + MPSGraphTensor* saved_norms = nil; + MPSGraphTensor* grad_g = nil; + MPSGraphTensor* grad_v = nil; +}; + +std::tuple weight_norm_mps(const Tensor& v, const Tensor& g, int64_t dim) { + TORCH_CHECK(dim == 0 || dim == v.dim() - 1, "fused kernels can only be applied for first or last dim") + + MPSStream* mpsStream = getCurrentMPSStream(); + + auto w = at::empty_like(v, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto norms = at::empty_like(g, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + string key = "weight_norm_mps_" + std::to_string(dim) + getTensorsStringKey({v, g}); + + NSMutableArray* reduction_dims = [NSMutableArray array]; + for (int i = 0; i < v.dim(); ++i) { + if (i != dim) { + [reduction_dims addObject:@(i)]; + } + } + + @autoreleasepool { + auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { + // Placeholders + newCachedGraph->v_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(v.scalar_type()), getMPSShape(v)); + newCachedGraph->g_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(g.scalar_type()), getMPSShape(g)); + + // Compute the L2 norm for each column of v + MPSGraphTensor* squared = [mpsGraph squareWithTensor:newCachedGraph->v_ name:nil]; + MPSGraphTensor* sum_squared = [mpsGraph reductionSumWithTensor:squared axes:reduction_dims name:nil]; + newCachedGraph->norms_ = [mpsGraph squareRootWithTensor:sum_squared name:nil]; + + // Divide each column of v by its L2 norm + MPSGraphTensor* unit_v = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->v_ + secondaryTensor:newCachedGraph->norms_ + name:nil]; + + // Multiply each columb of vNormalized by the corresponding element of g + newCachedGraph->w_ = [mpsGraph multiplicationWithPrimaryTensor:unit_v + secondaryTensor:newCachedGraph->g_ + name:nil]; + }); + + Placeholder v_placeholder = Placeholder(cachedGraph->v_, v, nil, true); + Placeholder g_placeholder = Placeholder(cachedGraph->g_, g, nil, true); + Placeholder norms_placeholder = Placeholder(cachedGraph->norms_, norms); + Placeholder w_placeholder = Placeholder(cachedGraph->w_, w); + + NSDictionary* feeds = @{ + v_placeholder.getMPSGraphTensor() : v_placeholder.getMPSGraphTensorData(), + g_placeholder.getMPSGraphTensor() : g_placeholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + norms_placeholder.getMPSGraphTensor() : norms_placeholder.getMPSGraphTensorData(), + w_placeholder.getMPSGraphTensor() : w_placeholder.getMPSGraphTensorData() + }; + + runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); + } + + return std::tuple{w, norms}; +} + +std::tuple weight_norm_backward_mps(const Tensor& grad_w, + const Tensor& saved_v, + const Tensor& saved_g, + const Tensor& saved_norms, + int64_t dim) { + // These checks should always succeed, because weight_norm_fused_backward should only + // ever be recorded in the autograd graph via weight_norm, which passes contiguous v and g. + TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous"); + TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous"); + TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous"); + TORCH_CHECK(dim == 0 || dim == saved_v.dim() - 1, "fused kernels can only be applied for first or last dim") + + MPSStream* mpsStream = getCurrentMPSStream(); + + auto grad_v = at::empty_like(saved_v, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_g = at::empty_like(saved_g, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + string key = + "weight_norm_backward_mps_" + std::to_string(dim) + getTensorsStringKey({grad_w, saved_v, saved_g, saved_norms}); + + NSMutableArray* reduction_dims = [NSMutableArray array]; + for (int i = 0; i < saved_v.dim(); ++i) { + if (i != dim) { + [reduction_dims addObject:@(i)]; + } + } + + @autoreleasepool { + auto cachedGraph = + LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { + // Placeholders + newCachedGraph->grad_w = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_w.scalar_type()), getMPSShape(grad_w)); + newCachedGraph->saved_v = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(saved_v.scalar_type()), getMPSShape(saved_v)); + newCachedGraph->saved_g = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(saved_g.scalar_type()), getMPSShape(saved_g)); + newCachedGraph->saved_norms = mpsGraphRankedPlaceHolder( + mpsGraph, getMPSScalarType(saved_norms.scalar_type()), getMPSShape(saved_norms)); + + // Compute Graph + MPSGraphTensor* grad_w_v = [mpsGraph multiplicationWithPrimaryTensor:newCachedGraph->grad_w + secondaryTensor:newCachedGraph->saved_v + name:nil]; + MPSGraphTensor* result = [mpsGraph reductionSumWithTensor:grad_w_v axes:reduction_dims name:nil]; + + newCachedGraph->grad_g = [mpsGraph divisionWithPrimaryTensor:result + secondaryTensor:newCachedGraph->saved_norms + name:nil]; + + MPSGraphTensor* grad_w_divided_by_norm = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->grad_w + secondaryTensor:newCachedGraph->saved_norms + name:nil]; + MPSGraphTensor* three = [mpsGraph constantWithScalar:3.0 dataType:newCachedGraph->saved_norms.dataType]; + MPSGraphTensor* norm_cubed = [mpsGraph powerWithPrimaryTensor:newCachedGraph->saved_norms + secondaryTensor:three + name:nil]; + MPSGraphTensor* v_result = [mpsGraph multiplicationWithPrimaryTensor:newCachedGraph->saved_v + secondaryTensor:result + name:nil]; + MPSGraphTensor* v_result_divided_by_norm_cubed = [mpsGraph divisionWithPrimaryTensor:v_result + secondaryTensor:norm_cubed + name:nil]; + MPSGraphTensor* diff = [mpsGraph subtractionWithPrimaryTensor:grad_w_divided_by_norm + secondaryTensor:v_result_divided_by_norm_cubed + name:nil]; + newCachedGraph->grad_v = [mpsGraph multiplicationWithPrimaryTensor:diff + secondaryTensor:newCachedGraph->saved_g + name:nil]; + }); + + Placeholder grad_w_placeholder = Placeholder(cachedGraph->grad_w, grad_w, nil, true); + Placeholder v_placeholder = Placeholder(cachedGraph->saved_v, saved_v, nil, true); + Placeholder g_placeholder = Placeholder(cachedGraph->saved_g, saved_g, nil, true); + Placeholder norms_placeholder = Placeholder(cachedGraph->saved_norms, saved_norms, nil, true); + + Placeholder grad_g_placeholder = Placeholder(cachedGraph->grad_g, grad_g); + Placeholder grad_v_placeholder = Placeholder(cachedGraph->grad_v, grad_v); + + NSDictionary* feeds = @{ + grad_w_placeholder.getMPSGraphTensor() : grad_w_placeholder.getMPSGraphTensorData(), + norms_placeholder.getMPSGraphTensor() : norms_placeholder.getMPSGraphTensorData(), + v_placeholder.getMPSGraphTensor() : v_placeholder.getMPSGraphTensorData(), + g_placeholder.getMPSGraphTensor() : g_placeholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + grad_g_placeholder.getMPSGraphTensor() : grad_g_placeholder.getMPSGraphTensorData(), + grad_v_placeholder.getMPSGraphTensor() : grad_v_placeholder.getMPSGraphTensorData() + }; + + runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); + } + + return std::tuple{grad_v, grad_g}; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 10502c01b2155..046cb82e2c715 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6265,6 +6265,7 @@ dispatch: CPU: weight_norm_cpu CUDA: weight_norm_cuda + MPS: weight_norm_mps autogen: _weight_norm_interface.out - func: _weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) @@ -6272,6 +6273,7 @@ dispatch: CPU: weight_norm_backward_cpu CUDA: weight_norm_backward_cuda + MPS: weight_norm_backward_mps autogen: _weight_norm_interface_backward.out - func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) diff --git a/test/test_mps.py b/test/test_mps.py index 59c73e20ea70e..f0c2d6c756ab3 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2538,6 +2538,86 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) + def test_weight_norm(self): + def helper(dim, layer='linear', dtype=torch.float32): + # linear layer + if layer == 'linear': + cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True) + x = cpu_x.detach().clone().to('mps').requires_grad_() + + cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True) + weight = cpu_weight.detach().clone().to('mps').requires_grad_() + + cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True) + bias = cpu_bias.detach().clone().to('mps').requires_grad_() + + cpu_linear = torch.nn.Linear(5, 10, device='cpu') + linear = torch.nn.Linear(5, 10, device='mps') + + with torch.no_grad(): + cpu_linear.weight.copy_(cpu_weight) + cpu_linear.bias.copy_(cpu_bias) + linear.weight.copy_(weight) + linear.bias.copy_(bias) + + cpu_norm = torch.nn.utils.weight_norm(cpu_linear, dim=dim) + norm = torch.nn.utils.weight_norm(linear, dim=dim) + + cpu_out = cpu_norm(cpu_x) + out = norm(x) + + self.assertEqual(cpu_out, out) + + cpu_grad = torch.randn(cpu_out.shape) + grad = cpu_grad.to('mps') + cpu_out.backward(gradient=cpu_grad) + out.backward(gradient=grad) + + self.assertEqual(cpu_linear.weight_g.grad, linear.weight_g.grad) + self.assertEqual(cpu_linear.weight_v.grad, linear.weight_v.grad) + + self.assertEqual(x.grad, cpu_x.grad) + + # conv layer + if layer == 'conv': + cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True) + x = cpu_x.detach().clone().to('mps').requires_grad_() + + cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu') + conv = torch.nn.Conv2d(3, 3, 3, device='mps') + + with torch.no_grad(): + conv.weight.copy_(cpu_conv.weight) + conv.bias.copy_(cpu_conv.bias) + + cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim) + norm = torch.nn.utils.weight_norm(conv, dim=dim) + + cpu_out = cpu_conv(cpu_x) + out = conv(x) + + self.assertEqual(cpu_out, out) + + cpu_grad = torch.randn(cpu_out.shape) + grad = cpu_grad.to('mps') + cpu_out.backward(gradient=cpu_grad) + out.backward(gradient=grad) + + self.assertEqual(cpu_conv.weight_g.grad, conv.weight_g.grad) + self.assertEqual(cpu_conv.weight_v.grad, conv.weight_v.grad) + + self.assertEqual(x.grad, cpu_x.grad) + + helper(0, layer='linear') + helper(1, layer='linear') + helper(-1, layer='linear') + + helper(0, layer='conv') + helper(1, layer='conv') + helper(2, layer='conv') + helper(3, layer='conv') + helper(-1, layer='conv') + # Test conv2d def test_conv2d_unit(self): def helper(input_shape, wt_shape,