Skip to content

Commit

Permalink
[ring] : improved tests for Gaussian and ternary sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Pro7ech committed Jan 14, 2021
1 parent 05edc0d commit 7b5d6d3
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 43 deletions.
6 changes: 4 additions & 2 deletions ckks/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"math"
"sort"

"github.com/ldsec/lattigo/v2/ring"
)

// NoiseEstimator is a struct storing the necessary pre-computed
Expand Down Expand Up @@ -64,7 +66,7 @@ func (ne *NoiseEstimator) StandardDeviationSlotDomain(valuesWant, valuesHave []c
ne.valuesfloat[2*i+1] = imag(err)
}

return StandardDeviation(ne.valuesfloat[:len(valuesWant)*2], scale)
return ring.StandardDeviation(ne.valuesfloat[:len(valuesWant)*2], scale)
}

// StandardDeviationCoefDomain returns the scaled standard deviation of the [coefficient domain] of the difference between two complex vectors in the [slot domains].
Expand All @@ -81,7 +83,7 @@ func (ne *NoiseEstimator) StandardDeviationCoefDomain(valuesWant, valuesHave []c
ne.valuesfloat[2*i+1] = imag(ne.values[i])
}

return StandardDeviation(ne.valuesfloat[:len(valuesWant)*2], scale)
return ring.StandardDeviation(ne.valuesfloat[:len(valuesWant)*2], scale)
}

// PrecisionStats is a struct storing statistic about the precision of a CKKS plaintext
Expand Down
22 changes: 0 additions & 22 deletions ckks/utils.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,11 @@
package ckks

import (
"math"
"math/big"

"github.com/ldsec/lattigo/v2/ring"
)

// StandardDeviation computes the scaled standard deviation of the input vector.
func StandardDeviation(vec []float64, scale float64) (std float64) {
// We assume that the error is centered around zero
var err, tmp, mean, n float64

n = float64(len(vec))

for _, c := range vec {
mean += c
}

mean /= n

for _, c := range vec {
tmp = c - mean
err += tmp * tmp
}

return math.Sqrt(err/n) * scale
}

func scaleUpExact(value float64, n float64, q uint64) (res uint64) {

var isNegative bool
Expand Down
3 changes: 2 additions & 1 deletion examples/ckks/bootstrapping/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ func printDebug(params *ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant
fmt.Printf("ValuesTest: %6.10f %6.10f %6.10f %6.10f...\n", valuesTest[0], valuesTest[1], valuesTest[2], valuesTest[3])
fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3])

precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, math.Exp2(53))
metrics := ckks.NewNoiseEstimator(params)
precStats := metrics.PrecisionStats(valuesWant, valuesTest)

fmt.Println(precStats.String())
fmt.Println()
Expand Down
5 changes: 3 additions & 2 deletions examples/ckks/euler/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,12 @@ func printDebug(params *ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant
fmt.Printf("Scale: 2^%f\n", math.Log2(ciphertext.Scale()))
fmt.Printf("ValuesTest: %6.10f %6.10f %6.10f %6.10f...\n", valuesTest[0], valuesTest[1], valuesTest[2], valuesTest[3])
fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3])
fmt.Println()

precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, math.Exp2(53))
metrics := ckks.NewNoiseEstimator(params)
precStats := metrics.PrecisionStats(valuesWant, valuesTest)

fmt.Println(precStats.String())
fmt.Println()

return
}
Expand Down
5 changes: 3 additions & 2 deletions examples/ckks/sigmoid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ func printDebug(params *ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant
fmt.Printf("Scale: 2^%f\n", math.Log2(ciphertext.Scale()))
fmt.Printf("ValuesTest: %6.10f %6.10f %6.10f %6.10f...\n", valuesTest[0], valuesTest[1], valuesTest[2], valuesTest[3])
fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3])
fmt.Println()

precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, math.Exp2(53))
metrics := ckks.NewNoiseEstimator(params)
precStats := metrics.PrecisionStats(valuesWant, valuesTest)

fmt.Println(precStats.String())
fmt.Println()

