-
Notifications
You must be signed in to change notification settings - Fork 0
/
claimable_balances.go
351 lines (294 loc) · 11.7 KB
/
claimable_balances.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
package history
import (
"context"
"database/sql/driver"
"encoding/json"
"fmt"
"strconv"
"strings"
sq "github.com/Masterminds/squirrel"
"github.com/guregu/null"
"github.com/shantanu-hashcash/go/services/aurora/internal/db2"
"github.com/shantanu-hashcash/go/support/errors"
"github.com/shantanu-hashcash/go/xdr"
)
// ClaimableBalancesQuery is a helper struct to configure queries to claimable balances
type ClaimableBalancesQuery struct {
PageQuery db2.PageQuery
Asset *xdr.Asset
Sponsor *xdr.AccountId
Claimant *xdr.AccountId
}
// Cursor validates and returns the query page cursor
func (cbq ClaimableBalancesQuery) Cursor() (int64, string, error) {
p := cbq.PageQuery
var l int64
var r string
var err error
if p.Cursor != "" {
parts := strings.SplitN(p.Cursor, "-", 2)
if len(parts) != 2 {
return l, r, errors.New("Invalid cursor")
}
l, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return l, r, errors.Wrap(err, "Invalid cursor - first value should be higher than 0")
}
var balanceID xdr.ClaimableBalanceId
if err = xdr.SafeUnmarshalHex(parts[1], &balanceID); err != nil {
return l, r, errors.Wrap(err, "Invalid cursor - second value should be a valid claimable balance id")
}
r = parts[1]
if l < 0 {
return l, r, errors.New("invalid cursor - first value should be higher than 0")
}
}
return l, r, nil
}
// ApplyCursor applies cursor to the given sql. For performance reason the limit
// is not applied here. This allows us to hint the planner later to use the right
// indexes.
func applyClaimableBalancesQueriesCursor(sql sq.SelectBuilder, tableName string, lCursor int64, rCursor string, order string) (sq.SelectBuilder, error) {
hasPagedLimit := false
if lCursor > 0 && rCursor != "" {
hasPagedLimit = true
}
switch order {
case db2.OrderAscending:
if hasPagedLimit {
sql = sql.
Where(
sq.Expr(
fmt.Sprintf("(%s.last_modified_ledger, %s.id) > (?, ?)", tableName, tableName),
lCursor, rCursor,
),
)
}
sql = sql.OrderBy(fmt.Sprintf("%s.last_modified_ledger asc, %s.id asc", tableName, tableName))
case db2.OrderDescending:
if hasPagedLimit {
sql = sql.
Where(
sq.Expr(
fmt.Sprintf("(%s.last_modified_ledger, %s.id) < (?, ?)", tableName, tableName),
lCursor,
rCursor,
),
)
}
sql = sql.OrderBy(fmt.Sprintf("%s.last_modified_ledger desc, %s.id desc", tableName, tableName))
default:
return sql, errors.Errorf("invalid order: %s", order)
}
return sql, nil
}
// ClaimableBalanceClaimant is a row of data from the `claimable_balances_claimants` table.
// This table exists to allow faster querying for claimable balances for a specific claimant.
type ClaimableBalanceClaimant struct {
BalanceID string `db:"id"`
Destination string `db:"destination"`
LastModifiedLedger uint32 `db:"last_modified_ledger"`
}
// ClaimableBalance is a row of data from the `claimable_balances` table.
type ClaimableBalance struct {
BalanceID string `db:"id"`
Claimants Claimants `db:"claimants"`
Asset xdr.Asset `db:"asset"`
Amount xdr.Int64 `db:"amount"`
Sponsor null.String `db:"sponsor"`
LastModifiedLedger uint32 `db:"last_modified_ledger"`
Flags uint32 `db:"flags"`
}
type Claimants []Claimant
func (c Claimants) Value() (driver.Value, error) {
// Convert the byte array into a string as a workaround to bypass buggy encoding in the pq driver
// (More info about this bug here https://github.com/shantanu-hashcash/go/issues/5086#issuecomment-1773215436).
// By doing so, the data will be written as a string rather than hex encoded bytes.
val, err := json.Marshal(c)
return string(val), err
}
func (c *Claimants) Scan(value interface{}) error {
b, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(b, &c)
}
type Claimant struct {
Destination string `json:"destination"`
Predicate xdr.ClaimPredicate `json:"predicate"`
}
// QClaimableBalances defines claimable-balance-related related queries.
type QClaimableBalances interface {
UpsertClaimableBalances(ctx context.Context, cb []ClaimableBalance) error
RemoveClaimableBalances(ctx context.Context, ids []string) (int64, error)
RemoveClaimableBalanceClaimants(ctx context.Context, ids []string) (int64, error)
GetClaimableBalancesByID(ctx context.Context, ids []string) ([]ClaimableBalance, error)
CountClaimableBalances(ctx context.Context) (int, error)
NewClaimableBalanceClaimantBatchInsertBuilder() ClaimableBalanceClaimantBatchInsertBuilder
NewClaimableBalanceBatchInsertBuilder() ClaimableBalanceBatchInsertBuilder
GetClaimantsByClaimableBalances(ctx context.Context, ids []string) (map[string][]ClaimableBalanceClaimant, error)
}
// CountClaimableBalances returns the total number of claimable balances in the DB
func (q *Q) CountClaimableBalances(ctx context.Context) (int, error) {
sql := sq.Select("count(*)").From("claimable_balances")
var count int
if err := q.Get(ctx, &count, sql); err != nil {
return 0, errors.Wrap(err, "could not run select query")
}
return count, nil
}
// GetClaimableBalancesByID finds all claimable balances by ClaimableBalanceId
func (q *Q) GetClaimableBalancesByID(ctx context.Context, ids []string) ([]ClaimableBalance, error) {
var cBalances []ClaimableBalance
sql := selectClaimableBalances.Where(map[string]interface{}{"cb.id": ids})
err := q.Select(ctx, &cBalances, sql)
return cBalances, err
}
// GetClaimantsByClaimableBalances finds all claimants for ClaimableBalanceIds.
// The returned list is sorted by ids and then destination ids for each balance id.
func (q *Q) GetClaimantsByClaimableBalances(ctx context.Context, ids []string) (map[string][]ClaimableBalanceClaimant, error) {
var claimants []ClaimableBalanceClaimant
sql := sq.Select("*").From("claimable_balance_claimants cbc").
Where(map[string]interface{}{"cbc.id": ids}).
OrderBy("id asc, destination asc")
err := q.Select(ctx, &claimants, sql)
claimantsMap := make(map[string][]ClaimableBalanceClaimant)
for _, claimant := range claimants {
claimantsMap[claimant.BalanceID] = append(claimantsMap[claimant.BalanceID], claimant)
}
return claimantsMap, err
}
// UpsertClaimableBalances upserts a batch of claimable balances in the claimable_balances table.
// It also upserts the corresponding claimants in the claimable_balance_claimants table.
func (q *Q) UpsertClaimableBalances(ctx context.Context, cbs []ClaimableBalance) error {
if err := q.upsertCBs(ctx, cbs); err != nil {
return errors.Wrap(err, "could not upsert claimable balances")
}
if err := q.upsertCBClaimants(ctx, cbs); err != nil {
return errors.Wrap(err, "could not upsert claimable balance claimants")
}
return nil
}
func (q *Q) upsertCBClaimants(ctx context.Context, cbs []ClaimableBalance) error {
var id, lastModifiedLedger, destination []interface{}
for _, cb := range cbs {
for _, claimant := range cb.Claimants {
id = append(id, cb.BalanceID)
lastModifiedLedger = append(lastModifiedLedger, cb.LastModifiedLedger)
destination = append(destination, claimant.Destination)
}
}
upsertFields := []upsertField{
{"id", "text", id},
{"destination", "text", destination},
{"last_modified_ledger", "integer", lastModifiedLedger},
}
return q.upsertRows(ctx, "claimable_balance_claimants", "id, destination", upsertFields)
}
func (q *Q) upsertCBs(ctx context.Context, cbs []ClaimableBalance) error {
var id, claimants, asset, amount, sponsor, lastModifiedLedger, flags []interface{}
for _, cb := range cbs {
id = append(id, cb.BalanceID)
claimants = append(claimants, cb.Claimants)
asset = append(asset, cb.Asset)
amount = append(amount, cb.Amount)
sponsor = append(sponsor, cb.Sponsor)
lastModifiedLedger = append(lastModifiedLedger, cb.LastModifiedLedger)
flags = append(flags, cb.Flags)
}
upsertFields := []upsertField{
{"id", "text", id},
{"claimants", "jsonb", claimants},
{"asset", "text", asset},
{"amount", "bigint", amount},
{"sponsor", "text", sponsor},
{"last_modified_ledger", "integer", lastModifiedLedger},
{"flags", "int", flags},
}
return q.upsertRows(ctx, "claimable_balances", "id", upsertFields)
}
// RemoveClaimableBalances deletes claimable balances table.
// Returns number of rows affected and error.
func (q *Q) RemoveClaimableBalances(ctx context.Context, ids []string) (int64, error) {
sql := sq.Delete("claimable_balances").
Where(sq.Eq{"id": ids})
result, err := q.Exec(ctx, sql)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// RemoveClaimableBalanceClaimants deletes claimable balance claimants.
// Returns number of rows affected and error.
func (q *Q) RemoveClaimableBalanceClaimants(ctx context.Context, ids []string) (int64, error) {
sql := sq.Delete("claimable_balance_claimants").
Where(sq.Eq{"id": ids})
result, err := q.Exec(ctx, sql)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// FindClaimableBalanceByID returns a claimable balance.
func (q *Q) FindClaimableBalanceByID(ctx context.Context, balanceID string) (ClaimableBalance, error) {
var claimableBalance ClaimableBalance
sql := selectClaimableBalances.Limit(1).Where("cb.id = ?", balanceID)
err := q.Get(ctx, &claimableBalance, sql)
return claimableBalance, err
}
// GetClaimableBalances finds all claimable balances where accountID is one of the claimants
func (q *Q) GetClaimableBalances(ctx context.Context, query ClaimableBalancesQuery) ([]ClaimableBalance, error) {
l, r, err := query.Cursor()
if err != nil {
return nil, errors.Wrap(err, "error getting cursor")
}
sql, err := applyClaimableBalancesQueriesCursor(selectClaimableBalances, "cb", l, r, query.PageQuery.Order)
if err != nil {
return nil, errors.Wrap(err, "could not apply query to page")
}
if query.Asset != nil || query.Sponsor != nil {
// JOIN with claimable_balance_claimants table to query by claimants
if query.Claimant != nil {
sql = sql.Join("claimable_balance_claimants on claimable_balance_claimants.id = cb.id")
sql = sql.Where("claimable_balance_claimants.destination = ?", query.Claimant.Address())
}
// Apply filters for asset and sponsor
if query.Asset != nil {
sql = sql.Where("cb.asset = ?", query.Asset)
}
if query.Sponsor != nil {
sql = sql.Where("cb.sponsor = ?", query.Sponsor.Address())
}
} else if query.Claimant != nil {
// If only the claimant is provided without additional filters, a JOIN with claimable_balance_claimants
// does not perform efficiently. Instead, use a subquery (with LIMIT) to retrieve claimable balances based on
// the claimant's address.
var selectClaimableBalanceClaimants = sq.Select("claimable_balance_claimants.id").From("claimable_balance_claimants").
Where("claimable_balance_claimants.destination = ?", query.Claimant.Address()).Limit(query.PageQuery.Limit)
subSql, err := applyClaimableBalancesQueriesCursor(selectClaimableBalanceClaimants, "claimable_balance_claimants", l, r, query.PageQuery.Order)
if err != nil {
return nil, errors.Wrap(err, "could not apply subquery to page")
}
subSqlString, subSqlArgs, err := subSql.ToSql()
if err != nil {
return nil, errors.Wrap(err, "could not build subquery")
}
sql = sql.
Where(fmt.Sprintf("cb.id IN (%s)", subSqlString), subSqlArgs...)
}
sql = sql.Limit(query.PageQuery.Limit)
var results []ClaimableBalance
if err := q.Select(ctx, &results, sql); err != nil {
return nil, errors.Wrap(err, "could not run select query")
}
return results, nil
}
var claimableBalancesSelectStatement = "cb.id, " +
"cb.claimants, " +
"cb.asset, " +
"cb.amount, " +
"cb.sponsor, " +
"cb.last_modified_ledger, " +
"cb.flags"
var selectClaimableBalances = sq.Select(claimableBalancesSelectStatement).From("claimable_balances cb")