Skip to content

Commit

Permalink
[MPS] adding weight_norm_interface support for mps (#108008)
Browse files Browse the repository at this point in the history
Fixes #104513

Adds support for aten::_weight_norm_interface to the mps backend.

Also adds a consistency test for the output and the grad.
Pull Request resolved: #108008
Approved by: https://github.com/kulinseth
  • Loading branch information
igm503 authored and pytorchmergebot committed Sep 20, 2023
1 parent 1b3e5b5 commit 0317626
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 0 deletions.
192 changes: 192 additions & 0 deletions aten/src/ATen/native/mps/operations/WeightNorm.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/OperationUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_weight_norm_interface_backward_native.h>
#include <ATen/ops/_weight_norm_interface_native.h>
#endif

namespace at::native {

using namespace at::native::mps;

// Derive from MPSCachedGraph
struct WeightNormCachedGraph : public MPSCachedGraph {
WeightNormCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* g_ = nil;
MPSGraphTensor* v_ = nil;
MPSGraphTensor* norms_ = nil;
MPSGraphTensor* w_ = nil;
};

struct WeightNormBackwardCachedGraph : public MPSCachedGraph {
WeightNormBackwardCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* grad_w = nil;
MPSGraphTensor* saved_v = nil;
MPSGraphTensor* saved_g = nil;
MPSGraphTensor* saved_norms = nil;
MPSGraphTensor* grad_g = nil;
MPSGraphTensor* grad_v = nil;
};

std::tuple<Tensor, Tensor> weight_norm_mps(const Tensor& v, const Tensor& g, int64_t dim) {
TORCH_CHECK(dim == 0 || dim == v.dim() - 1, "fused kernels can only be applied for first or last dim")

MPSStream* mpsStream = getCurrentMPSStream();

auto w = at::empty_like(v, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto norms = at::empty_like(g, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

string key = "weight_norm_mps_" + std::to_string(dim) + getTensorsStringKey({v, g});

NSMutableArray* reduction_dims = [NSMutableArray array];
for (int i = 0; i < v.dim(); ++i) {
if (i != dim) {
[reduction_dims addObject:@(i)];
}
}

@autoreleasepool {
auto cachedGraph = LookUpOrCreateCachedGraph<WeightNormCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
// Placeholders
newCachedGraph->v_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(v.scalar_type()), getMPSShape(v));
newCachedGraph->g_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(g.scalar_type()), getMPSShape(g));

// Compute the L2 norm for each column of v
MPSGraphTensor* squared = [mpsGraph squareWithTensor:newCachedGraph->v_ name:nil];
MPSGraphTensor* sum_squared = [mpsGraph reductionSumWithTensor:squared axes:reduction_dims name:nil];
newCachedGraph->norms_ = [mpsGraph squareRootWithTensor:sum_squared name:nil];

// Divide each column of v by its L2 norm
MPSGraphTensor* unit_v = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->v_
secondaryTensor:newCachedGraph->norms_
name:nil];

// Multiply each columb of vNormalized by the corresponding element of g
newCachedGraph->w_ = [mpsGraph multiplicationWithPrimaryTensor:unit_v
secondaryTensor:newCachedGraph->g_
name:nil];
});

Placeholder v_placeholder = Placeholder(cachedGraph->v_, v, nil, true);
Placeholder g_placeholder = Placeholder(cachedGraph->g_, g, nil, true);
Placeholder norms_placeholder = Placeholder(cachedGraph->norms_, norms);
Placeholder w_placeholder = Placeholder(cachedGraph->w_, w);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
v_placeholder.getMPSGraphTensor() : v_placeholder.getMPSGraphTensorData(),
g_placeholder.getMPSGraphTensor() : g_placeholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
norms_placeholder.getMPSGraphTensor() : norms_placeholder.getMPSGraphTensorData(),
w_placeholder.getMPSGraphTensor() : w_placeholder.getMPSGraphTensorData()
};

runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
}

return std::tuple<Tensor, Tensor>{w, norms};
}

std::tuple<Tensor, Tensor> weight_norm_backward_mps(const Tensor& grad_w,
const Tensor& saved_v,
const Tensor& saved_g,
const Tensor& saved_norms,
int64_t dim) {
// These checks should always succeed, because weight_norm_fused_backward should only
// ever be recorded in the autograd graph via weight_norm, which passes contiguous v and g.
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
TORCH_CHECK(dim == 0 || dim == saved_v.dim() - 1, "fused kernels can only be applied for first or last dim")

MPSStream* mpsStream = getCurrentMPSStream();

auto grad_v = at::empty_like(saved_v, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto grad_g = at::empty_like(saved_g, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

string key =
"weight_norm_backward_mps_" + std::to_string(dim) + getTensorsStringKey({grad_w, saved_v, saved_g, saved_norms});

NSMutableArray* reduction_dims = [NSMutableArray array];
for (int i = 0; i < saved_v.dim(); ++i) {
if (i != dim) {
[reduction_dims addObject:@(i)];
}
}

@autoreleasepool {
auto cachedGraph =
LookUpOrCreateCachedGraph<WeightNormBackwardCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
// Placeholders
newCachedGraph->grad_w =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_w.scalar_type()), getMPSShape(grad_w));
newCachedGraph->saved_v =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(saved_v.scalar_type()), getMPSShape(saved_v));
newCachedGraph->saved_g =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(saved_g.scalar_type()), getMPSShape(saved_g));
newCachedGraph->saved_norms = mpsGraphRankedPlaceHolder(
mpsGraph, getMPSScalarType(saved_norms.scalar_type()), getMPSShape(saved_norms));

