Skip to content

Commit

Permalink
satellite/payments/stripe: avoid full table scan while listing records
Browse files Browse the repository at this point in the history
Stripe invoice project records while listing are causing full table scan
because of OFFSET caluse. This change is refactoring query to list using
cursor.

Change-Id: I6b73b9b2815173d7ef02cf615408778476eb3b7b
  • Loading branch information
mniewrzal authored and Storj Robot committed May 9, 2023
1 parent c64f3f3 commit 87d0789
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 188 deletions.
11 changes: 6 additions & 5 deletions satellite/payments/stripe/projectrecords.go
Expand Up @@ -26,7 +26,8 @@ type ProjectRecordsDB interface {
// Consume consumes invoice project record.
Consume(ctx context.Context, id uuid.UUID) error
// ListUnapplied returns project records page with unapplied project records.
ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (ProjectRecordsPage, error)
// Cursor is not included into listing results.
ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (ProjectRecordsPage, error)
}

// CreateProjectRecord holds info needed for creation new invoice
Expand All @@ -52,9 +53,9 @@ type ProjectRecord struct {

// ProjectRecordsPage holds project records and
// indicates if there is more data available
// and provides next offset.
// and provides cursor for next listing.
type ProjectRecordsPage struct {
Records []ProjectRecord
Next bool
NextOffset int64
Records []ProjectRecord
Next bool
Cursor uuid.UUID
}
67 changes: 36 additions & 31 deletions satellite/payments/stripe/projectrecords_test.go
Expand Up @@ -4,6 +4,7 @@
package stripe_test

import (
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -50,7 +51,7 @@ func TestProjectRecords(t *testing.T) {
assert.Equal(t, stripe.ErrProjectRecordExists, err)
})

page, err := projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end)
page, err := projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end)
require.NoError(t, err)
require.Equal(t, 1, len(page.Records))

Expand All @@ -59,7 +60,7 @@ func TestProjectRecords(t *testing.T) {
require.NoError(t, err)
})

page, err = projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end)
page, err = projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end)
require.NoError(t, err)
require.Equal(t, 0, len(page.Records))
})
Expand All @@ -74,8 +75,7 @@ func TestProjectRecordsList(t *testing.T) {

projectRecordsDB := db.StripeCoinPayments().ProjectRecords()

const limit = 5
const recordsLen = limit * 4
const recordsLen = 20

var createProjectRecords []stripe.CreateProjectRecord
for i := 0; i < recordsLen; i++ {
Expand All @@ -95,37 +95,42 @@ func TestProjectRecordsList(t *testing.T) {
err := projectRecordsDB.Create(ctx, createProjectRecords, start, end)
require.NoError(t, err)

page, err := projectRecordsDB.ListUnapplied(ctx, 0, limit, start, end)
require.NoError(t, err)

records := page.Records
for _, limit := range []int{1, 3, 5, 30} {
t.Run(fmt.Sprintf("limit-%d", limit), func(t *testing.T) {
records := []stripe.ProjectRecord{}

for page.Next {
page, err = projectRecordsDB.ListUnapplied(ctx, page.NextOffset, limit, start, end)
require.NoError(t, err)
var page stripe.ProjectRecordsPage
for {
page, err = projectRecordsDB.ListUnapplied(ctx, page.Cursor, limit, start, end)
require.NoError(t, err)

records = append(records, page.Records...)
}

require.Equal(t, recordsLen, len(records))
assert.False(t, page.Next)
assert.Equal(t, int64(0), page.NextOffset)

for _, record := range page.Records {
for _, createRecord := range createProjectRecords {
if record.ProjectID != createRecord.ProjectID {
continue
records = append(records, page.Records...)
if !page.Next {
break
}
}

assert.NotNil(t, record.ID)
assert.Equal(t, 16, len(record.ID))
assert.Equal(t, createRecord.ProjectID, record.ProjectID)
assert.Equal(t, createRecord.Storage, record.Storage)
assert.Equal(t, createRecord.Egress, record.Egress)
assert.Equal(t, createRecord.Segments, record.Segments)
assert.True(t, start.Equal(record.PeriodStart))
assert.True(t, end.Equal(record.PeriodEnd))
}
require.Equal(t, recordsLen, len(records))
assert.False(t, page.Next)
assert.Equal(t, uuid.UUID{}, page.Cursor)

for _, record := range page.Records {
for _, createRecord := range createProjectRecords {
if record.ProjectID != createRecord.ProjectID {
continue
}

assert.NotNil(t, record.ID)
assert.Equal(t, 16, len(record.ID))
assert.Equal(t, createRecord.ProjectID, record.ProjectID)
assert.Equal(t, createRecord.Storage, record.Storage)
assert.Equal(t, createRecord.Egress, record.Egress)
assert.Equal(t, createRecord.Segments, record.Segments)
assert.True(t, start.Equal(record.PeriodStart))
assert.True(t, end.Equal(record.PeriodEnd))
}
}
})
}
})
}
2 changes: 1 addition & 1 deletion satellite/payments/stripe/service.go
Expand Up @@ -243,7 +243,7 @@ func (service *Service) InvoiceApplyProjectRecords(ctx context.Context, period t
}

// we are always starting from offset 0 because applyProjectRecords is changing project record state to applied
recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, 0, service.listingLimit, start, end)
recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, service.listingLimit, start, end)
if err != nil {
return Error.Wrap(err)
}
Expand Down
6 changes: 3 additions & 3 deletions satellite/payments/stripe/service_test.go
Expand Up @@ -139,15 +139,15 @@ func TestService_InvoiceElementsProcessing(t *testing.T) {
end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)

