diff --git a/satellite/payments/stripe/projectrecords.go b/satellite/payments/stripe/projectrecords.go index 8b16002aacfe..435487c4b9a6 100644 --- a/satellite/payments/stripe/projectrecords.go +++ b/satellite/payments/stripe/projectrecords.go @@ -26,7 +26,8 @@ type ProjectRecordsDB interface { // Consume consumes invoice project record. Consume(ctx context.Context, id uuid.UUID) error // ListUnapplied returns project records page with unapplied project records. - ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (ProjectRecordsPage, error) + // Cursor is not included into listing results. + ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (ProjectRecordsPage, error) } // CreateProjectRecord holds info needed for creation new invoice @@ -52,9 +53,9 @@ type ProjectRecord struct { // ProjectRecordsPage holds project records and // indicates if there is more data available -// and provides next offset. +// and provides cursor for next listing. type ProjectRecordsPage struct { - Records []ProjectRecord - Next bool - NextOffset int64 + Records []ProjectRecord + Next bool + Cursor uuid.UUID } diff --git a/satellite/payments/stripe/projectrecords_test.go b/satellite/payments/stripe/projectrecords_test.go index 4f3f7667cae9..86772493a88a 100644 --- a/satellite/payments/stripe/projectrecords_test.go +++ b/satellite/payments/stripe/projectrecords_test.go @@ -4,6 +4,7 @@ package stripe_test import ( + "fmt" "testing" "time" @@ -50,7 +51,7 @@ func TestProjectRecords(t *testing.T) { assert.Equal(t, stripe.ErrProjectRecordExists, err) }) - page, err := projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end) + page, err := projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end) require.NoError(t, err) require.Equal(t, 1, len(page.Records)) @@ -59,7 +60,7 @@ func TestProjectRecords(t *testing.T) { require.NoError(t, err) }) - page, err = projectRecordsDB.ListUnapplied(ctx, 0, 1, start, end) + page, err = projectRecordsDB.ListUnapplied(ctx, uuid.UUID{}, 1, start, end) require.NoError(t, err) require.Equal(t, 0, len(page.Records)) }) @@ -74,8 +75,7 @@ func TestProjectRecordsList(t *testing.T) { projectRecordsDB := db.StripeCoinPayments().ProjectRecords() - const limit = 5 - const recordsLen = limit * 4 + const recordsLen = 20 var createProjectRecords []stripe.CreateProjectRecord for i := 0; i < recordsLen; i++ { @@ -95,37 +95,42 @@ func TestProjectRecordsList(t *testing.T) { err := projectRecordsDB.Create(ctx, createProjectRecords, start, end) require.NoError(t, err) - page, err := projectRecordsDB.ListUnapplied(ctx, 0, limit, start, end) - require.NoError(t, err) - - records := page.Records + for _, limit := range []int{1, 3, 5, 30} { + t.Run(fmt.Sprintf("limit-%d", limit), func(t *testing.T) { + records := []stripe.ProjectRecord{} - for page.Next { - page, err = projectRecordsDB.ListUnapplied(ctx, page.NextOffset, limit, start, end) - require.NoError(t, err) + var page stripe.ProjectRecordsPage + for { + page, err = projectRecordsDB.ListUnapplied(ctx, page.Cursor, limit, start, end) + require.NoError(t, err) - records = append(records, page.Records...) - } - - require.Equal(t, recordsLen, len(records)) - assert.False(t, page.Next) - assert.Equal(t, int64(0), page.NextOffset) - - for _, record := range page.Records { - for _, createRecord := range createProjectRecords { - if record.ProjectID != createRecord.ProjectID { - continue + records = append(records, page.Records...) + if !page.Next { + break + } } - assert.NotNil(t, record.ID) - assert.Equal(t, 16, len(record.ID)) - assert.Equal(t, createRecord.ProjectID, record.ProjectID) - assert.Equal(t, createRecord.Storage, record.Storage) - assert.Equal(t, createRecord.Egress, record.Egress) - assert.Equal(t, createRecord.Segments, record.Segments) - assert.True(t, start.Equal(record.PeriodStart)) - assert.True(t, end.Equal(record.PeriodEnd)) - } + require.Equal(t, recordsLen, len(records)) + assert.False(t, page.Next) + assert.Equal(t, uuid.UUID{}, page.Cursor) + + for _, record := range page.Records { + for _, createRecord := range createProjectRecords { + if record.ProjectID != createRecord.ProjectID { + continue + } + + assert.NotNil(t, record.ID) + assert.Equal(t, 16, len(record.ID)) + assert.Equal(t, createRecord.ProjectID, record.ProjectID) + assert.Equal(t, createRecord.Storage, record.Storage) + assert.Equal(t, createRecord.Egress, record.Egress) + assert.Equal(t, createRecord.Segments, record.Segments) + assert.True(t, start.Equal(record.PeriodStart)) + assert.True(t, end.Equal(record.PeriodEnd)) + } + } + }) } }) } diff --git a/satellite/payments/stripe/service.go b/satellite/payments/stripe/service.go index 65ed0f8db964..9d3e5b27e318 100644 --- a/satellite/payments/stripe/service.go +++ b/satellite/payments/stripe/service.go @@ -243,7 +243,7 @@ func (service *Service) InvoiceApplyProjectRecords(ctx context.Context, period t } // we are always starting from offset 0 because applyProjectRecords is changing project record state to applied - recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, 0, service.listingLimit, start, end) + recordsPage, err := service.db.ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, service.listingLimit, start, end) if err != nil { return Error.Wrap(err) } diff --git a/satellite/payments/stripe/service_test.go b/satellite/payments/stripe/service_test.go index 1f58edf11486..90a2fe4958dc 100644 --- a/satellite/payments/stripe/service_test.go +++ b/satellite/payments/stripe/service_test.go @@ -139,7 +139,7 @@ func TestService_InvoiceElementsProcessing(t *testing.T) { end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC) // check if we have project record for each project - recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end) + recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end) require.NoError(t, err) require.Equal(t, numberOfProjects, len(recordsPage.Records)) @@ -147,7 +147,7 @@ func TestService_InvoiceElementsProcessing(t *testing.T) { require.NoError(t, err) // verify that we applied all unapplied project records - recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end) + recordsPage, err = satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end) require.NoError(t, err) require.Equal(t, 0, len(recordsPage.Records)) }) @@ -284,7 +284,7 @@ func TestService_ProjectsWithMembers(t *testing.T) { start := time.Date(period.Year(), period.Month(), 1, 0, 0, 0, 0, time.UTC) end := time.Date(period.Year(), period.Month()+1, 1, 0, 0, 0, 0, time.UTC) - recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, 0, 40, start, end) + recordsPage, err := satellite.DB.StripeCoinPayments().ProjectRecords().ListUnapplied(ctx, uuid.UUID{}, 40, start, end) require.NoError(t, err) require.Equal(t, len(projects), len(recordsPage.Records)) }) diff --git a/satellite/satellitedb/database.go b/satellite/satellitedb/database.go index aada891a1c8c..498b0df5314e 100644 --- a/satellite/satellitedb/database.go +++ b/satellite/satellitedb/database.go @@ -398,3 +398,13 @@ func (dbc *satelliteDBCollectionTesting) ProductionMigration() *migrate.Migratio func (dbc *satelliteDBCollectionTesting) TestMigration() *migrate.Migration { return dbc.getByName("").TestMigration() } + +func withRows(rows tagsql.Rows, err error) func(func(tagsql.Rows) error) error { + return func(callback func(tagsql.Rows) error) error { + if err != nil { + return err + } + err := callback(rows) + return errs.Combine(rows.Err(), rows.Close(), err) + } +} diff --git a/satellite/satellitedb/dbx/billing.dbx b/satellite/satellitedb/dbx/billing.dbx index a901cd3bb004..c3e0219ffb08 100644 --- a/satellite/satellitedb/dbx/billing.dbx +++ b/satellite/satellitedb/dbx/billing.dbx @@ -232,12 +232,6 @@ read one ( where stripecoinpayments_invoice_project_record.period_start = ? where stripecoinpayments_invoice_project_record.period_end = ? ) -read limitoffset ( - select stripecoinpayments_invoice_project_record - where stripecoinpayments_invoice_project_record.period_start = ? - where stripecoinpayments_invoice_project_record.period_end = ? - where stripecoinpayments_invoice_project_record.state = ? -) // stripecoinpayments_tx_conversion_rate contains information about a conversion-rate that was used in a transaction. model stripecoinpayments_tx_conversion_rate ( diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index 8e2b787fc65f..dd0ed1d2ce54 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -14567,57 +14567,6 @@ func (obj *pgxImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_ProjectId_And_ } -func (obj *pgxImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - defer mon.Task()(&ctx)(&err) - - var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?") - - var __values []interface{} - __values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value()) - - __values = append(__values, limit, offset) - - var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) - obj.logStmt(__stmt, __values...) - - for { - rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) - if err != nil { - return nil, err - } - defer __rows.Close() - - for __rows.Next() { - stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{} - err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt) - if err != nil { - return nil, err - } - rows = append(rows, stripecoinpayments_invoice_project_record) - } - err = __rows.Err() - if err != nil { - return nil, err - } - return rows, nil - }() - if err != nil { - if obj.shouldRetry(err) { - continue - } - return nil, obj.makeErr(err) - } - return rows, nil - } - -} - func (obj *pgxImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context, stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) ( stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) { @@ -22416,57 +22365,6 @@ func (obj *pgxcockroachImpl) Get_StripecoinpaymentsInvoiceProjectRecord_By_Proje } -func (obj *pgxcockroachImpl) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - defer mon.Task()(&ctx)(&err) - - var __embed_stmt = __sqlbundle_Literal("SELECT stripecoinpayments_invoice_project_records.id, stripecoinpayments_invoice_project_records.project_id, stripecoinpayments_invoice_project_records.storage, stripecoinpayments_invoice_project_records.egress, stripecoinpayments_invoice_project_records.objects, stripecoinpayments_invoice_project_records.segments, stripecoinpayments_invoice_project_records.period_start, stripecoinpayments_invoice_project_records.period_end, stripecoinpayments_invoice_project_records.state, stripecoinpayments_invoice_project_records.created_at FROM stripecoinpayments_invoice_project_records WHERE stripecoinpayments_invoice_project_records.period_start = ? AND stripecoinpayments_invoice_project_records.period_end = ? AND stripecoinpayments_invoice_project_records.state = ? LIMIT ? OFFSET ?") - - var __values []interface{} - __values = append(__values, stripecoinpayments_invoice_project_record_period_start.value(), stripecoinpayments_invoice_project_record_period_end.value(), stripecoinpayments_invoice_project_record_state.value()) - - __values = append(__values, limit, offset) - - var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) - obj.logStmt(__stmt, __values...) - - for { - rows, err = func() (rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) - if err != nil { - return nil, err - } - defer __rows.Close() - - for __rows.Next() { - stripecoinpayments_invoice_project_record := &StripecoinpaymentsInvoiceProjectRecord{} - err = __rows.Scan(&stripecoinpayments_invoice_project_record.Id, &stripecoinpayments_invoice_project_record.ProjectId, &stripecoinpayments_invoice_project_record.Storage, &stripecoinpayments_invoice_project_record.Egress, &stripecoinpayments_invoice_project_record.Objects, &stripecoinpayments_invoice_project_record.Segments, &stripecoinpayments_invoice_project_record.PeriodStart, &stripecoinpayments_invoice_project_record.PeriodEnd, &stripecoinpayments_invoice_project_record.State, &stripecoinpayments_invoice_project_record.CreatedAt) - if err != nil { - return nil, err - } - rows = append(rows, stripecoinpayments_invoice_project_record) - } - err = __rows.Err() - if err != nil { - return nil, err - } - return rows, nil - }() - if err != nil { - if obj.shouldRetry(err) { - continue - } - return nil, obj.makeErr(err) - } - return rows, nil - } - -} - func (obj *pgxcockroachImpl) Get_StripecoinpaymentsTxConversionRate_By_TxId(ctx context.Context, stripecoinpayments_tx_conversion_rate_tx_id StripecoinpaymentsTxConversionRate_TxId_Field) ( stripecoinpayments_tx_conversion_rate *StripecoinpaymentsTxConversionRate, err error) { @@ -29364,19 +29262,6 @@ func (rx *Rx) Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Des return tx.Limited_StorjscanPayment_By_ToAddress_OrderBy_Desc_BlockNumber_Desc_LogIndex(ctx, storjscan_payment_to_address, limit, offset) } -func (rx *Rx) Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) { - var tx *Tx - if tx, err = rx.getTx(ctx); err != nil { - return - } - return tx.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx, stripecoinpayments_invoice_project_record_period_start, stripecoinpayments_invoice_project_record_period_end, stripecoinpayments_invoice_project_record_state, limit, offset) -} - func (rx *Rx) Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context, bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field, limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) ( @@ -30513,13 +30398,6 @@ type Methods interface { limit int, offset int64) ( rows []*StorjscanPayment, err error) - Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx context.Context, - stripecoinpayments_invoice_project_record_period_start StripecoinpaymentsInvoiceProjectRecord_PeriodStart_Field, - stripecoinpayments_invoice_project_record_period_end StripecoinpaymentsInvoiceProjectRecord_PeriodEnd_Field, - stripecoinpayments_invoice_project_record_state StripecoinpaymentsInvoiceProjectRecord_State_Field, - limit int, offset int64) ( - rows []*StripecoinpaymentsInvoiceProjectRecord, err error) - Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual(ctx context.Context, bucket_bandwidth_rollup_archive_interval_start_greater_or_equal BucketBandwidthRollupArchive_IntervalStart_Field, limit int, start *Paged_BucketBandwidthRollupArchive_By_IntervalStart_GreaterOrEqual_Continuation) ( diff --git a/satellite/satellitedb/invoiceprojectrecords.go b/satellite/satellitedb/invoiceprojectrecords.go index fda85e7f1b06..210903029766 100644 --- a/satellite/satellitedb/invoiceprojectrecords.go +++ b/satellite/satellitedb/invoiceprojectrecords.go @@ -12,6 +12,7 @@ import ( "github.com/zeebo/errs" "storj.io/common/uuid" + "storj.io/private/tagsql" "storj.io/storj/satellite/payments/stripe" "storj.io/storj/satellite/satellitedb/dbx" ) @@ -129,36 +130,40 @@ func (db *invoiceProjectRecords) Consume(ctx context.Context, id uuid.UUID) (err } // ListUnapplied returns project records page with unapplied project records. -func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, offset int64, limit int, start, end time.Time) (_ stripe.ProjectRecordsPage, err error) { +// Cursor is not included into listing results. +func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, cursor uuid.UUID, limit int, start, end time.Time) (page stripe.ProjectRecordsPage, err error) { defer mon.Task()(&ctx)(&err) - var page stripe.ProjectRecordsPage + err = withRows(db.db.QueryContext(ctx, db.db.Rebind(` + SELECT + id, project_id, storage, egress, segments, period_start, period_end, state + FROM + stripecoinpayments_invoice_project_records + WHERE + id > ? AND period_start = ? AND period_end = ? AND state = ? + LIMIT ? + `), cursor, start, end, invoiceProjectRecordStateUnapplied.Int(), limit+1))(func(rows tagsql.Rows) error { + for rows.Next() { + var record stripe.ProjectRecord + err := rows.Scan(&record.ID, &record.ProjectID, &record.Storage, &record.Egress, &record.Segments, &record.PeriodStart, &record.PeriodEnd, &record.State) + if err != nil { + return Error.New("failed to scan stripe invoice project records: %w", err) + } - dbxRecords, err := db.db.Limited_StripecoinpaymentsInvoiceProjectRecord_By_PeriodStart_And_PeriodEnd_And_State(ctx, - dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodStart(start), - dbx.StripecoinpaymentsInvoiceProjectRecord_PeriodEnd(end), - dbx.StripecoinpaymentsInvoiceProjectRecord_State(invoiceProjectRecordStateUnapplied.Int()), - limit+1, - offset, - ) + page.Records = append(page.Records, record) + } + return nil + }) if err != nil { return stripe.ProjectRecordsPage{}, err } - if len(dbxRecords) == limit+1 { + if len(page.Records) == limit+1 { page.Next = true - page.NextOffset = offset + int64(limit) - dbxRecords = dbxRecords[:len(dbxRecords)-1] - } - - for _, dbxRecord := range dbxRecords { - record, err := fromDBXInvoiceProjectRecord(dbxRecord) - if err != nil { - return stripe.ProjectRecordsPage{}, err - } + page.Records = page.Records[:len(page.Records)-1] - page.Records = append(page.Records, *record) + page.Cursor = page.Records[len(page.Records)-1].ID } return page, nil