// Compute Graph
MPSGraphTensor* grad_w_v = [mpsGraph multiplicationWithPrimaryTensor:newCachedGraph->grad_w
secondaryTensor:newCachedGraph->saved_v
name:nil];
MPSGraphTensor* result = [mpsGraph reductionSumWithTensor:grad_w_v axes:reduction_dims name:nil];

newCachedGraph->grad_g = [mpsGraph divisionWithPrimaryTensor:result
secondaryTensor:newCachedGraph->saved_norms
name:nil];

MPSGraphTensor* grad_w_divided_by_norm = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->grad_w
secondaryTensor:newCachedGraph->saved_norms
name:nil];
MPSGraphTensor* three = [mpsGraph constantWithScalar:3.0 dataType:newCachedGraph->saved_norms.dataType];
MPSGraphTensor* norm_cubed = [mpsGraph powerWithPrimaryTensor:newCachedGraph->saved_norms
secondaryTensor:three
name:nil];
MPSGraphTensor* v_result = [mpsGraph multiplicationWithPrimaryTensor:newCachedGraph->saved_v
secondaryTensor:result
name:nil];
MPSGraphTensor* v_result_divided_by_norm_cubed = [mpsGraph divisionWithPrimaryTensor:v_result
secondaryTensor:norm_cubed
name:nil];
MPSGraphTensor* diff = [mpsGraph subtractionWithPrimaryTensor:grad_w_divided_by_norm
secondaryTensor:v_result_divided_by_norm_cubed
name:nil];
newCachedGraph->grad_v = [mpsGraph multiplicationWithPrimaryTensor:diff
secondaryTensor:newCachedGraph->saved_g
name:nil];
});

Placeholder grad_w_placeholder = Placeholder(cachedGraph->grad_w, grad_w, nil, true);
Placeholder v_placeholder = Placeholder(cachedGraph->saved_v, saved_v, nil, true);
Placeholder g_placeholder = Placeholder(cachedGraph->saved_g, saved_g, nil, true);
Placeholder norms_placeholder = Placeholder(cachedGraph->saved_norms, saved_norms, nil, true);

Placeholder grad_g_placeholder = Placeholder(cachedGraph->grad_g, grad_g);
Placeholder grad_v_placeholder = Placeholder(cachedGraph->grad_v, grad_v);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
grad_w_placeholder.getMPSGraphTensor() : grad_w_placeholder.getMPSGraphTensorData(),
norms_placeholder.getMPSGraphTensor() : norms_placeholder.getMPSGraphTensorData(),
v_placeholder.getMPSGraphTensor() : v_placeholder.getMPSGraphTensorData(),
g_placeholder.getMPSGraphTensor() : g_placeholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
grad_g_placeholder.getMPSGraphTensor() : grad_g_placeholder.getMPSGraphTensorData(),
grad_v_placeholder.getMPSGraphTensor() : grad_v_placeholder.getMPSGraphTensorData()
};

runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
}

return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}

} // namespace at::native
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6265,13 +6265,15 @@
dispatch:
CPU: weight_norm_cpu
CUDA: weight_norm_cuda
MPS: weight_norm_mps
autogen: _weight_norm_interface.out

- func: _weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
variants: function
dispatch:
CPU: weight_norm_backward_cpu
CUDA: weight_norm_backward_cuda
MPS: weight_norm_backward_mps
autogen: _weight_norm_interface_backward.out

- func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
Expand Down
80 changes: 80 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2538,6 +2538,86 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run
helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)

def test_weight_norm(self):
def helper(dim, layer='linear', dtype=torch.float32):
# linear layer
if layer == 'linear':
cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()

cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()

cpu_linear = torch.nn.Linear(5, 10, device='cpu')
linear = torch.nn.Linear(5, 10, device='mps')

with torch.no_grad():
cpu_linear.weight.copy_(cpu_weight)
cpu_linear.bias.copy_(cpu_bias)
linear.weight.copy_(weight)
linear.bias.copy_(bias)

cpu_norm = torch.nn.utils.weight_norm(cpu_linear, dim=dim)
norm = torch.nn.utils.weight_norm(linear, dim=dim)

cpu_out = cpu_norm(cpu_x)
out = norm(x)

self.assertEqual(cpu_out, out)

cpu_grad = torch.randn(cpu_out.shape)
grad = cpu_grad.to('mps')
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)

self.assertEqual(cpu_linear.weight_g.grad, linear.weight_g.grad)
self.assertEqual(cpu_linear.weight_v.grad, linear.weight_v.grad)

self.assertEqual(x.grad, cpu_x.grad)

# conv layer
if layer == 'conv':
cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
conv = torch.nn.Conv2d(3, 3, 3, device='mps')

with torch.no_grad():
conv.weight.copy_(cpu_conv.weight)
conv.bias.copy_(cpu_conv.bias)

cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim)
norm = torch.nn.utils.weight_norm(conv, dim=dim)

cpu_out = cpu_conv(cpu_x)
out = conv(x)

self.assertEqual(cpu_out, out)

cpu_grad = torch.randn(cpu_out.shape)
grad = cpu_grad.to('mps')
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)

self.assertEqual(cpu_conv.weight_g.grad, conv.weight_g.grad)
self.assertEqual(cpu_conv.weight_v.grad, conv.weight_v.grad)

self.assertEqual(x.grad, cpu_x.grad)

helper(0, layer='linear')
helper(1, layer='linear')
helper(-1, layer='linear')

helper(0, layer='conv')
helper(1, layer='conv')
helper(2, layer='conv')
helper(3, layer='conv')
helper(-1, layer='conv')

# Test conv2d
def test_conv2d_unit(self):
def helper(input_shape, wt_shape,
Expand Down

0 comments on commit 0317626

Please sign in to comment.