// check if we have project record for each project
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end)
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end)
require.NoError(t, err)
require.Equal(t, numberOfProjects, len(recordsPage.Records))

err = satellite.API.Payments.StripeService.InvoiceApplyProjectRecords(ctx, period)
require.NoError(t, err)

// verify that we applied all unapplied project records
recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end)
recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end)
require.NoError(t, err)
require.Equal(t, 0, len(recordsPage.Records))
})
Expand Down Expand Up @@ -284,7 +284,7 @@ func TestService_ProjectsWithMembers(t *testing.T) {
start := time.Date(period.Year(), period.Month(), 1, 0, 0, 0, 0, time.UTC)
end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC)

recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end)
recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end)
require.NoError(t, err)
require.Equal(t, len(projects), len(recordsPage.Records))
})
Expand Down
10 changes: 10 additions & 0 deletions satellite/satellitedb/database.go
Expand Up @@ -398,3 +398,13 @@ func (dbc *satelliteDBCollectionTesting) ProductionMigration() *migrate.Migratio
func (dbc *satelliteDBCollectionTesting) TestMigration() *migrate.Migration {
return dbc.getByName("").TestMigration()
}

func withRows(rows tagsql.Rows, err error) func(func(tagsql.Rows) error) error {
return func(callback func(tagsql.Rows) error) error {
if err != nil {
return err
}
err := callback(rows)
return errs.Combine(rows.Err(), rows.Close(), err)
}
}
6 changes: 0 additions & 6 deletions satellite/satellitedb/dbx/billing.dbx
Expand Up @@ -232,12 +232,6 @@ read one (
where stripecoinpayments_invoice_project_record.period_start = ?
where stripecoinpayments_invoice_project_record.period_end = ?
)
read limitoffset (
select stripecoinpayments_invoice_project_record
where stripecoinpayments_invoice_project_record.period_start = ?
where stripecoinpayments_invoice_project_record.period_end = ?
where stripecoinpayments_invoice_project_record.state = ?
)

// stripecoinpayments_tx_conversion_rate contains information about a conversion-rate that was used in a transaction.
model stripecoinpayments_tx_conversion_rate (
Expand Down
122 changes: 0 additions & 122 deletions satellite/satellitedb/dbx/satellitedb.dbx.go
Expand Up @@ -14567,57 +14567,6 @@ func (obj *pgxImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_ProjectId_And_

}

func (obj *pgxImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
defer mon.Task()(&ctx)(&err)

var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?")

var __values []interface{}
__values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value())

__values = append(__values, limit, offset)

var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)

for {
rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, err
}
defer __rows.Close()

for __rows.Next() {
stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{}
err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt)
if err != nil {
return nil, err
}
rows = append(rows, stripecoinpayments_invoice_project_record)
}
err = __rows.Err()
if err != nil {
return nil, err
}
return rows, nil
}()
if err != nil {
if obj.shouldRetry(err) {
continue
}
return nil, obj.makeErr(err)
}
return rows, nil
}

}

func (obj *pgxImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context,
stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) (
stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) {
Expand Down Expand Up @@ -22416,57 +22365,6 @@ func (obj *pgxcockroachImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_Proje

}

func (obj *pgxcockroachImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
defer mon.Task()(&ctx)(&err)

var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?")

var __values []interface{}
__values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value())

__values = append(__values, limit, offset)

var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt)
obj.logStmt(__stmt, __values...)

for {
rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
__rows, err := obj.driver.QueryContext(ctx, __stmt, __values...)
if err != nil {
return nil, err
}
defer __rows.Close()

for __rows.Next() {
stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{}
err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt)
if err != nil {
return nil, err
}
rows = append(rows, stripecoinpayments_invoice_project_record)
}
err = __rows.Err()
if err != nil {
return nil, err
}
return rows, nil
}()
if err != nil {
if obj.shouldRetry(err) {
continue
}
return nil, obj.makeErr(err)
}
return rows, nil
}

}

func (obj *pgxcockroachImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context,
stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) (
stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) {
Expand Down Expand Up @@ -29364,19 +29262,6 @@ func (rx *Rx) Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Des
return tx.Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Desc_LogIndex(ctx, storjscan_payment_to_address, limit, offset)
}

func (rx *Rx) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error) {
var tx *Tx
if tx, err = rx.getTx(ctx); err != nil {
return
}
return tx.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx, stripecoinpayments_invoice_project_record_period_start, stripecoinpayments_invoice_project_record_period_end, stripecoinpayments_invoice_project_record_state, limit, offset)
}

func (rx *Rx) Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context,
bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field,
limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) (
Expand Down Expand Up @@ -30513,13 +30398,6 @@ type Methods interface {
limit int, offset int64) (
rows []*StorjscanPayment, err error)

Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context,
stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field,
stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field,
stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field,
limit int, offset int64) (
rows []*StripecoinpaymentsInvoiceProjectRecord, err error)

Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context,
bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field,
limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) (
Expand Down

0 comments on commit 87d0789

Please sign in to comment.