diff --git a/satellite/payments/stripe/invoices.go b/satellite/payments/stripe/invoices.go index 25821cccde8c..5f7321083bd9 100644 --- a/satellite/payments/stripe/invoices.go +++ b/satellite/payments/stripe/invoices.go @@ -110,23 +110,52 @@ func (invoices *invoices) Get(ctx context.Context, invoiceID string) (*payments. // AttemptPayOverdueInvoices attempts to pay a user's open, overdue invoices. func (invoices *invoices) AttemptPayOverdueInvoices(ctx context.Context, userID uuid.UUID) (err error) { + defer mon.Task()(&ctx, userID)(&err) + customerID, err := invoices.service.db.Customers().GetCustomerID(ctx, userID) if err != nil { return Error.Wrap(err) } - params := &stripe.InvoiceListParams{ - ListParams: stripe.ListParams{Context: ctx}, - Customer: &customerID, - Status: stripe.String(string(stripe.InvoiceStatusOpen)), + stripeInvoices, err := invoices.service.getInvoices(ctx, customerID, time.Unix(0, 0)) + if err != nil { + return Error.Wrap(err) } - var errGrp errs.Group + if len(stripeInvoices) == 0 { + return nil + } - invoicesIterator := invoices.service.stripeClient.Invoices().List(params) - for invoicesIterator.Next() { - stripeInvoice := invoicesIterator.Invoice() + // first check users token balance + monetaryTokenBalance, err := invoices.service.billingDB.GetBalance(ctx, userID) + if err != nil { + invoices.service.log.Error("error getting token balance", zap.Error(err)) + return Error.Wrap(err) + } + if monetaryTokenBalance.BaseUnits() > 0 { + err := invoices.service.PayInvoicesWithTokenBalance(ctx, userID, customerID, stripeInvoices) + if err != nil { + invoices.service.log.Error("error paying invoice(s) with token balance", zap.Error(err)) + return Error.Wrap(err) + } + // get invoices again to see if any are still unpaid + stripeInvoices, err = invoices.service.getInvoices(ctx, customerID, time.Unix(0, 0)) + if err != nil { + invoices.service.log.Error("error getting invoices for stripe customer", zap.String(customerID, customerID), zap.Error(err)) + return Error.Wrap(err) + } + } + if len(stripeInvoices) > 0 { + return invoices.attemptPayOverdueInvoicesWithCC(ctx, stripeInvoices) + } + return nil +} +// AttemptPayOverdueInvoices attempts to pay a user's open, overdue invoices. +func (invoices *invoices) attemptPayOverdueInvoicesWithCC(ctx context.Context, stripeInvoices []stripe.Invoice) (err error) { + var errGrp errs.Group + + for _, stripeInvoice := range stripeInvoices { params := &stripe.InvoicePayParams{Params: stripe.Params{Context: ctx}} invResponse, err := invoices.service.stripeClient.Invoices().Pay(stripeInvoice.ID, params) if err != nil { @@ -140,10 +169,6 @@ func (invoices *invoices) AttemptPayOverdueInvoices(ctx context.Context, userID } - if err = invoicesIterator.Err(); err != nil { - return Error.Wrap(err) - } - return errGrp.Err() } diff --git a/satellite/payments/stripe/invoices_test.go b/satellite/payments/stripe/invoices_test.go index 16e30244b042..3c2b9da66598 100644 --- a/satellite/payments/stripe/invoices_test.go +++ b/satellite/payments/stripe/invoices_test.go @@ -5,13 +5,18 @@ package stripe_test import ( "testing" + "time" "github.com/stretchr/testify/require" "github.com/stripe/stripe-go/v72" + "storj.io/common/currency" "storj.io/common/testcontext" "storj.io/common/testrand" + "storj.io/storj/private/blockchain" "storj.io/storj/private/testplanet" + "storj.io/storj/satellite/console" + "storj.io/storj/satellite/payments/billing" stripe1 "storj.io/storj/satellite/payments/stripe" ) @@ -71,3 +76,148 @@ func TestInvoices(t *testing.T) { }) }) } + +func TestPayOverdueInvoices(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + satellite := planet.Satellites[0] + + // create user + user, err := satellite.AddUser(ctx, console.CreateUser{ + FullName: "testuser", + Email: "user@test", + }, 1) + require.NoError(t, err) + customer, err := satellite.DB.StripeCoinPayments().Customers().GetCustomerID(ctx, user.ID) + require.NoError(t, err) + + amount1 := int64(75) + amount2 := int64(100) + curr := string(stripe.CurrencyUSD) + + // create invoice items for first invoice + inv1Item1, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: &amount1, + Currency: &curr, + Customer: &customer, + }) + require.NoError(t, err) + inv1Item2, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: &amount1, + Currency: &curr, + Customer: &customer, + }) + require.NoError(t, err) + Inv1Items := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 2) + Inv1Items = append(Inv1Items, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &inv1Item1.ID, + Amount: &amount1, + Currency: &curr, + }) + Inv1Items = append(Inv1Items, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &inv1Item2.ID, + Amount: &amount1, + Currency: &curr, + }) + + // invoice items for second invoice + inv2Item1, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: &amount2, + Currency: &curr, + Customer: &customer, + }) + require.NoError(t, err) + inv2Item2, err := satellite.API.Payments.StripeClient.InvoiceItems().New(&stripe.InvoiceItemParams{ + Params: stripe.Params{Context: ctx}, + Amount: &amount2, + Currency: &curr, + Customer: &customer, + }) + require.NoError(t, err) + Inv2Items := make([]*stripe.InvoiceUpcomingInvoiceItemParams, 0, 2) + Inv2Items = append(Inv2Items, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &inv2Item1.ID, + Amount: &amount2, + Currency: &curr, + }) + Inv2Items = append(Inv2Items, &stripe.InvoiceUpcomingInvoiceItemParams{ + InvoiceItem: &inv2Item2.ID, + Amount: &amount2, + Currency: &curr, + }) + + // create invoice one + inv1, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ + Params: stripe.Params{Context: ctx}, + Customer: &customer, + InvoiceItems: Inv1Items, + DefaultPaymentMethod: stripe.String(stripe1.MockInvoicesPaySuccess), + }) + require.NoError(t, err) + + // create invoice two + inv2, err := satellite.API.Payments.StripeClient.Invoices().New(&stripe.InvoiceParams{ + Params: stripe.Params{Context: ctx}, + Customer: &customer, + InvoiceItems: Inv2Items, + DefaultPaymentMethod: stripe.String(stripe1.MockInvoicesPaySuccess), + }) + require.NoError(t, err) + + finalizeParams := &stripe.InvoiceFinalizeParams{Params: stripe.Params{Context: ctx}} + + // finalize invoice one + inv1, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv1.ID, finalizeParams) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv1.Status) + + // finalize invoice two + inv2, err = satellite.API.Payments.StripeClient.Invoices().FinalizeInvoice(inv2.ID, finalizeParams) + require.NoError(t, err) + require.Equal(t, stripe.InvoiceStatusOpen, inv2.Status) + + // setup storjscan wallet and user balance + address, err := blockchain.BytesToAddress(testrand.Bytes(20)) + require.NoError(t, err) + userID := user.ID + err = satellite.DB.Wallets().Add(ctx, userID, address) + require.NoError(t, err) + // User balance is not enough to cover full amount of both invoices + _, err = satellite.DB.Billing().Insert(ctx, billing.Transaction{ + UserID: userID, + Amount: currency.AmountFromBaseUnits(300, currency.USDollars), + Description: "token payment credit", + Source: billing.StorjScanSource, + Status: billing.TransactionStatusCompleted, + Type: billing.TransactionTypeCredit, + Metadata: nil, + Timestamp: time.Now(), + CreatedAt: time.Now(), + }) + require.NoError(t, err) + + // attempt to pay user invoices, CC should be used to cover remainder after token balance + err = satellite.API.Payments.Accounts.Invoices().AttemptPayOverdueInvoices(ctx, userID) + require.NoError(t, err) + + iter := satellite.API.Payments.StripeClient.Invoices().List(&stripe.InvoiceListParams{ + ListParams: stripe.ListParams{Context: ctx}, + }) + var stripeInvoices []*stripe.Invoice + for iter.Next() { + stripeInvoices = append(stripeInvoices, iter.Invoice()) + } + require.Equal(t, 2, len(stripeInvoices)) + require.Equal(t, stripe.InvoiceStatusPaid, stripeInvoices[0].Status) + require.Equal(t, stripe.InvoiceStatusPaid, stripeInvoices[1].Status) + require.NoError(t, iter.Err()) + balance, err := satellite.DB.Billing().GetBalance(ctx, userID) + require.NoError(t, err) + require.False(t, balance.IsNegative()) + require.Zero(t, balance.BaseUnits()) + }) +}