-
Notifications
You must be signed in to change notification settings - Fork 211
/
certs.go
149 lines (137 loc) · 4.53 KB
/
certs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package certificates
import (
"fmt"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/sql"
)
func SetHareOutput(db sql.Executor, lid types.LayerID, bid types.BlockID) error {
return setHareOutput(db, lid, bid, true)
}
func SetHareOutputInvalid(db sql.Executor, lid types.LayerID, bid types.BlockID) error {
return setHareOutput(db, lid, bid, false)
}
func setHareOutput(db sql.Executor, lid types.LayerID, bid types.BlockID, valid bool) error {
if _, err := db.Exec(`insert into certificates (layer, block, valid) values (?1, ?2, ?3)
on conflict do nothing;`,
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
stmt.BindBytes(2, bid[:])
stmt.BindBool(3, valid)
}, nil); err != nil {
return fmt.Errorf("add wo cert %s: %w", lid, err)
}
return nil
}
// GetHareOutput returns the block that's valid as hare output for the specified layer.
// if there are more than one valid blocks, return types.EmptyBlockID.
func GetHareOutput(db sql.Executor, lid types.LayerID) (types.BlockID, error) {
var (
result types.BlockID
err error
rows int
)
if rows, err = db.Exec("select block from certificates where layer = ?1 and valid = 1;", func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
}, func(stmt *sql.Statement) bool {
stmt.ColumnBytes(0, result[:])
return true
}); err != nil {
return types.EmptyBlockID, fmt.Errorf("get certs %s: %w", lid, err)
} else if rows == 0 {
return types.EmptyBlockID, fmt.Errorf("get certs %s: %w", lid, sql.ErrNotFound)
} else if rows > 1 {
return types.EmptyBlockID, nil
}
return result, nil
}
func FirstInEpoch(db sql.Executor, epoch types.EpochID) (types.BlockID, error) {
var (
result types.BlockID
err error
rows int
)
if rows, err = db.Exec(`
select block from certificates where layer between ?1 and ?2 and valid = 1 and cert is not null
order by layer asc limit 1;`, func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(epoch.FirstLayer()))
stmt.BindInt64(2, int64((epoch+1).FirstLayer()-1))
}, func(stmt *sql.Statement) bool {
stmt.ColumnBytes(0, result[:])
return true
}); err != nil {
return types.EmptyBlockID, fmt.Errorf("FirstInEpoch %s: %w", epoch, err)
} else if rows == 0 {
return types.EmptyBlockID, fmt.Errorf("FirstInEpoch %s: %w", epoch, sql.ErrNotFound)
}
return result, nil
}
func Add(db sql.Executor, lid types.LayerID, cert *types.Certificate) error {
data, err := codec.Encode(cert)
if err != nil {
return fmt.Errorf("encode cert %w", err)
}
if _, err = db.Exec(`insert into certificates (layer, block, cert, valid) values (?1, ?2, ?3, 1)
on conflict do update set cert = ?3, valid = 1;`,
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
stmt.BindBytes(2, cert.BlockID[:])
stmt.BindBytes(3, data[:])
}, nil); err != nil {
return fmt.Errorf("add cert %s: %w", lid, err)
}
return nil
}
type CertValidity struct {
Block types.BlockID
Cert *types.Certificate
Valid bool
}
func Get(db sql.Executor, lid types.LayerID) ([]CertValidity, error) {
var result []CertValidity
if rows, err := db.Exec("select block, cert, valid from certificates where layer = ?1 order by length(cert) desc;", func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
}, func(stmt *sql.Statement) bool {
var (
cv CertValidity
cert types.Certificate
)
stmt.ColumnBytes(0, cv.Block[:])
if stmt.ColumnLen(1) > 0 {
data := make([]byte, stmt.ColumnLen(1))
stmt.ColumnBytes(1, data[:])
if err := codec.Decode(data, &cert); err != nil {
return false
}
cv.Cert = &cert
}
cv.Valid = stmt.ColumnInt(2) > 0
result = append(result, cv)
return true
}); err != nil {
return nil, fmt.Errorf("get certs %s: %w", lid, err)
} else if rows == 0 {
return nil, fmt.Errorf("get certs %s: %w", lid, sql.ErrNotFound)
}
return result, nil
}
func SetValid(db sql.Executor, lid types.LayerID, bid types.BlockID) error {
if _, err := db.Exec(`update certificates set valid = 1 where layer = ?1 and block = ?2;`,
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
stmt.BindBytes(2, bid[:])
}, nil); err != nil {
return fmt.Errorf("invalidate %s: %w", lid, err)
}
return nil
}
func SetInvalid(db sql.Executor, lid types.LayerID, bid types.BlockID) error {
if _, err := db.Exec(`update certificates set valid = 0 where layer = ?1 and block = ?2;`,
func(stmt *sql.Statement) {
stmt.BindInt64(1, int64(lid))
stmt.BindBytes(2, bid[:])
}, nil); err != nil {
return fmt.Errorf("invalidate %s: %w", lid, err)
}
return nil
}