-
Notifications
You must be signed in to change notification settings - Fork 242
/
database.go
191 lines (175 loc) · 4.35 KB
/
database.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
package storenodes
import (
"bytes"
"database/sql"
"fmt"
"time"
"github.com/status-im/status-go/eth-node/types"
)
type Database struct {
db *sql.DB
}
func NewDB(db *sql.DB) *Database {
return &Database{db: db}
}
// syncSave will sync the storenodes in the DB from the snode slice
// - if a storenode is not in the provided list, it will be soft-deleted
// - if a storenode is in the provided list, it will be inserted or updated
func (d *Database) syncSave(communityID types.HexBytes, snode []Storenode, clock uint64) (err error) {
var tx *sql.Tx
tx, err = d.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
now := time.Now().Unix()
dbNodes, err := d.getByCommunityID(communityID, tx)
if err != nil {
return fmt.Errorf("getting storenodes by community id: %w", err)
}
// Soft-delete db nodes that are not in the provided list
for _, dbN := range dbNodes {
if find(dbN, snode) != nil {
continue
}
if clock != 0 && dbN.Clock >= clock {
continue
}
if err := d.softDelete(communityID, dbN.StorenodeID, now, tx); err != nil {
return fmt.Errorf("soft deleting existing storenodes: %w", err)
}
}
// Insert or update the nodes in the provided list
for _, n := range snode {
// defensively validate the communityID
if len(n.CommunityID) == 0 || !bytes.Equal(communityID, n.CommunityID) {
err = fmt.Errorf("communityID mismatch %v != %v", communityID, n.CommunityID)
return err
}
dbN := find(n, dbNodes)
if dbN != nil && n.Clock != 0 && dbN.Clock >= n.Clock {
continue
}
if err := d.upsert(n, tx); err != nil {
return fmt.Errorf("upserting storenodes: %w", err)
}
}
// TODO for now only allow one storenode per community
count, err := d.countByCommunity(communityID, tx)
if err != nil {
return err
}
if count > 1 {
err = fmt.Errorf("only one storenode per community is allowed")
return err
}
return nil
}
func (d *Database) getAll() ([]Storenode, error) {
rows, err := d.db.Query(`
SELECT community_id, storenode_id, name, address, fleet, version, clock, removed, deleted_at
FROM community_storenodes
WHERE removed = 0
`)
if err != nil {
return nil, err
}
defer rows.Close()
return toStorenodes(rows)
}
func (d *Database) getByCommunityID(communityID types.HexBytes, tx ...*sql.Tx) ([]Storenode, error) {
var rows *sql.Rows
var err error
q := `
SELECT community_id, storenode_id, name, address, fleet, version, clock, removed, deleted_at
FROM community_storenodes
WHERE community_id = ? AND removed = 0
`
if len(tx) > 0 {
rows, err = tx[0].Query(q, communityID)
} else {
rows, err = d.db.Query(q, communityID)
}
if err != nil {
return nil, err
}
defer rows.Close()
return toStorenodes(rows)
}
func (d *Database) softDelete(communityID types.HexBytes, storenodeID string, deletedAt int64, tx *sql.Tx) error {
_, err := tx.Exec("UPDATE community_storenodes SET removed = 1, deleted_at = ? WHERE community_id = ? AND storenode_id = ?", deletedAt, communityID, storenodeID)
if err != nil {
return err
}
return nil
}
func (d *Database) upsert(n Storenode, tx *sql.Tx) error {
_, err := tx.Exec(`INSERT OR REPLACE INTO community_storenodes(
community_id,
storenode_id,
name,
address,
fleet,
version,
clock,
removed,
deleted_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
n.CommunityID,
n.StorenodeID,
n.Name,
n.Address,
n.Fleet,
n.Version,
n.Clock,
n.Removed,
n.DeletedAt,
)
if err != nil {
return err
}
return nil
}
func (d *Database) countByCommunity(communityID types.HexBytes, tx *sql.Tx) (int, error) {
var count int
err := tx.QueryRow(`SELECT COUNT(*) FROM community_storenodes WHERE community_id = ? AND removed = 0`, communityID).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
func toStorenodes(rows *sql.Rows) ([]Storenode, error) {
var result []Storenode
for rows.Next() {
var m Storenode
if err := rows.Scan(
&m.CommunityID,
&m.StorenodeID,
&m.Name,
&m.Address,
&m.Fleet,
&m.Version,
&m.Clock,
&m.Removed,
&m.DeletedAt,
); err != nil {
return nil, err
}
result = append(result, m)
}
return result, nil
}
func find(n Storenode, nodes []Storenode) *Storenode {
for i, node := range nodes {
if node.StorenodeID == n.StorenodeID && bytes.Equal(node.CommunityID, n.CommunityID) {
return &nodes[i]
}
}
return nil
}