-
Notifications
You must be signed in to change notification settings - Fork 1
/
database.go
181 lines (164 loc) · 4.33 KB
/
database.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/mattn/go-sqlite3"
)
var (
dbSetup = []string{`
CREATE TABLE IF NOT EXISTS state (
current_term integer not null,
voted_for text
);
`,
`
INSERT INTO state
SELECT 0, NULL WHERE NOT EXISTS (SELECT 1 FROM state);
`,
`
CREATE TABLE IF NOT EXISTS log (
id integer primary key,
term integer not null,
client_id text not null,
client_serial integer not null,
operation text not null,
key text not null,
value text not null
);
`,
}
)
func initDB(id string) (*sql.DB, error) {
filename := fmt.Sprintf("raft-%s.db", id)
db, err := sql.Open("sqlite3", filename)
if err != nil {
log.Printf("error opening db %q: %v", filename, err)
return nil, err
}
for _, setup := range dbSetup {
_, err = db.Exec(setup)
if err != nil {
log.Printf("db setup error: %v", err)
db.Close()
return nil, err
}
}
return db, nil
}
// getPersistent loads the persistent state of a server from the database.
func (s *Server) getPersistent(db *sql.DB) error {
var currentTerm int
var votedFor *string
err := db.QueryRow(`SELECT current_term, voted_for FROM state`).Scan(¤tTerm, &votedFor)
if err != nil {
log.Printf("db error getting state: %v", err)
return err
}
s.CurrentTerm = currentTerm
if votedFor == nil {
s.VotedFor = ""
} else {
s.VotedFor = *votedFor
}
return nil
}
// putPersistent saves the persistent state of a server to the database.
func (s *Server) putPersistent(tx *sql.Tx) error {
votedFor := &s.VotedFor
if len(*votedFor) == 0 {
votedFor = nil
}
_, err := tx.Exec(`UPDATE state SET current_term = ?, voted_for = ?`, s.CurrentTerm, votedFor)
if err != nil {
log.Printf("db error putting state: %v", err)
return err
}
return nil
}
// verifyLogAt confirms the existence of a log entry with the given index and term.
func verifyLogAt(tx *sql.Tx, index, term int) (bool, error) {
// special case: leader has an empty log
if index == -1 && term == -1 {
return true, nil
}
found := false
err := tx.QueryRow(`SELECT 1 FROM log WHERE id = ? AND term = ?`, index, term).Scan(&found)
if err == sql.ErrNoRows {
return false, nil
} else if err != nil {
log.Printf("db error checking log entry %d: %v", index, err)
return false, err
}
return true, nil
}
// getLastLogEntry returns the index and term of the last entry in the log.
func getLastLogEntry(tx *sql.Tx) (index, term int, err error) {
err = tx.QueryRow(`SELECT index, term FROM log ORDER BY index LIMIT 1`).Scan(&index, &term)
if err == sql.ErrNoRows {
return -1, -1, nil
} else if err != nil {
log.Printf("db error checking last log entry index and term: %v", err)
return 0, 0, err
}
return index, term, nil
}
// getLogEntries retrieves log entries in the range [from, to)
func getLogEntries(tx *sql.Tx, from, to int) ([]*LogEntry, error) {
var out []*LogEntry
rows, err := tx.Query(`SELECT id, term, client_id, client_serial, operation, key, value `+
`FROM log WHERE id >= ? AND id < ? ORDER BY id ASC`, from, to)
if err != nil {
log.Printf("db error loading log entries [%d,%d): %v", from, to, err)
return nil, err
}
for rows.Next() {
l := new(LogEntry)
out = append(out, l)
err := rows.Scan(
&l.ID,
&l.Term,
&l.ClientRequest.ClientID,
&l.ClientRequest.ClientSerial,
&l.ClientRequest.Operation,
&l.ClientRequest.Key,
&l.ClientRequest.Value)
if err != nil {
log.Printf("db error scanning log entry: %v", err)
return nil, err
}
}
if err := rows.Err(); err != nil {
log.Printf("db error reading log entries: %v", err)
return nil, err
}
return out, nil
}
// saveLogEntries saves a slice of log entries, which must be in order by index.
func saveLogEntries(tx *sql.Tx, entries []*LogEntry) error {
if len(entries) == 0 {
return nil
}
// truncate the log if applicable
_, err := tx.Exec(`DELETE FROM log WHERE id >= ?`, entries[0].ID)
if err != nil {
log.Printf("db error truncating log: %v", err)
return err
}
for _, elt := range entries {
_, err := tx.Exec(`INSERT INTO log (id, term, client_id, client_serial, operation, key, value) `+
`VALUES (?,?,?,?,?,?,?)`,
elt.ID,
elt.Term,
elt.ClientRequest.ClientID,
elt.ClientRequest.ClientSerial,
elt.ClientRequest.Operation,
elt.ClientRequest.Key,
elt.ClientRequest.Value)
if err != nil {
log.Printf("db error inserting log entry: %v", err)
return err
}
}
return nil
}