Skip to content

Commit

Permalink
MPS: Add multinomial op (#80760)
Browse files Browse the repository at this point in the history
Add multinomial with replacement

Pull Request resolved: #80760
Approved by: https://github.com/razarmehr, https://github.com/malfet
  • Loading branch information
kulinseth authored and pytorchmergebot committed Oct 3, 2022
1 parent 37013bb commit 6a842e3
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 0 deletions.
231 changes: 231 additions & 0 deletions aten/src/ATen/native/mps/operations/Distributions.mm
Original file line number Diff line number Diff line change
Expand Up @@ -350,5 +350,236 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator
"exponential_mps_:" + std::to_string(lambda), random_op_block);
}

Tensor& multinomial_with_replacement_mps_kernel(
const Tensor& self,
const int64_t n_sample,
c10::optional<Generator> generator,
Tensor& result) {

using namespace mps;

int inputSize = self.dim();
int numDist =
inputSize == 1 ? 1 : self.size(0);
int numCategories =
inputSize == 1 ? self.size(0) : self.size(1);

// Restructure data for 2d
auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self;
auto result_v = inputSize == 1 ? result.view({numDist, n_sample}) : result;

MPSStream* stream = getCurrentMPSStream();
uint64_t seed_ = c10::detail::getNonDeterministicRandom(true);

@autoreleasepool {
MPSShape* prob_shape = getMPSShape(self_v);
MPSGraph* mpsGraph = make_mps_graph();

auto prob_dtype = getMPSDataType(self_v.scalar_type());

// This is probability weights
MPSGraphTensor *probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);

MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:probTensor
axis:-1
name:nil];

MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:probTensor
secondaryTensor:sumProbs
name:nil];

auto ns_numCategories = [NSNumber numberWithInt:numCategories];
auto ns_numDist = [NSNumber numberWithInt:numDist];
auto ns_n_sample = [NSNumber numberWithInt:n_sample];

MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
shape:@[ns_numCategories, ns_numCategories]
dataType:prob_dtype];
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
numLower:0
numUpper:-1
name:nil];
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
secondaryTensor:upperTriangle
name:nil];

MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
secondaryTensor:normalizedProbs
name:nil];

upperProbRange = [mpsGraph reshapeTensor:upperProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];

MPSGraphTensor *stateTensor = [mpsGraph randomPhiloxStateTensorWithSeed:seed_
name:nil];
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
dataType:prob_dtype];
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
descriptor:descriptor
stateTensor:stateTensor
name:nil];
MPSGraphTensor *randomTensor = generatorTensors[0];

auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
int broadcastShapeVals[3] = {numDist, n_sample, numCategories};
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
dataType:MPSDataTypeUInt32];

MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
toShape:broadcastShape
name:nil];
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
secondaryTensor:lowerProbRange
name:nil];
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
secondaryTensor:upperProbRange
name:nil];
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
secondaryTensor:sampleBelow
name:nil];
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
toType:MPSDataTypeInt32
name:@"sampleMask"];
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
withShapeTensor:broadcastShapeTensor
name:nil];
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
secondaryTensor:sampleMask
name:nil];
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
axis:-1
name:nil];
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ns_numDist ,ns_n_sample]
name:nil];
MPSGraphTensor *resultTensor = [mpsGraph castTensor:reshapeTensor
toType:getMPSDataType(result.scalar_type())
name:@"resultTensor"];

auto probPlaceholder = Placeholder(probTensor, self_v);
auto outputPlaceholder = Placeholder(resultTensor, result_v);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, mpsGraph, feeds, results);
}

return result;

}

/* The largest consecutive integer representable in float32 (2^24) */
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);

Tensor& multinomial_out_mps(const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen,
Tensor& result) {

std::cout<<"Multinomial MPS\n";

TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ", result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
int64_t n_categories = self.size(-1);
TORCH_CHECK(with_replacement || (n_sample <= n_categories),
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");

if (self.dim() == 1) {
result.resize_({n_sample});
} else {
const int64_t n_dist = self.size(0);
result.resize_({n_dist, n_sample});
}
if (result.numel() == 0) {
return result;
}

// Fast-path for no replacement.
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
// Half is not supported on CPU.
TORCH_CHECK(
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half),
"multinomial is not implemented for half on CPU");
if (!with_replacement) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
if (self.dim() == 1){
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");

// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
// Here we can apply exp to the formula which will not affect result of
// argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1)
Tensor q = at::empty_like(self).exponential_(1, gen);
// In theory the probability to generate 0 from exponential distribution is
// 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
// side, there is a very low probability to generate 0 from
// exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
// ignore it here, but there may be some risk to get invalid output on CPU.
at::div_out(q, self, q);
if (n_sample == 1) {
at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
} else {
Tensor vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, q, n_sample);
}
return result;
}

result = multinomial_with_replacement_mps_kernel(const_cast<Tensor&>(self), n_sample, gen, result);

return result;
}

Tensor multinomial_mps(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));
multinomial_out_mps(self, n_sample, with_replacement, gen, result);
return result;
}

} // namespace native
} // namespace at
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8289,11 +8289,13 @@
tags: nondeterministic_seeded
dispatch:
CPU, CUDA: multinomial_out
MPS: multinomial_out_mps

- func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
variants: method, function
dispatch:
CPU, CUDA: multinomial
MPS: multinomial_mps
tags: nondeterministic_seeded

- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
20 changes: 20 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4877,6 +4877,26 @@ def helper(shape):
helper(10000)
helper((10000, 40))

def test_multinomial(self):
# Test with num_dist = 1
def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
prob_tensor = cpu_prob_tensor.detach().clone().to('mps')

mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
if(not replacement):
print(mps_out.to('cpu'))
else:
# Compare "real" with theoretical values
print(mps_out.to('cpu').float().mean(), compare_mean)
print(mps_out.to('cpu').float().std() ** 2, compare_var)

# TODO: Add tests for data types
helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)

class TestNNMPS(NNTestCase):

Expand Down

0 comments on commit 6a842e3

Please sign in to comment.