From 87d0789691c2571daa5ead2344111f90a2281c41 Mon Sep 17 00:00:00 2001 From: Michal Niewrzal Date: Mon, 8 May 2023 13:15:09 +0200 Subject: [PATCH] satellite/payments/stripe: avoid full table scan while listing records Stripe invoice project records while listing are causing full table scan because of OFFSET caluse. This change is refactoring query to list using cursor. Change-Id: I6b73b9b2815173d7ef02cf615408778476eb3b7b --- satellite/payments/stripe/projectrecords.go | 11 +- .../payments/stripe/projectrecords_test.go | 67 +++++----- satellite/payments/stripe/service.go | 2 +- satellite/payments/stripe/service_test.go | 6 +- satellite/satellitedb/database.go | 10 ++ satellite/satellitedb/dbx/billing.dbx | 6 - satellite/satellitedb/dbx/satellitedb.dbx.go | 122 ------------------ .../satellitedb/invoiceprojectrecords.go | 45 ++++--- 8 files changed, 81 insertions(+), 188 deletions(-) diff --git a/satellite/payments/stripe/projectrecords.go b/satellite/payments/stripe/projectrecords.go index 8b16002aacfe..435487c4b9a6 100644 --- a/satellite/payments/stripe/projectrecords.go +++ b/satellite/payments/stripe/projectrecords.go @@ -26,7 +26,8 @@ type ProjectRecordsDB interface { // Consume consumes invoice project record. Consume(ctx context.Context, id uuid.UUID) error // ListUnapplied returns project records page with unapplied project records. - ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (ProjectRecordsPage, error) + // Cursor is not included into listing results. + ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (ProjectRecordsPage, error) } // CreateProjectRecord holds info needed for creation new invoice @@ -52,9 +53,9 @@ type ProjectRecord struct { // ProjectRecordsPage holds project records and // indicates if there is more data available -// and provides next offset. +// and provides cursor for next listing. type ProjectRecordsPage struct { - Records []ProjectRecord - Next bool - NextOffset int64 + Records []ProjectRecord + Next bool + Cursor uuid.UUID } diff --git a/satellite/payments/stripe/projectrecords_test.go b/satellite/payments/stripe/projectrecords_test.go index 4f3f7667cae9..86772493a88a 100644 --- a/satellite/payments/stripe/projectrecords_test.go +++ b/satellite/payments/stripe/projectrecords_test.go @@ -4,6 +4,7 @@ package stripe_test import ( + "fmt" "testing" "time" @@ -50,7 +51,7 @@ func TestProjectRecords(t *testing.T) { assert.Equal(t, stripe.ErrProjectRecordExists, err) }) - page, err := projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end) + page, err := projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end) require.NoError(t, err) require.Equal(t, 1, len(page.Records)) @@ -59,7 +60,7 @@ func TestProjectRecords(t *testing.T) { require.NoError(t, err) }) - page, err = projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end) + page, err = projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end) require.NoError(t, err) require.Equal(t, 0, len(page.Records)) }) @@ -74,8 +75,7 @@ func TestProjectRecordsList(t *testing.T) { projectRecordsDB := db.StripeCoinPayments().ProjectRecords() - const limit = 5 - const recordsLen = limit * 4 + const recordsLen = 20 var createProjectRecords []stripe.CreateProjectRecord for i := 0; i < recordsLen; i++ { @@ -95,37 +95,42 @@ func TestProjectRecordsList(t *testing.T) { err := projectRecordsDB.Create(ctx, createProjectRecords, start, end) require.NoError(t, err) - page, err := projectRecordsDB.ListUnapplied(ctx, 0, limit, start, end) - require.NoError(t, err) - - records := page.Records + for _, limit := range []int{1, 3, 5, 30} { + t.Run(fmt.Sprintf("limit-%d", limit), func(t *testing.T) { + records := []stripe.ProjectRecord{} - for page.Next { - page, err = projectRecordsDB.ListUnapplied(ctx, page.NextOffset, limit, start, end) - require.NoError(t, err) + var page stripe.ProjectRecordsPage + for { + page, err = projectRecordsDB.ListUnapplied(ctx, page.Cursor, limit, start, end) + require.NoError(t, err) - records = append(records, page.Records...) - } - - require.Equal(t, recordsLen, len(records)) - assert.False(t, page.Next) - assert.Equal(t, int64(0), page.NextOffset) - - for _, record := range page.Records { - for _, createRecord := range createProjectRecords { - if record.ProjectID != createRecord.ProjectID { - continue + records = append(records, page.Records...) + if !page.Next { + break + } } - assert.NotNil(t, record.ID) - assert.Equal(t, 16, len(record.ID)) - assert.Equal(t, createRecord.ProjectID, record.ProjectID) - assert.Equal(t, createRecord.Storage, record.Storage) - assert.Equal(t, createRecord.Egress, record.Egress) - assert.Equal(t, createRecord.Segments, record.Segments) - assert.True(t, start.Equal(record.PeriodStart)) - assert.True(t, end.Equal(record.PeriodEnd)) - } + require.Equal(t, recordsLen, len(records)) + assert.False(t, page.Next) + assert.Equal(t, uuid.UUID{}, page.Cursor) + + for _, record := range page.Records { + for _, createRecord := range createProjectRecords { + if record.ProjectID != createRecord.ProjectID { + continue + } + + assert.NotNil(t, record.ID) + assert.Equal(t, 16, len(record.ID)) + assert.Equal(t, createRecord.ProjectID, record.ProjectID) + assert.Equal(t, createRecord.Storage, record.Storage) + assert.Equal(t, createRecord.Egress, record.Egress) + assert.Equal(t, createRecord.Segments, record.Segments) + assert.True(t, start.Equal(record.PeriodStart)) + assert.True(t, end.Equal(record.PeriodEnd)) + } + } + }) } }) } diff --git a/satellite/payments/stripe/service.go b/satellite/payments/stripe/service.go index 65ed0f8db964..9d3e5b27e318 100644 --- a/satellite/payments/stripe/service.go +++ b/satellite/payments/stripe/service.go @@ -243,7 +243,7 @@ func (service *Service) InvoiceApplyProjectRecords(ctx context.Context, period t } // we are always starting from offset 0 because applyProjectRecords is changing project record state to applied - recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, 0, service.listingLimit, start, end) + recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, service.listingLimit, start, end) if err != nil { return Error.Wrap(err) } diff --git a/satellite/payments/stripe/service_test.go b/satellite/payments/stripe/service_test.go index 1f58edf11486..90a2fe4958dc 100644 --- a/satellite/payments/stripe/service_test.go +++ b/satellite/payments/stripe/service_test.go @@ -139,7 +139,7 @@ func TestService_InvoiceElementsProcessing(t *testing.T) { end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC) // check if we have project record for each project - recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end) + recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end) require.NoError(t, err) require.Equal(t, numberOfProjects, len(recordsPage.Records)) @@ -147,7 +147,7 @@ func TestService_InvoiceElementsProcessing(t *testing.T) { require.NoError(t, err) // verify that we applied all unapplied project records - recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end) + recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end) require.NoError(t, err) require.Equal(t, 0, len(recordsPage.Records)) }) @@ -284,7 +284,7 @@ func TestService_ProjectsWithMembers(t *testing.T) { start := time.Date(period.Year(), period.Month(), 1, 0, 0, 0, 0, time.UTC) end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC) - recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end) + recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end) require.NoError(t, err) require.Equal(t, len(projects), len(recordsPage.Records)) }) diff --git a/satellite/satellitedb/database.go b/satellite/satellitedb/database.go index aada891a1c8c..498b0df5314e 100644 --- a/satellite/satellitedb/database.go +++ b/satellite/satellitedb/database.go @@ -398,3 +398,13 @@ func (dbc *satelliteDBCollectionTesting) ProductionMigration() *migrate.Migratio func (dbc *satelliteDBCollectionTesting) TestMigration() *migrate.Migration { return dbc.getByName("").TestMigration() } + +func withRows(rows tagsql.Rows, err error) func(func(tagsql.Rows) error) error { + return func(callback func(tagsql.Rows) error) error { + if err != nil { + return err + } + err := callback(rows) + return errs.Combine(rows.Err(), rows.Close(), err) + } +} diff --git a/satellite/satellitedb/dbx/billing.dbx b/satellite/satellitedb/dbx/billing.dbx index a901cd3bb004..c3e0219ffb08 100644 --- a/satellite/satellitedb/dbx/billing.dbx +++ b/satellite/satellitedb/dbx/billing.dbx @@ -232,12 +232,6 @@ read one ( where stripecoinpayments_invoice_project_record.period_start = ? where stripecoinpayments_invoice_project_record.period_end = ? ) -read limitoffset ( - select stripecoinpayments_invoice_project_record - where stripecoinpayments_invoice_project_record.period_start = ? - where stripecoinpayments_invoice_project_record.period_end = ? - where stripecoinpayments_invoice_project_record.state = ? -) // stripecoinpayments_tx_conversion_rate contains information about a conversion-rate that was used in a transaction. model stripecoinpayments_tx_conversion_rate ( diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index 8e2b787fc65f..dd0ed1d2ce54 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -14567,57 +14567,6 @@ func (obj *pgxImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_ProjectId_And_ } -func (obj *pgxImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - defer mon.Task()(&ctx)(&err) - - var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?") - - var __values []interface{} - __values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value()) - - __values = append(__values, limit, offset) - - var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) - obj.logStmt(__stmt, __values...) - - for { - rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) - if err != nil { - return nil, err - } - defer __rows.Close() - - for __rows.Next() { - stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{} - err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt) - if err != nil { - return nil, err - } - rows = append(rows, stripecoinpayments_invoice_project_record) - } - err = __rows.Err() - if err != nil { - return nil, err - } - return rows, nil - }() - if err != nil { - if obj.shouldRetry(err) { - continue - } - return nil, obj.makeErr(err) - } - return rows, nil - } - -} - func (obj *pgxImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context, stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) ( stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) { @@ -22416,57 +22365,6 @@ func (obj *pgxcockroachImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_Proje } -func (obj *pgxcockroachImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - defer mon.Task()(&ctx)(&err) - - var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?") - - var __values []interface{} - __values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value()) - - __values = append(__values, limit, offset) - - var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) - obj.logStmt(__stmt, __values...) - - for { - rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) - if err != nil { - return nil, err - } - defer __rows.Close() - - for __rows.Next() { - stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{} - err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt) - if err != nil { - return nil, err - } - rows = append(rows, stripecoinpayments_invoice_project_record) - } - err = __rows.Err() - if err != nil { - return nil, err - } - return rows, nil - }() - if err != nil { - if obj.shouldRetry(err) { - continue - } - return nil, obj.makeErr(err) - } - return rows, nil - } - -} - func (obj *pgxcockroachImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context, stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) ( stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) { @@ -29364,19 +29262,6 @@ func (rx *Rx) Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Des return tx.Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Desc_LogIndex(ctx, storjscan_payment_to_address, limit, offset) } -func (rx *Rx) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - var tx *Tx - if tx, err = rx.getTx(ctx); err != nil { - return - } - return tx.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx, stripecoinpayments_invoice_project_record_period_start, stripecoinpayments_invoice_project_record_period_end, stripecoinpayments_invoice_project_record_state, limit, offset) -} - func (rx *Rx) Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context, bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field, limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) ( @@ -30513,13 +30398,6 @@ type Methods interface { limit int, offset int64) ( rows []*StorjscanPayment, err error) - Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) - Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context, bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field, limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) ( diff --git a/satellite/satellitedb/invoiceprojectrecords.go b/satellite/satellitedb/invoiceprojectrecords.go index fda85e7f1b06..210903029766 100644 --- a/satellite/satellitedb/invoiceprojectrecords.go +++ b/satellite/satellitedb/invoiceprojectrecords.go @@ -12,6 +12,7 @@ import ( "github.com/zeebo/errs" "storj.io/common/uuid" + "storj.io/private/tagsql" "storj.io/storj/satellite/payments/stripe" "storj.io/storj/satellite/satellitedb/dbx" ) @@ -129,36 +130,40 @@ func (db *invoiceProjectRecords) Consume(ctx context.Context, id uuid.UUID) (err } // ListUnapplied returns project records page with unapplied project records. -func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (_ stripe.ProjectRecordsPage, err error) { +// Cursor is not included into listing results. +func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (page stripe.ProjectRecordsPage, err error) { defer mon.Task()(&ctx)(&err) - var page stripe.ProjectRecordsPage + err = withRows(db.db.QueryContext(ctx, db.db.Rebind(` + SELECT + id, project_id, storage, egress, segments, period_start, period_end, state + FROM + stripecoinpayments_invoice_project_records + WHERE + id > ? AND period_start = ? AND period_end = ? AND state = ? + LIMIT ? + `), cursor, start, end, invoiceProjectRecordStateUnapplied.Int(), limit+1))(func(rows tagsql.Rows) error { + for rows.Next() { + var record stripe.ProjectRecord + err := rows.Scan(&record.ID, &record.ProjectID, &record.Storage, &record.Egress, &record.Segments, &record.PeriodStart, &record.PeriodEnd, &record.State) + if err != nil { + return Error.New("failed to scan stripe invoice project records: %w", err) + } - dbxRecords, err := db.db.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx, - dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodStart(start), - dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodEnd(end), - dbx.StripecoinpaymentsInvoiceProjectRecord_State(invoiceProjectRecordStateUnapplied.Int()), - limit+1, - offset, - ) + page.Records = append(page.Records, record) + } + return nil + }) if err != nil { return stripe.ProjectRecordsPage{}, err } - if len(dbxRecords) == limit+1 { + if len(page.Records) == limit+1 { page.Next = true - page.NextOffset = offset + int64(limit) - dbxRecords = dbxRecords[:len(dbxRecords)-1] - } - - for _, dbxRecord := range dbxRecords { - record, err := fromDBXInvoiceProjectRecord(dbxRecord) - if err != nil { - return stripe.ProjectRecordsPage{}, err - } + page.Records = page.Records[:len(page.Records)-1] - page.Records = append(page.Records, *record) + page.Cursor = page.Records[len(page.Records)-1].ID } return page, nil