Skip to content

Commit

Permalink
[SobolEngine] Update direction numbers to 21201 dims (#49710)
Browse files Browse the repository at this point in the history
Summary:
Performs the update that was suggested in #41489

Adjust the functionality to largely match that pf the scipy companion PR scipy/scipy#10844, including
- a new `draw_base2` method
- include zero as the first point in the (unscrambled) Sobol sequence

The scipy PR is also quite opinionated if the `draw` method doesn't get called with a base 2 number (for which the resulting sequence has nice properties, see the scipy PR for a comprehensive discussion of this).

Note that this update is a **breaking change** in the sense that sequences generated with the same parameters after as before will not be identical! They will have the same (better, arguably) distributional properties, but calling the engine with the same seed will result in different numbers in the sequence.

Pull Request resolved: #49710

Test Plan:
```
from torch.quasirandom import SobolEngine

sobol = SobolEngine(3)
sobol.draw(4)

sobol = SobolEngine(4, scramble=True)
sobol.draw(5)

sobol = SobolEngine(4, scramble=True)
sobol.draw_base2(2)
```

Reviewed By: malfet

Differential Revision: D25657233

Pulled By: Balandat

fbshipit-source-id: ce8ccafe90fc968ed811634b5f37c2eb208af985
  • Loading branch information
Balandat authored and facebook-github-bot committed Jan 30, 2021
1 parent e26fccc commit 36b8cb2
Show file tree
Hide file tree
Showing 5 changed files with 42,675 additions and 1,350 deletions.
19 changes: 15 additions & 4 deletions aten/src/ATen/native/SobolEngineOps.cpp
Expand Up @@ -126,19 +126,28 @@ Tensor& _sobol_engine_initialize_state_(Tensor& sobolstate, int64_t dimension) {
TORCH_CHECK(sobolstate.dtype() == at::kLong,
"sobolstate needs to be of type ", at::kLong);

/// First row of `sobolstate` is 1
sobolstate.select(0, 0).fill_(1);

/// Use a tensor accessor for `sobolstate`
auto ss_a = sobolstate.accessor<int64_t, 2>();
for (int64_t d = 0; d < dimension; ++d) {

/// First row of `sobolstate` is all 1s
for (int64_t m = 0; m < MAXBIT; ++m) {
ss_a[0][m] = 1;
}

/// Remaining rows of sobolstate (row 2 through dim, indexed by [1:dim])
for (int64_t d = 1; d < dimension; ++d) {
int64_t p = poly[d];
int64_t m = bit_length(p) - 1;

// First m elements of row d comes from initsobolstate
for (int64_t i = 0; i < m; ++i) {
ss_a[d][i] = initsobolstate[d][i];
}

// Fill in remaining elements of v as in Section 2 (top of pg. 90) of:
// P. Bratley and B. L. Fox. Algorithm 659: Implementing sobol's
// quasirandom sequence generator. ACM Trans.
// Math. Softw., 14(1):88-100, Mar. 1988.
for (int64_t j = m; j < MAXBIT; ++j) {
int64_t newv = ss_a[d][j - m];
int64_t pow2 = 1;
Expand All @@ -152,6 +161,8 @@ Tensor& _sobol_engine_initialize_state_(Tensor& sobolstate, int64_t dimension) {
}
}

/// Multiply each column of sobolstate by power of 2:
/// sobolstate * [2^(maxbit-1), 2^(maxbit-2),..., 2, 1]
Tensor pow2s = at::pow(2, at::native::arange((MAXBIT - 1), -1, -1, sobolstate.options()));
sobolstate.mul_(pow2s);
return sobolstate;
Expand Down

0 comments on commit 36b8cb2

Please sign in to comment.