-
Notifications
You must be signed in to change notification settings - Fork 390
/
table.go
294 lines (249 loc) · 8.94 KB
/
table.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package usedserials
import (
"encoding/binary"
"math/rand"
"sort"
"sync"
"time"
"github.com/spacemonkeygo/monkit/v3"
"github.com/zeebo/errs"
"storj.io/common/memory"
"storj.io/common/storj"
)
var (
// ErrSerials defines the usedserials store error class.
ErrSerials = errs.Class("usedserials")
// ErrSerialAlreadyExists defines an error class for duplicate usedserials.
ErrSerialAlreadyExists = errs.Class("used serial already exists in store")
mon = monkit.Package()
)
const (
// PartialSize is the size of a partial serial number.
PartialSize = memory.Size(len(Partial{}))
// FullSize is the size of a full serial number.
FullSize = memory.Size(len(storj.SerialNumber{}))
)
// Partial represents the last 8 bytes of a serial number. It is used when the first 8 are based on the expiration date.
type Partial [8]byte
// Less returns true if partial serial a is less than partial serial b and false otherwise.
func (a Partial) Less(b Partial) bool {
return binary.BigEndian.Uint64(a[:]) < binary.BigEndian.Uint64(b[:])
}
// Full is a copy of the SerialNumber type. It is necessary so we can define a Less function on it.
type Full storj.SerialNumber
// Less returns true if partial serial a is less than partial serial b and false otherwise.
func (a Full) Less(b Full) bool {
return binary.BigEndian.Uint64(a[:]) < binary.BigEndian.Uint64(b[:])
}
// serialsList is a structure that contains a list of partial serials and a list of full serials.
//
// For serials where expiration time is the first 8 bytes, it uses partialSerials.
// It uses fullSerials otherwise.
type serialsList struct {
partialSerials []Partial
fullSerials []storj.SerialNumber
}
// Table is an in-memory store for serial numbers.
type Table struct {
mu sync.Mutex
// key 1: satellite ID, key 2: expiration hour (in unix time), value: a list of serial numbers
serials map[storj.NodeID]map[int64]serialsList
maxMemory memory.Size
memoryUsed memory.Size
}
// NewTable creates and returns a new usedserials in-memory store.
func NewTable(maxMemory memory.Size) *Table {
if maxMemory <= 0 {
panic("max memory for usedserials store is 0")
}
return &Table{
serials: make(map[storj.NodeID]map[int64]serialsList),
maxMemory: maxMemory,
}
}
// Add adds a serial to the store, or returns an error if the serial number was already added.
// It randomly deletes items from the store if the set maxMemory is exceeded.
func (table *Table) Add(satelliteID storj.NodeID, serialNumber storj.SerialNumber, expiration time.Time) error {
table.mu.Lock()
defer table.mu.Unlock()
satMap, ok := table.serials[satelliteID]
if !ok {
satMap = make(map[int64]serialsList)
table.serials[satelliteID] = satMap
}
expirationHour := ceilExpirationHour(expiration)
list, ok := satMap[expirationHour]
if !ok {
list = serialsList{}
satMap[expirationHour] = list
}
// determine whether we can use a partial serial number
partialSerial, usePartial := tryTruncate(serialNumber, expiration)
if usePartial {
partialList := list.partialSerials
partialList, err := insertPartial(partialList, partialSerial)
if err != nil {
return err
}
list.partialSerials = partialList
table.serials[satelliteID][expirationHour] = list
table.memoryUsed += PartialSize
} else {
fullList := list.fullSerials
fullList, err := insertSerial(fullList, serialNumber)
if err != nil {
return err
}
list.fullSerials = fullList
table.serials[satelliteID][expirationHour] = list
table.memoryUsed += FullSize
}
// Check to see if the structure exceeds the max allowed size.
// If so, delete random items until there is enough space.
for table.memoryUsed > table.maxMemory {
err := table.deleteRandomSerial()
if err != nil {
return err
}
}
return nil
}
// DeleteExpired deletes expired serial numbers if their expiration hour has passed.
func (table *Table) DeleteExpired(now time.Time) {
table.mu.Lock()
defer table.mu.Unlock()
partialToDelete := 0
fullToDelete := 0
for _, satMap := range table.serials {
for expirationHour, list := range satMap {
if expirationHour < now.Unix() {
partialToDelete += len(list.partialSerials)
fullToDelete += len(list.fullSerials)
delete(satMap, expirationHour)
}
}
}
table.memoryUsed -= memory.Size(partialToDelete) * PartialSize
table.memoryUsed -= memory.Size(fullToDelete) * FullSize
}
// Exists determines whether a serial number exists in the table.
func (table *Table) Exists(satelliteID storj.NodeID, serialNumber storj.SerialNumber, expiration time.Time) bool {
table.mu.Lock()
defer table.mu.Unlock()
expirationHour := ceilExpirationHour(expiration)
serialsList := table.serials[satelliteID][expirationHour]
partial, usePartial := tryTruncate(serialNumber, expiration)
if usePartial {
for _, serial := range serialsList.partialSerials {
if serial == partial {
return true
}
}
} else {
for _, serial := range serialsList.fullSerials {
if serial == serialNumber {
return true
}
}
}
return false
}
// Count iterates over all the items in the table and returns the number.
func (table *Table) Count() int {
table.mu.Lock()
defer table.mu.Unlock()
count := 0
for _, satMap := range table.serials {
for _, serialsList := range satMap {
count += len(serialsList.fullSerials)
count += len(serialsList.partialSerials)
}
}
return count
}
// deleteRandomSerial deletes a random item.
// It expects the mutex to be locked before being called.
func (table *Table) deleteRandomSerial() error {
mon.Meter("delete_random_serial").Mark(1) //mon:locked
for _, satMap := range table.serials {
for expirationHour, serialList := range satMap {
if len(serialList.partialSerials) > 0 {
i := rand.Intn(len(serialList.partialSerials))
// shift all elements after i once, to overwrite i
copy(serialList.partialSerials[i:], serialList.partialSerials[i+1:])
// truncate to get rid of last item
serialList.partialSerials = serialList.partialSerials[:len(serialList.partialSerials)-1]
satMap[expirationHour] = serialList
table.memoryUsed -= PartialSize
return nil
} else if len(serialList.fullSerials) > 0 {
i := rand.Intn(len(serialList.fullSerials))
// shift all elements after i once, to overwrite i
copy(serialList.fullSerials[i:], serialList.fullSerials[i+1:])
// truncate to get rid of last item
serialList.fullSerials = serialList.fullSerials[:len(serialList.fullSerials)-1]
satMap[expirationHour] = serialList
table.memoryUsed -= FullSize
return nil
}
}
}
// we should never get to this path unless config.MaxTableSize is 0
return ErrSerials.New("could not delete a random item")
}
// insertPartial inserts a partial serial in the correct position in a sorted list,
// or returns an error if it is already in the list.
func insertPartial(list []Partial, serial Partial) ([]Partial, error) {
i := sort.Search(len(list), func(h int) bool {
return serial.Less(list[h])
})
// if serial is already in the list, it will be at index i-1
if i > 0 && list[i-1] == serial {
return nil, ErrSerialAlreadyExists.New("")
}
// insert new serial at index i and shift everything up
// 1. grow the slice by one element.
list = append(list, Partial{})
// 2. move the upper part of the slice out of the way and open a hole.
copy(list[i+1:], list[i:])
// 3. store the new value.
list[i] = serial
return list, nil
}
// insertSerial inserts a serial in the correct position in a sorted list,
// or returns an error if it is already in the list.
func insertSerial(list []storj.SerialNumber, serial storj.SerialNumber) ([]storj.SerialNumber, error) {
i := sort.Search(len(list), func(h int) bool {
return serial.Less(list[h])
})
// if serial is already in the list, it will be at index i-1
if i > 0 && list[i-1] == serial {
return nil, ErrSerialAlreadyExists.New("")
}
// insert new serial at index i and shift everything up
// 1. grow the slice by one element.
list = append(list, storj.SerialNumber{})
// 2. move the upper part of the slice out of the way and open a hole.
copy(list[i+1:], list[i:])
// 3. store the new value.
list[i] = serial
return list, nil
}
func tryTruncate(serial storj.SerialNumber, expiration time.Time) (partial Partial, succeeded bool) {
// If the first 8 bytes of the serial number are based on the expiration date
// then we can use a partial serial number with the last 8 bytes.
// Otherwise, we need to use the full serial number.
// see satellite/orders/service.go, createSerial() for how expiration date is used in the serial number.
if binary.BigEndian.Uint64(serial[0:8]) == uint64(expiration.Unix()) {
partialSerial := Partial{}
copy(partialSerial[:], serial[8:])
return partialSerial, true
}
return Partial{}, false
}
func ceilExpirationHour(expiration time.Time) int64 {
// time.Truncate rounds down; adding (Hour-Nanosecond) ensures that we round down to the actual expiration hour
return expiration.Add(time.Hour - time.Nanosecond).Truncate(time.Hour).Unix()
}