Skip to content

Commit

Permalink
Update on "Fix pool2d_shape_check breakage as a result of API signatu…
Browse files Browse the repository at this point in the history
…re update."

[ghstack-poisoned]
  • Loading branch information
Ashkan Aliabadi committed Nov 14, 2020
2 parents 8c7d669 + a9b6fa9 commit bb4c4e7
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 24 deletions.
89 changes: 68 additions & 21 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,21 @@ Tensor _s_dirichlet_cpu(const Tensor& alpha, c10::optional<Generator> gen) {
/* The largest consecutive integer representable in float32 (2^24) */
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);

Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional<Generator> gen) {
TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device");
TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ", self.scalar_type());
Tensor& multinomial_out(
Tensor& result,
const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen) {
TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ", result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
Expand All @@ -469,42 +479,79 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24");
if (self.dim() > 1) {
int64_t n_dist = self.size(-2);
result.resize_({n_dist, n_sample});
if (n_dist == 0) { return result; };
} else {
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");

if (self.dim() == 1) {
result.resize_({n_sample});
} else {
const int64_t n_dist = self.size(0);
result.resize_({n_dist, n_sample});
}
// Fast-path based on RobertoLat example.
if (result.numel() == 0) {
return result;
}

// Fast-path for no replacement.
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
// Half is not supported on CPU.
if (!with_replacement &&
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) {
if (result.numel()==0) return result;
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(is_valid.to<bool>(), "probability tensor contains either `inf`, `nan` or element < 0");
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
bool zero_prob_condition;
if (self.dim() == 1){
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(!zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
auto rand = at::empty_like(self).uniform_(0, 1, gen);
rand.log_().div_(self); //save memory with inplace operations
auto vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, rand, n_sample);
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");

// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
// Here we can apply exp to the formula which will not affect result of
// argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1)
Tensor q = at::empty_like(self).exponential_(1, gen);
// In theory the probability to generate 0 from exponential distribution is
// 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
// side, there is a very low probability to generate 0 from
// exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
// ignore it here, but there may be some risk to get invalid output on CPU.
at::div_out(q, self, q);
if (n_sample == 1) {
at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
} else {
Tensor vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, q, n_sample);
}
return result;
}
multinomial_stub(result.device().type(), result, self, n_sample, with_replacement, gen);

multinomial_stub(
result.device().type(),
result,
self,
n_sample,
with_replacement,
gen);
return result;
}

Tensor multinomial(const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional<Generator> gen) {
Tensor multinomial(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));
native::multinomial_out(result, self, n_sample, with_replacement, gen);
return result;
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,20 @@
dispatch:
CPU, CUDA: argmax

- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: argmax_out

- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: argmin

- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: argmin_out

- func: acosh(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
16 changes: 16 additions & 0 deletions test/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,21 @@ def _test_activation_module_impl(self, name, float_module_class, quantized_modul
msg="{} module API failed, qY_ref\n{} vs qY\n{}"
.format(name, qY_ref, qY))

def _test_leaky_relu_serialization(self):
scale_original = 10.0 / 256
zero_point_original = 1.0

quant_mod_original = nnq.LeakyReLU(scale_original, zero_point_original)
state_dict = quant_mod_original.state_dict()

scale_new = 5.0 / 256
zero_point_new = 2.0
quant_mod_new = nnq.LeakyReLU(scale_new, zero_point_new)
quant_mod_new.load_state_dict(state_dict)

self.assertEqual(quant_mod_original.scale, quant_mod_new.scale)
self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point)

def test_elu(self):
"""Tests the correctness of the ELU module.
The correctness is defined against the functional implementation.
Expand All @@ -717,6 +732,7 @@ def test_elu(self):

def test_leaky_relu(self):
self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2})
self._test_leaky_relu_serialization()

def test_sigmoid(self):
self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {})
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/quantized/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class LeakyReLU(torch.nn.LeakyReLU):
"""
def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2, inplace: bool = False):
super().__init__(negative_slope, inplace)
self.scale = scale
self.zero_point = zero_point
self.register_buffer('scale', torch.tensor([scale]))
self.register_buffer('zero_point', torch.tensor([zero_point]))

def forward(self, input):
return torch.ops.quantized.leaky_relu(
Expand All @@ -113,7 +113,7 @@ def from_float(cls, mod):
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)

class Sigmoid(torch.nn.Sigmoid):
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
Args:
scale: quantization scale of the output tensor
Expand Down

0 comments on commit bb4c4e7

Please sign in to comment.