return
}
Expand Down
4 changes: 2 additions & 2 deletions ring/ring_sampler_gaussian.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (gaussianSampler *GaussianSampler) ReadLvl(level uint64, pol *Poly, baseRin
for {
coeffFlo, sign = gaussianSampler.normFloat64()

if coeffInt = uint64(coeffFlo * sigma + 0.5); coeffInt <= bound {
if coeffInt = uint64(coeffFlo*sigma + 0.5); coeffInt <= bound {
break
}
}
Expand Down Expand Up @@ -91,7 +91,7 @@ func (gaussianSampler *GaussianSampler) ReadAndAddLvl(level uint64, pol *Poly, b
for {
coeffFlo, sign = gaussianSampler.normFloat64()

if coeffInt = uint64(coeffFlo * sigma + 0.5); coeffInt <= bound {
if coeffInt = uint64(coeffFlo*sigma + 0.5); coeffInt <= bound {
break
}
}
Expand Down
77 changes: 65 additions & 12 deletions ring/ring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ring
import (
"flag"
"fmt"
"math"
"math/big"
"testing"

Expand Down Expand Up @@ -303,35 +304,87 @@ func testUniformSampler(testContext *testParams, t *testing.T) {
func testGaussianSampler(testContext *testParams, t *testing.T) {

t.Run(testString("GaussianSampler/", testContext.ringQ), func(t *testing.T) {

gaussianSampler := NewGaussianSampler(testContext.prng)

// Samples a poly
pol := gaussianSampler.ReadNew(testContext.ringQ, DefaultSigma, DefaultBound)

for i := uint64(0); i < testContext.ringQ.N; i++ {
for j, qi := range testContext.ringQ.Modulus {
require.False(t, uint64(DefaultBound) < pol.Coeffs[j][i] && pol.Coeffs[j][i] < (qi-uint64(DefaultBound)))
// RNS reconstruction mod Q
coeffsBigint := make([]*big.Int, testContext.ringQ.N)
testContext.ringQ.PolyToBigint(pol, coeffsBigint)

// Extract the coefficient to float64 (they should be low norm)
coeffs := make([]float64, len(coeffsBigint))

QBigint := testContext.ringQ.ModulusBigint
bigQHalf := new(big.Int)
bigQHalf.Set(QBigint)
bigQHalf.Rsh(bigQHalf, 1)

var sign int
for i, c := range coeffsBigint {

sign = c.Cmp(bigQHalf)
if sign == 1 || sign == 0 {
c.Sub(c, QBigint)
}

coeffs[i] = float64(c.Int64())
}

// Computes the standard deviation
std := StandardDeviation(coeffs, 1.0)

// Checks it stays within small bounds
require.True(t, (std < DefaultSigma+0.1) && (std > DefaultSigma-0.1))
})
}

func testTernarySampler(testContext *testParams, t *testing.T) {

// p is the probability that a coefficient will be zero
// Hence the standard deviation should be (1-p)^0.5
for _, p := range []float64{.5, 1. / 3., 128. / 65536.} {
t.Run(testString(fmt.Sprintf("TernarySampler/p=%1.2f/", p), testContext.ringQ), func(t *testing.T) {

prng, err := utils.NewPRNG()
if err != nil {
panic(err)
}
ternarySampler := NewTernarySampler(prng, testContext.ringQ, p, false)
ternarySampler := NewTernarySampler(testContext.prng, testContext.ringQ, p, false)

// Samples a poly
pol := ternarySampler.ReadNew()
for i, mod := range testContext.ringQ.Modulus {
minOne := mod - 1
for _, c := range pol.Coeffs[i] {
require.True(t, c == 0 || c == minOne || c == 1)

// RNS reconstruction mod Q
coeffsBigint := make([]*big.Int, testContext.ringQ.N)
testContext.ringQ.PolyToBigint(pol, coeffsBigint)

// Extract the coefficient to float64 (they should be low norm)
coeffs := make([]float64, len(coeffsBigint))

QBigint := testContext.ringQ.ModulusBigint
bigQHalf := new(big.Int)
bigQHalf.Set(QBigint)
bigQHalf.Rsh(bigQHalf, 1)

var sign int
for i, c := range coeffsBigint {

sign = c.Cmp(bigQHalf)
if sign == 1 || sign == 0 {
c.Sub(c, QBigint)
}

coeffs[i] = float64(c.Int64())
}

// Computes the standard deviation
std := StandardDeviation(coeffs, 1.0)

// sqrt(1/N * (N * 1 * (1-p))
stdWant := math.Sqrt(1 - p)

// Checks it stays within small bounds
require.True(t, (std < stdWant+0.1) && (std > stdWant-0.1))

})
}

Expand Down
22 changes: 22 additions & 0 deletions ring/utils.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
package ring

import (
"math"
"math/bits"
)

// StandardDeviation computes the scaled standard deviation of the input vector.
func StandardDeviation(vec []float64, scale float64) (std float64) {
// We assume that the error is centered around zero
var err, tmp, mean, n float64

n = float64(len(vec))

for _, c := range vec {
mean += c
}

mean /= n

for _, c := range vec {
tmp = c - mean
err += tmp * tmp
}

return math.Sqrt(err/n) * scale
}

// Min returns the minimum between to int
func Min(x, y int) int {
if x > y {
Expand Down

0 comments on commit 7b5d6d3

Please sign in to comment.