Skip to content

Commit

Permalink
[ring] : Gaussian & Ternary sampler have deterministic tests for coef…
Browse files Browse the repository at this point in the history
…ficient bounds and will only print a warning if the distribution is not as expected.
  • Loading branch information
Pro7ech committed Jan 14, 2021
1 parent 7b5d6d3 commit 1668c96
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
26 changes: 20 additions & 6 deletions ring/ring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,18 @@ func testGaussianSampler(testContext *testParams, t *testing.T) {
coeffs[i] = float64(c.Int64())
}

// Checks that the coefficient are within the bounds
for _, c := range coeffs {
require.True(t, math.Abs(c) <= float64(DefaultBound))
}

// Computes the standard deviation
std := StandardDeviation(coeffs, 1.0)
sigma := StandardDeviation(coeffs, 1.0)

// Checks it stays within small bounds
require.True(t, (std < DefaultSigma+0.1) && (std > DefaultSigma-0.1))
// Print a warning the distance is too large
if math.Abs(sigma-DefaultSigma) > DefaultSigma*0.05 {
t.Log("warning : |(sigma - DefaultSigma)| > DefaultSigma * 0.05")
}
})
}

Expand Down Expand Up @@ -376,14 +383,21 @@ func testTernarySampler(testContext *testParams, t *testing.T) {
coeffs[i] = float64(c.Int64())
}

// Checks that the coefficient are within the bounds
for _, c := range coeffs {
require.True(t, math.Abs(c) <= 1)
}

// Computes the standard deviation
std := StandardDeviation(coeffs, 1.0)
sigma := StandardDeviation(coeffs, 1.0)

// sqrt(1/N * (N * 1 * (1-p))
stdWant := math.Sqrt(1 - p)

// Checks it stays within small bounds
require.True(t, (std < stdWant+0.1) && (std > stdWant-0.1))
// Print a warning the distance is too large
if math.Abs(sigma-stdWant) > stdWant*0.05 {
t.Log("warning : |(sigma - sqrt(1-p))| > sqrt(1-p) * 0.05")
}

})
}
Expand Down
3 changes: 2 additions & 1 deletion ring/utils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package ring

import (
"math"
"math/bits"
"math"
)


// StandardDeviation computes the scaled standard deviation of the input vector.
func StandardDeviation(vec []float64, scale float64) (std float64) {
// We assume that the error is centered around zero
Expand Down

0 comments on commit 1668c96

Please sign in to comment.