-
Notifications
You must be signed in to change notification settings - Fork 0
/
account_loader.go
236 lines (205 loc) · 6.23 KB
/
account_loader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
package history
import (
"context"
"database/sql/driver"
"fmt"
"sort"
"strings"
"github.com/lib/pq"
"github.com/shantanu-hashcash/go/support/collections/set"
"github.com/shantanu-hashcash/go/support/db"
"github.com/shantanu-hashcash/go/support/errors"
"github.com/shantanu-hashcash/go/support/ordered"
)
// FutureAccountID represents a future history account.
// A FutureAccountID is created by an AccountLoader and
// the account id is available after calling Exec() on
// the AccountLoader.
type FutureAccountID struct {
address string
loader *AccountLoader
}
const loaderLookupBatchSize = 50000
// Value implements the database/sql/driver Valuer interface.
func (a FutureAccountID) Value() (driver.Value, error) {
return a.loader.GetNow(a.address)
}
// AccountLoader will map account addresses to their history
// account ids. If there is no existing mapping for a given address,
// the AccountLoader will insert into the history_accounts table to
// establish a mapping.
type AccountLoader struct {
sealed bool
set set.Set[string]
ids map[string]int64
stats LoaderStats
}
var errSealed = errors.New("cannot register more entries to loader after calling Exec()")
// NewAccountLoader will construct a new AccountLoader instance.
func NewAccountLoader() *AccountLoader {
return &AccountLoader{
sealed: false,
set: set.Set[string]{},
ids: map[string]int64{},
stats: LoaderStats{},
}
}
// GetFuture registers the given account address into the loader and
// returns a FutureAccountID which will hold the history account id for
// the address after Exec() is called.
func (a *AccountLoader) GetFuture(address string) FutureAccountID {
if a.sealed {
panic(errSealed)
}
a.set.Add(address)
return FutureAccountID{
address: address,
loader: a,
}
}
// GetNow returns the history account id for the given address.
// GetNow should only be called on values which were registered by
// GetFuture() calls. Also, Exec() must be called before any GetNow
// call can succeed.
func (a *AccountLoader) GetNow(address string) (int64, error) {
if !a.sealed {
return 0, fmt.Errorf(`invalid account loader state,
Exec was not called yet to properly seal and resolve %v id`, address)
}
if internalID, ok := a.ids[address]; !ok {
return 0, fmt.Errorf(`account loader address %q was not found`, address)
} else {
return internalID, nil
}
}
func (a *AccountLoader) lookupKeys(ctx context.Context, q *Q, addresses []string) error {
for i := 0; i < len(addresses); i += loaderLookupBatchSize {
end := ordered.Min(len(addresses), i+loaderLookupBatchSize)
var accounts []Account
if err := q.AccountsByAddresses(ctx, &accounts, addresses[i:end]); err != nil {
return errors.Wrap(err, "could not select accounts")
}
for _, account := range accounts {
a.ids[account.Address] = account.ID
}
}
return nil
}
// LoaderStats describes the result of executing a history lookup id loader
type LoaderStats struct {
// Total is the number of elements registered to the loader
Total int
// Inserted is the number of elements inserted into the lookup table
Inserted int
}
// Exec will look up all the history account ids for the addresses registered in the loader.
// If there are no history account ids for a given set of addresses, Exec will insert rows
// into the history_accounts table to establish a mapping between address and history account id.
func (a *AccountLoader) Exec(ctx context.Context, session db.SessionInterface) error {
a.sealed = true
if len(a.set) == 0 {
return nil
}
q := &Q{session}
addresses := make([]string, 0, len(a.set))
for address := range a.set {
addresses = append(addresses, address)
}
if err := a.lookupKeys(ctx, q, addresses); err != nil {
return err
}
a.stats.Total += len(addresses)
insert := 0
for _, address := range addresses {
if _, ok := a.ids[address]; ok {
continue
}
addresses[insert] = address
insert++
}
if insert == 0 {
return nil
}
addresses = addresses[:insert]
// sort entries before inserting rows to prevent deadlocks on acquiring a ShareLock
// https://github.com/shantanu-hashcash/go/issues/2370
sort.Strings(addresses)
err := bulkInsert(
ctx,
q,
"history_accounts",
[]string{"address"},
[]bulkInsertField{
{
name: "address",
dbType: "character varying(64)",
objects: addresses,
},
},
)
if err != nil {
return err
}
a.stats.Inserted += insert
return a.lookupKeys(ctx, q, addresses)
}
// Stats returns the number of addresses registered in the loader and the number of addresses
// inserted into the history_accounts table.
func (a *AccountLoader) Stats() LoaderStats {
return a.stats
}
func (a *AccountLoader) Name() string {
return "AccountLoader"
}
type bulkInsertField struct {
name string
dbType string
objects []string
}
func bulkInsert(ctx context.Context, q *Q, table string, conflictFields []string, fields []bulkInsertField) error {
unnestPart := make([]string, 0, len(fields))
insertFieldsPart := make([]string, 0, len(fields))
pqArrays := make([]interface{}, 0, len(fields))
for _, field := range fields {
unnestPart = append(
unnestPart,
fmt.Sprintf("unnest(?::%s[]) /* %s */", field.dbType, field.name),
)
insertFieldsPart = append(
insertFieldsPart,
field.name,
)
pqArrays = append(
pqArrays,
pq.Array(field.objects),
)
}
sql := `
WITH r AS
(SELECT ` + strings.Join(unnestPart, ",") + `)
INSERT INTO ` + table + `
(` + strings.Join(insertFieldsPart, ",") + `)
SELECT * from r
ON CONFLICT (` + strings.Join(conflictFields, ",") + `) DO NOTHING`
_, err := q.ExecRaw(
context.WithValue(ctx, &db.QueryTypeContextKey, db.UpsertQueryType),
sql,
pqArrays...,
)
return err
}
// AccountLoaderStub is a stub wrapper around AccountLoader which allows
// you to manually configure the mapping of addresses to history account ids
type AccountLoaderStub struct {
Loader *AccountLoader
}
// NewAccountLoaderStub returns a new AccountLoaderStub instance
func NewAccountLoaderStub() AccountLoaderStub {
return AccountLoaderStub{Loader: NewAccountLoader()}
}
// Insert updates the wrapped AccountLoader so that the given account
// address is mapped to the provided history account id
func (a AccountLoaderStub) Insert(address string, id int64) {
a.Loader.sealed = true
a.Loader.ids[address] = id
}