-
Notifications
You must be signed in to change notification settings - Fork 172
/
keygen_relin.go
362 lines (291 loc) · 11.2 KB
/
keygen_relin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
package drlwe
import (
"errors"
"github.com/tuneinsight/lattigo/v4/ring"
"github.com/tuneinsight/lattigo/v4/rlwe"
"github.com/tuneinsight/lattigo/v4/rlwe/ringqp"
"github.com/tuneinsight/lattigo/v4/utils"
)
// RKGProtocol is the structure storing the parameters and and precomputations for the collective relinearization key generation protocol.
type RKGProtocol struct {
params rlwe.Parameters
gaussianSamplerQ *ring.GaussianSampler
ternarySamplerQ *ring.TernarySampler // sampling in Montgomery form
tmpPoly1 ringqp.Poly
tmpPoly2 ringqp.Poly
}
// ShallowCopy creates a shallow copy of RKGProtocol in which all the read-only data-structures are
// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned
// RKGProtocol can be used concurrently.
func (ekg *RKGProtocol) ShallowCopy() *RKGProtocol {
var err error
prng, err := utils.NewPRNG()
if err != nil {
panic(err)
}
params := ekg.params
return &RKGProtocol{
params: ekg.params,
gaussianSamplerQ: ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma())),
ternarySamplerQ: ring.NewTernarySamplerWithHammingWeight(prng, params.RingQ(), params.HammingWeight(), false),
tmpPoly1: params.RingQP().NewPoly(),
tmpPoly2: params.RingQP().NewPoly(),
}
}
// RKGShare is a share in the RKG protocol.
type RKGShare struct {
Value [][][2]ringqp.Poly
}
// RKGCRP is a type for common reference polynomials in the RKG protocol.
type RKGCRP [][]ringqp.Poly
// NewRKGProtocol creates a new RKG protocol struct.
func NewRKGProtocol(params rlwe.Parameters) *RKGProtocol {
rkg := new(RKGProtocol)
rkg.params = params
var err error
prng, err := utils.NewPRNG()
if err != nil {
panic(err)
}
rkg.gaussianSamplerQ = ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma()))
rkg.ternarySamplerQ = ring.NewTernarySamplerWithHammingWeight(prng, params.RingQ(), params.HammingWeight(), false)
rkg.tmpPoly1 = params.RingQP().NewPoly()
rkg.tmpPoly2 = params.RingQP().NewPoly()
return rkg
}
// AllocateShare allocates the share of the EKG protocol.
func (ekg *RKGProtocol) AllocateShare() (ephSk *rlwe.SecretKey, r1 *RKGShare, r2 *RKGShare) {
params := ekg.params
ephSk = rlwe.NewSecretKey(params)
r1, r2 = new(RKGShare), new(RKGShare)
decompRNS := params.DecompRNS(params.QCount()-1, params.PCount()-1)
decompPw2 := params.DecompPw2(params.QCount()-1, params.PCount()-1)
r1.Value = make([][][2]ringqp.Poly, decompRNS)
r2.Value = make([][][2]ringqp.Poly, decompRNS)
for i := 0; i < decompRNS; i++ {
r1.Value[i] = make([][2]ringqp.Poly, decompPw2)
r2.Value[i] = make([][2]ringqp.Poly, decompPw2)
for j := 0; j < decompPw2; j++ {
r1.Value[i][j][0] = ekg.params.RingQP().NewPoly()
r1.Value[i][j][1] = ekg.params.RingQP().NewPoly()
r2.Value[i][j][0] = ekg.params.RingQP().NewPoly()
r2.Value[i][j][1] = ekg.params.RingQP().NewPoly()
}
}
return
}
// SampleCRP samples a common random polynomial to be used in the RKG protocol from the provided
// common reference string.
func (ekg *RKGProtocol) SampleCRP(crs CRS) RKGCRP {
params := ekg.params
decompRNS := params.DecompRNS(params.QCount()-1, params.PCount()-1)
decompPw2 := params.DecompPw2(params.QCount()-1, params.PCount()-1)
crp := make([][]ringqp.Poly, decompRNS)
us := ringqp.NewUniformSampler(crs, *params.RingQP())
for i := range crp {
crp[i] = make([]ringqp.Poly, decompPw2)
for j := range crp[i] {
crp[i][j] = params.RingQP().NewPoly()
us.Read(crp[i][j])
}
}
return RKGCRP(crp)
}
// GenShareRoundOne is the first of three rounds of the RKGProtocol protocol. Each party generates a pseudo encryption of
// its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + s_i*w + e_i] and broadcasts it to the other
// j-1 parties.
func (ekg *RKGProtocol) GenShareRoundOne(sk *rlwe.SecretKey, crp RKGCRP, ephSkOut *rlwe.SecretKey, shareOut *RKGShare) {
// Given a base decomposition w_i (here the CRT decomposition)
// computes [-u*a_i + P*s_i + e_i]
// where a_i = crp_i
ringQ := ekg.params.RingQ()
ringQP := ekg.params.RingQP()
levelQ := sk.LevelQ()
levelP := sk.LevelP()
hasModulusP := levelP > -1
if hasModulusP {
// Computes P * sk
ringQ.MulScalarBigintLvl(levelQ, sk.Value.Q, ringQP.RingP.ModulusAtLevel[levelP], ekg.tmpPoly1.Q)
} else {
levelP = 0
ring.CopyLvl(levelQ, sk.Value.Q, ekg.tmpPoly1.Q)
}
ringQ.InvMForm(ekg.tmpPoly1.Q, ekg.tmpPoly1.Q)
// u
ekg.ternarySamplerQ.Read(ephSkOut.Value.Q)
if hasModulusP {
ringQP.ExtendBasisSmallNormAndCenter(ephSkOut.Value.Q, levelP, nil, ephSkOut.Value.P)
}
ringQP.NTTLvl(levelQ, levelP, ephSkOut.Value, ephSkOut.Value)
ringQP.MFormLvl(levelQ, levelP, ephSkOut.Value, ephSkOut.Value)
RNSDecomp := len(shareOut.Value)
BITDecomp := len(shareOut.Value[0])
var index int
for j := 0; j < BITDecomp; j++ {
for i := 0; i < RNSDecomp; i++ {
// h = e
ekg.gaussianSamplerQ.Read(shareOut.Value[i][j][0].Q)
if hasModulusP {
ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i][j][0].Q, levelP, nil, shareOut.Value[i][j][0].P)
}
ringQP.NTTLvl(levelQ, levelP, shareOut.Value[i][j][0], shareOut.Value[i][j][0])
// h = sk*CrtBaseDecompQi + e
for k := 0; k < levelP+1; k++ {
index = i*(levelP+1) + k
// Handles the case where nb pj does not divides nb qi
if index >= levelQ+1 {
break
}
qi := ringQ.Modulus[index]
skP := ekg.tmpPoly1.Q.Coeffs[index]
h := shareOut.Value[i][j][0].Q.Coeffs[index]
for w := 0; w < ringQ.N; w++ {
h[w] = ring.CRed(h[w]+skP[w], qi)
}
}
// h = sk*CrtBaseDecompQi + -u*a + e
ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, ephSkOut.Value, crp[i][j], shareOut.Value[i][j][0])
// Second Element
// e_2i
ekg.gaussianSamplerQ.Read(shareOut.Value[i][j][1].Q)
if hasModulusP {
ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i][j][1].Q, levelP, nil, shareOut.Value[i][j][1].P)
}
ringQP.NTTLvl(levelQ, levelP, shareOut.Value[i][j][1], shareOut.Value[i][j][1])
// s*a + e_2i
ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, sk.Value, crp[i][j], shareOut.Value[i][j][1])
}
ringQ.MulScalar(ekg.tmpPoly1.Q, 1<<ekg.params.Pow2Base(), ekg.tmpPoly1.Q)
}
}
// GenShareRoundTwo is the second of three rounds of the RKGProtocol protocol. Upon receiving the j-1 shares, each party computes :
//
// [s_i * sum([-u_j*a + s_j*w + e_j]) + e_i1, s_i*a + e_i2]
//
// = [s_i * (-u*a + s*w + e) + e_i1, s_i*a + e_i2]
//
// and broadcasts both values to the other j-1 parties.
func (ekg *RKGProtocol) GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGShare, shareOut *RKGShare) {
ringQP := ekg.params.RingQP()
levelQ := sk.LevelQ()
hasModulusP := sk.Value.P != nil
var levelP int
if hasModulusP {
levelP = sk.LevelP()
}
// (u_i - s_i)
ringQP.SubLvl(levelQ, levelP, ephSk.Value, sk.Value, ekg.tmpPoly1)
RNSDecomp := len(shareOut.Value)
BITDecomp := len(shareOut.Value[0])
// Each sample is of the form [-u*a_i + s*w_i + e_i]
// So for each element of the base decomposition w_i:
for i := 0; i < RNSDecomp; i++ {
for j := 0; j < BITDecomp; j++ {
// Computes [(sum samples)*sk + e_1i, sk*a + e_2i]
// (AggregateShareRoundTwo samples) * sk
ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, round1.Value[i][j][0], sk.Value, shareOut.Value[i][j][0])
// (AggregateShareRoundTwo samples) * sk + e_1i
ekg.gaussianSamplerQ.Read(ekg.tmpPoly2.Q)
if hasModulusP {
ringQP.ExtendBasisSmallNormAndCenter(ekg.tmpPoly2.Q, levelP, nil, ekg.tmpPoly2.P)
}
ringQP.NTTLvl(levelQ, levelP, ekg.tmpPoly2, ekg.tmpPoly2)
ringQP.AddLvl(levelQ, levelP, shareOut.Value[i][j][0], ekg.tmpPoly2, shareOut.Value[i][j][0])
// second part
// (u - s) * (sum [x][s*a_i + e_2i]) + e3i
ekg.gaussianSamplerQ.Read(shareOut.Value[i][j][1].Q)
if hasModulusP {
ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i][j][1].Q, levelP, nil, shareOut.Value[i][j][1].P)
}
ringQP.NTTLvl(levelQ, levelP, shareOut.Value[i][j][1], shareOut.Value[i][j][1])
ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, ekg.tmpPoly1, round1.Value[i][j][1], shareOut.Value[i][j][1])
}
}
}
// AggregateShares combines two RKG shares into a single one.
func (ekg *RKGProtocol) AggregateShares(share1, share2, shareOut *RKGShare) {
ringQP := ekg.params.RingQP()
levelQ := share1.Value[0][0][0].Q.Level()
var levelP int
if share1.Value[0][0][0].P != nil {
levelP = share1.Value[0][0][0].P.Level()
}
RNSDecomp := len(shareOut.Value)
BITDecomp := len(shareOut.Value[0])
for i := 0; i < RNSDecomp; i++ {
for j := 0; j < BITDecomp; j++ {
ringQP.AddLvl(levelQ, levelP, share1.Value[i][j][0], share2.Value[i][j][0], shareOut.Value[i][j][0])
ringQP.AddLvl(levelQ, levelP, share1.Value[i][j][1], share2.Value[i][j][1], shareOut.Value[i][j][1])
}
}
}
// GenRelinearizationKey computes the generated RLK from the public shares and write the result in evalKeyOut.
func (ekg *RKGProtocol) GenRelinearizationKey(round1 *RKGShare, round2 *RKGShare, evalKeyOut *rlwe.RelinearizationKey) {
ringQP := ekg.params.RingQP()
levelQ := round1.Value[0][0][0].Q.Level()
var levelP int
if round1.Value[0][0][0].P != nil {
levelP = round1.Value[0][0][0].P.Level()
}
RNSDecomp := len(round1.Value)
BITDecomp := len(round1.Value[0])
for i := 0; i < RNSDecomp; i++ {
for j := 0; j < BITDecomp; j++ {
ringQP.AddLvl(levelQ, levelP, round2.Value[i][j][0], round2.Value[i][j][1], evalKeyOut.Keys[0].Value[i][j].Value[0])
evalKeyOut.Keys[0].Value[i][j].Value[1].Copy(round1.Value[i][j][1])
ringQP.MFormLvl(levelQ, levelP, evalKeyOut.Keys[0].Value[i][j].Value[0], evalKeyOut.Keys[0].Value[i][j].Value[0])
ringQP.MFormLvl(levelQ, levelP, evalKeyOut.Keys[0].Value[i][j].Value[1], evalKeyOut.Keys[0].Value[i][j].Value[1])
}
}
}
// MarshalBinary encodes the target element on a slice of bytes.
func (share *RKGShare) MarshalBinary() ([]byte, error) {
//we have modulus * bitLog * Len of 1 ring rings
data := make([]byte, 2+2*share.Value[0][0][0].MarshalBinarySize64()*len(share.Value)*len(share.Value[0]))
if len(share.Value) > 0xFF {
return []byte{}, errors.New("RKGShare : uint8 overflow on length")
}
if len(share.Value[0]) > 0xFF {
return []byte{}, errors.New("RKGShare : uint8 overflow on length")
}
data[0] = uint8(len(share.Value))
data[1] = uint8(len(share.Value[0]))
//write all of our rings in the data
//write all the polys
ptr := 2
var inc int
var err error
for i := range share.Value {
for _, el := range share.Value[i] {
if inc, err = el[0].Encode64(data[ptr:]); err != nil {
return []byte{}, err
}
ptr += inc
if inc, err = el[1].Encode64(data[ptr:]); err != nil {
return []byte{}, err
}
ptr += inc
}
}
return data, nil
}
// UnmarshalBinary decodes a slice of bytes on the target element.
func (share *RKGShare) UnmarshalBinary(data []byte) (err error) {
share.Value = make([][][2]ringqp.Poly, data[0])
ptr := 2
var inc int
for i := range share.Value {
share.Value[i] = make([][2]ringqp.Poly, data[1])
for j := range share.Value[i] {
if inc, err = share.Value[i][j][0].Decode64(data[ptr:]); err != nil {
return err
}
ptr += inc
if inc, err = share.Value[i][j][1].Decode64(data[ptr:]); err != nil {
return err
}
ptr += inc
}
}
return nil
}