-
Notifications
You must be signed in to change notification settings - Fork 211
/
ballots.go
336 lines (322 loc) · 9.35 KB
/
ballots.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
package ballots
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/sql"
)
func decodeBallot(id types.BallotID, pubkey, body *bytes.Reader, malicious bool) (*types.Ballot, error) {
var nodeID types.NodeID
if n, err := pubkey.Read(nodeID[:]); err != nil {
if err != io.EOF {
return nil, fmt.Errorf("copy pubkey: %w", err)
}
} else if n != types.NodeIDSize {
return nil, errors.New("public key data missing")
}
ballot := types.Ballot{}
if n, err := codec.DecodeFrom(body, &ballot); err != nil {
if err != io.EOF {
return nil, fmt.Errorf("decode body of the %s: %w", id, err)
}
} else if n == 0 {
return nil, errors.New("ballot data missing")
}
ballot.SetID(id)
ballot.SmesherID = nodeID
if malicious {
ballot.SetMalicious()
}
return &ballot, nil
}
// Add ballot to the database.
func Add(db sql.Executor, ballot *types.Ballot) error {
bytes, err := codec.Encode(ballot)
if err != nil {
return fmt.Errorf("encode ballot %s: %w", ballot.ID(), err)
}
if _, err := db.Exec(`insert into ballots
(id, atx, layer, pubkey, ballot)
values (?1, ?2, ?3, ?4, ?5);`,
func(stmt *sql.Statement) {
stmt.BindBytes(1, ballot.ID().Bytes())
stmt.BindBytes(2, ballot.AtxID.Bytes())
stmt.BindInt64(3, int64(ballot.Layer))
stmt.BindBytes(4, ballot.SmesherID.Bytes())
stmt.BindBytes(5, bytes)
}, nil); err != nil {
return fmt.Errorf("insert ballot %s: %w", ballot.ID(), err)
}
return nil
}
// Has a ballot in the database.
func Has(db sql.Executor, id types.BallotID) (bool, error) {
rows, err := db.Exec("select 1 from ballots where id = ?1;",
func(stmt *sql.Statement) {
stmt.BindBytes(1, id.Bytes())
}, nil,
)
if err != nil {
return false, fmt.Errorf("has ballot %s: %w", id, err)
}
return rows > 0, nil
}
// Get ballot with id from database.
func Get(db sql.Executor, id types.BallotID) (rst *types.Ballot, err error) {
if rows, err := db.Exec(`select pubkey, ballot, length(identities.proof)
from ballots left join identities using(pubkey)
where id = ?1;`,
func(stmt *sql.Statement) {
stmt.BindBytes(1, id.Bytes())
}, func(stmt *sql.Statement) bool {
rst, err = decodeBallot(id,
stmt.ColumnReader(0),
stmt.ColumnReader(1),
stmt.ColumnInt(2) > 0,
)
return true
}); err != nil {
return nil, fmt.Errorf("get %s: %w", id, err)
} else if rows == 0 {
return nil, fmt.Errorf("%w ballot %s", sql.ErrNotFound, id)
}
return rst, nil
}
// Layer returns full body ballot for layer.
func Layer(db sql.Executor, lid types.LayerID) (rst []*types.Ballot, err error) {
if _, err = db.Exec(`select id, pubkey, ballot, length(identities.proof)
from ballots left join identities using(pubkey)
where layer = ?1;`, func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
}, func(stmt *sql.Statement) bool {
id := types.BallotID{}
stmt.ColumnBytes(0, id[:])
var ballot *types.Ballot
ballot, err = decodeBallot(id,
stmt.ColumnReader(1),
stmt.ColumnReader(2),
stmt.ColumnInt(3) > 0,
)
if err != nil {
return false
}
rst = append(rst, ballot)
return true
}); err != nil {
return nil, fmt.Errorf("ballots for layer %s: %w", lid, err)
}
return rst, err
}
// IDsInLayer returns ballots ids in the layer.
func IDsInLayer(db sql.Executor, lid types.LayerID) (rst []types.BallotID, err error) {
if _, err := db.Exec("select id from ballots where layer = ?1;", func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid.Uint32()))
}, func(stmt *sql.Statement) bool {
id := types.BallotID{}
stmt.ColumnBytes(0, id[:])
rst = append(rst, id)
return true
}); err != nil {
return nil, fmt.Errorf("ballots for layer %s: %w", lid, err)
}
return rst, err
}
// CountByPubkeyLayer counts number of ballots in the layer for the nodeID.
func CountByPubkeyLayer(db sql.Executor, lid types.LayerID, nodeID types.NodeID) (int, error) {
rows, err := db.Exec("select 1 from ballots where layer = ?1 and pubkey = ?2;", func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
stmt.BindBytes(2, nodeID.Bytes())
}, nil)
if err != nil {
return 0, fmt.Errorf("counting layer %s: %w", lid, err)
}
return rows, nil
}
// LayerBallotByNodeID returns any ballot by the specified NodeID in a given layer.
func LayerBallotByNodeID(db sql.Executor, lid types.LayerID, nodeID types.NodeID) (*types.Ballot, error) {
var (
ballot types.Ballot
err error
)
enc := func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
stmt.BindBytes(2, nodeID.Bytes())
}
dec := func(stmt *sql.Statement) bool {
var (
n int
bid types.BallotID
)
stmt.ColumnBytes(0, bid[:])
if n, err = codec.DecodeFrom(stmt.ColumnReader(1), &ballot); err != nil {
if err != io.EOF {
err = fmt.Errorf("ballot data layer %v, nodeID %v: %w", lid, nodeID, err)
return false
}
} else if n == 0 {
err = fmt.Errorf("ballot data missing layer %v, nodeID %v", lid, nodeID)
return false
}
ballot.SetID(bid)
ballot.SmesherID = nodeID
return true
}
if rows, err := db.Exec(`
select id, ballot from ballots
where layer = ?1 and pubkey = ?2
limit 1;`, enc, dec); err != nil {
return nil, fmt.Errorf("same layer ballot %v: %w", lid, err)
} else if rows == 0 {
return nil, sql.ErrNotFound
}
return &ballot, err
}
// RefBallot gets a ref ballot for a layer and a nodeID.
func RefBallot(db sql.Executor, epoch types.EpochID, nodeID types.NodeID) (*types.Ballot, error) {
firstLayer := epoch.FirstLayer()
lastLayer := firstLayer.Add(types.GetLayersPerEpoch()).Sub(1)
var (
bid types.BallotID
ballot types.Ballot
rows, n int
err error
)
dec := func(stmt *sql.Statement) bool {
stmt.ColumnBytes(0, bid[:])
if n, err = codec.DecodeFrom(stmt.ColumnReader(1), &ballot); err != nil {
if err != io.EOF {
err = fmt.Errorf("ref ballot %s/%d: %w", nodeID.ShortString(), epoch, err)
return false
}
} else if n == 0 {
err = fmt.Errorf("ref ballot missing data %s/%d", nodeID.ShortString(), epoch)
return false
}
ballot.SetID(bid)
ballot.SmesherID = nodeID
if stmt.ColumnInt(2) > 0 {
ballot.SetMalicious()
}
// only ref ballot has valid EpochData
if ballot.EpochData != nil {
return false
}
return true
}
rows, err = db.Exec(`
select id, ballot, length(identities.proof) from ballots
left join identities using(pubkey)
where layer between ?1 and ?2 and pubkey = ?3
order by layer asc;`,
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(firstLayer))
stmt.BindInt64(2, int64(lastLayer))
stmt.BindBytes(3, nodeID.Bytes())
}, dec)
if err != nil {
return nil, fmt.Errorf("ref ballot %s/%d: %w", nodeID.ShortString(), epoch, err)
} else if rows == 0 {
return nil, fmt.Errorf("%w ref ballot %s/%d", sql.ErrNotFound, nodeID.ShortString(), epoch)
}
return &ballot, nil
}
// LatestLayer gets the highest layer with ballots.
func LatestLayer(db sql.Executor) (types.LayerID, error) {
var lid types.LayerID
if _, err := db.Exec("select max(layer) from ballots;",
nil,
func(stmt *sql.Statement) bool {
lid = types.LayerID(uint32(stmt.ColumnInt64(0)))
return true
}); err != nil {
return lid, fmt.Errorf("latest layer: %w", err)
}
return lid, nil
}
func FirstInEpoch(db sql.Executor, atx types.ATXID, epoch types.EpochID) (*types.Ballot, error) {
var (
bid types.BallotID
ballot types.Ballot
nodeID types.NodeID
rows, n int
err error
)
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, atx.Bytes())
stmt.BindInt64(2, int64(epoch.FirstLayer()))
stmt.BindInt64(3, int64((epoch+1).FirstLayer()-1))
}
dec := func(stmt *sql.Statement) bool {
stmt.ColumnBytes(0, bid[:])
stmt.ColumnBytes(1, nodeID[:])
if n, err = codec.DecodeFrom(stmt.ColumnReader(2), &ballot); err != nil {
if err != io.EOF {
err = fmt.Errorf("ballot by atx %s: %w", atx, err)
return false
}
} else if n == 0 {
err = fmt.Errorf("ballot by atx missing data %s", atx)
return false
}
ballot.SetID(bid)
ballot.SmesherID = nodeID
if stmt.ColumnInt(3) > 0 {
ballot.SetMalicious()
}
// only ref ballot has valid EpochData
if ballot.EpochData != nil {
return false
}
return true
}
rows, err = db.Exec(`
select id, pubkey, ballot, length(identities.proof) from ballots
left join identities using(pubkey)
where atx = ?1 and layer between ?2 and ?3
order by layer asc;`, enc, dec)
if err != nil {
return nil, fmt.Errorf("ballot by atx %s: %w", atx, err)
}
if rows == 0 {
return nil, sql.ErrNotFound
}
return &ballot, err
}
func AllFirstInEpoch(db sql.Executor, epoch types.EpochID) ([]*types.Ballot, error) {
var (
err error
rst []*types.Ballot
)
enc := func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(epoch.FirstLayer()))
stmt.BindInt64(2, int64((epoch+1).FirstLayer()-1))
}
dec := func(stmt *sql.Statement) bool {
var (
bid types.BallotID
ballot types.Ballot
)
stmt.ColumnBytes(0, bid[:])
if _, err = codec.DecodeFrom(stmt.ColumnReader(1), &ballot); err != nil && err != io.EOF {
err = fmt.Errorf("decode ballot: %w", err)
return false
} else {
err = nil
}
ballot.SetID(bid)
rst = append(rst, &ballot)
return true
}
if _, err := db.Exec(`
select id, ballot, min(layer) from ballots where layer between ?1 and ?2
group by pubkey;`, enc, dec); err != nil {
return nil, fmt.Errorf("query first ballots in epoch %d: %w", epoch, err)
}
if err != nil {
return nil, err
}
return rst, nil
}