Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 89 additions & 25 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,43 @@ void set_axes(NSMutableArray<NSNumber *> * &axes,
}
}

// Helper function to prepare axes and tensor shapes
void set_axes_and_shapes(const Tensor& input_t,
IntArrayRef dims,
NSMutableArray<NSNumber*> * &axes,
NSMutableArray<NSNumber*> * &apparent_input_shape,
NSMutableArray<NSNumber*> * &apparent_output_shape,
NSMutableArray<NSNumber*> * &output_shape) {

IntArrayRef input_shape = input_t.sizes();

int64_t num_input_dims = input_shape.size();
int64_t num_reduce_dims = dims.size();
int64_t num_output_dims;

num_output_dims = num_reduce_dims == 0 ? 1 : num_input_dims;

// Reduction axes
set_axes(axes, num_reduce_dims, dims, input_shape.size());

// Shapes
set_apparent_shapes(apparent_output_shape,
apparent_input_shape,
num_reduce_dims,
num_input_dims,
num_output_dims,
input_shape,
axes);

// Squeeze dims for output shape
output_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:0];
for(int i=0; i < num_output_dims; i++) {
if([apparent_output_shape[i] longValue] != 1) {
[output_shape addObject:apparent_output_shape[i]];
}
}
}

void reduction_out_mps
(const Tensor& input_t,
IntArrayRef dim,
Expand All @@ -107,6 +144,13 @@ void set_axes(NSMutableArray<NSNumber *> * &axes,

namespace native_mps = at::native::mps;

NSMutableArray<NSNumber*> *axes = nil;
NSMutableArray<NSNumber*> *apparent_input_shape = nil;
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
NSMutableArray<NSNumber*> *output_shape = nil;

set_axes_and_shapes(input_t, dim, axes, apparent_input_shape, apparent_output_shape, output_shape);

// Derive from MPSCachedGraph
struct CachedGraph : public native_mps::MPSCachedGraph
{
Expand All @@ -117,27 +161,6 @@ void set_axes(NSMutableArray<NSNumber *> * &axes,

native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();

int64_t num_input_dims = input_shape.size();
int64_t num_reduce_dims = dim.size();
int64_t num_output_dims;

// For output shape calculation, assume that keepdim is true
num_output_dims = num_input_dims;
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
NSMutableArray<NSNumber*> *apparent_input_shape = nil;

// Reduction axes
NSMutableArray<NSNumber *> *axes;
set_axes(axes, num_reduce_dims, dim, input_shape.size());

set_apparent_shapes(apparent_output_shape,
apparent_input_shape,
num_reduce_dims,
num_input_dims,
num_output_dims,
input_shape,
axes);

if (output_t.numel() == 0 || input_t.numel() == 0) {
return;
}
Expand Down Expand Up @@ -173,22 +196,34 @@ void set_axes(NSMutableArray<NSNumber *> * &axes,

MPSGraphTensor* castOutputTensor = nil;

if(reduction_type == "sum")
if(reduction_type == "sum") {
castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor
axes:axes
name:nil];
else if(reduction_type == "prod")
} else if(reduction_type == "prod") {
castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor
axes:axes
name:nil];
else if(reduction_type == "mean")
} else if(reduction_type == "mean") {
castOutputTensor = [mpsGraph meanOfTensor:inputTensor
axes:axes
name:nil];
} else if(reduction_type == "count_nonzero") {
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0
dataType:castInputTensor.dataType];

MPSGraphTensor* nonZeros = [mpsGraph notEqualWithPrimaryTensor:castInputTensor
secondaryTensor:zeros
name:nil];

castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros
axes:axes
name:nil];
}

MPSGraphTensor* outputTensor = nil;

if(input_t.scalar_type() != ScalarType::Float)
if(output_t.scalar_type() != ScalarType::Float)
outputTensor = [mpsGraph castTensor:castOutputTensor
toType:(native_mps::getMPSDataType(output_t.scalar_type()))
name:@"outputTensor"];
Expand Down Expand Up @@ -281,6 +316,35 @@ Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
return output_t;
}


Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){
NSMutableArray<NSNumber*> *axes = nil;
NSMutableArray<NSNumber*> *apparent_input_shape = nil;
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
NSMutableArray<NSNumber*> *output_shape = nil;

set_axes_and_shapes(self, dims, axes, apparent_input_shape, apparent_output_shape, output_shape);

int64_t* raw_output_shape = (int64_t *)malloc([output_shape count] * sizeof(int64_t));
for(int i=0; i < [output_shape count]; i++) {
raw_output_shape[i] = [output_shape[i] longValue];
}

Tensor output_t = at::native::empty_mps(
IntArrayRef(raw_output_shape, [output_shape count]),
ScalarType::Long,
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

reduction_out_mps(self, dims, false, self.scalar_type(), const_cast<Tensor&>(output_t), "count_nonzero", "count_nonzero_mps");

free(raw_output_shape);

return output_t;
}

TORCH_IMPL_FUNC(mean_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,7 @@
dispatch:
CPU: count_nonzero_cpu
CUDA: count_nonzero_cuda
MPS: count_nonzero_mps

- func: count_nonzero(Tensor self, int? dim=None) -> Tensor
variants: function, method
Expand Down
31 changes: 31 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,37 @@ def helper(shape, repeats):
helper((3, 4, 5), (2, 3, 4, 5))
helper((3, 4, 5), (2, 2, 2))

def test_count_nonzero(self):
def helper(dtype):
n = [
[[1, 0, 2], [3, 0, 2], [7, 9, -4]],
[[0, 2, 3], [3, 2, 1], [2, 0, 0]],
]
cpu_x = torch.tensor(n, dtype=dtype)
mps_x = torch.tensor(n, dtype=dtype).to('mps')

# All non-zeros
self.assertEqual(
torch.count_nonzero(cpu_x),
torch.count_nonzero(mps_x)
)

# dim=1
self.assertEqual(
torch.count_nonzero(cpu_x, dim=1),
torch.count_nonzero(mps_x, dim=1)
)

# dim=(0, 1)
self.assertEqual(
torch.count_nonzero(cpu_x, dim=(0, 1)),
torch.count_nonzero(mps_x, dim=(0, 1))
)
helper(torch.int32)
helper(torch.int64)
helper(torch.float16)
helper(torch.float32)

def _test_module_empty_input(self, module, inp, check_size=True):
inp.requires_grad_(True)
out = module(inp)
Expand Down