Skip to content

Commit

Permalink
satellite/satellitedb: add billing ID info to customers db
Browse files Browse the repository at this point in the history
Add billing_customer_id column to stripe_customers table and add update
invoicing to use this field when present.

Change-Id: I4f0e42cb4923e8b092a7f29d41632c549b14b509
  • Loading branch information
dlamarmorgan committed Mar 20, 2024
1 parent 9dd6c51 commit 4eabe4a
Show file tree
Hide file tree
Showing 10 changed files with 993 additions and 17 deletions.
3 changes: 3 additions & 0 deletions satellite/payments/stripe/customers.go
Expand Up @@ -30,6 +30,8 @@ type CustomersDB interface {
UpdatePackage(ctx context.Context, userID uuid.UUID, packagePlan *string, timestamp *time.Time) (*Customer, error)
// GetPackageInfo returns the package plan and time of purchase for a user.
GetPackageInfo(ctx context.Context, userID uuid.UUID) (packagePlan *string, purchaseTime *time.Time, err error)
// GetStripeIDs returns stripe customer and billing ids.
GetStripeIDs(ctx context.Context, userID uuid.UUID) (billingID *string, customerID string, err error)

// TODO: get rid of this.
Raw() *dbx.DB
Expand All @@ -38,6 +40,7 @@ type CustomersDB interface {
// Customer holds customer id, user id, and package information.
type Customer struct {
ID string
BillingID *string
UserID uuid.UUID
PackagePlan *string
PackagePurchasedAt *time.Time
Expand Down
40 changes: 36 additions & 4 deletions satellite/payments/stripe/service.go
Expand Up @@ -540,7 +540,7 @@ func (service *Service) applyProjectRecords(ctx context.Context, records []Proje
continue
}

cusID, err := service.db.Customers().GetCustomerID(ctx, proj.OwnerID)
billingID, cusID, err := service.db.Customers().GetStripeIDs(ctx, proj.OwnerID)
if err != nil {
if errors.Is(err, ErrNoCustomer) {
service.log.Warn("Stripe customer does not exist for project owner.", zap.Stringer("Owner ID", proj.OwnerID), zap.Stringer("Project ID", proj.ID))
Expand All @@ -552,7 +552,7 @@ func (service *Service) applyProjectRecords(ctx context.Context, records []Proje

record := record
limiter.Go(ctx, func() {
skipped, err := service.createInvoiceItems(ctx, cusID, proj.Name, record, proj.OwnerID, period)
skipped, err := service.createInvoiceItems(ctx, billingID, cusID, proj.Name, record, proj.OwnerID, period)
if err != nil {
mu.Lock()
errGrp.Add(errs.Wrap(err))
Expand Down Expand Up @@ -618,7 +618,7 @@ func (service *Service) applyToBeAggregatedProjectRecords(ctx context.Context, r
}

// createInvoiceItems creates invoice line items for stripe customer.
func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName string, record ProjectRecord, userID uuid.UUID, period time.Time) (skipped bool, err error) {
func (service *Service) createInvoiceItems(ctx context.Context, billingID *string, cusID, projName string, record ProjectRecord, userID uuid.UUID, period time.Time) (skipped bool, err error) {
defer mon.Task()(&ctx)(&err)

if !service.useIdempotency {
Expand Down Expand Up @@ -651,10 +651,21 @@ func (service *Service) createInvoiceItems(ctx context.Context, cusID, projName
}

items := service.InvoiceItemsFromProjectUsage(projName, usages, false)

var invoiceID *string
if billingID == nil {
billingID = &cusID
} else {
// create parent invoice
invoiceID, err = service.createParentInvoice(ctx, *billingID, cusID, projName, period)
}
for _, item := range items {
item.Params = stripe.Params{Context: ctx}
item.Currency = stripe.String(string(stripe.CurrencyUSD))
item.Customer = stripe.String(cusID)
item.Customer = stripe.String(*billingID)
if invoiceID != nil {
item.Invoice = invoiceID
}
// TODO: do not expose regular project ID.
item.AddMetadata("projectID", record.ProjectID.String())

Expand Down Expand Up @@ -1284,6 +1295,27 @@ func (service *Service) createInvoices(ctx context.Context, customers []Customer
return scheduled, draft, errGrp.Err()
}

// createParentInvoice creates a parent invoice for the customer.
func (service *Service) createParentInvoice(ctx context.Context, billingID, cusID, projName string, period time.Time) (invoiceID *string, err error) {
defer mon.Task()(&ctx)(&err)

description := fmt.Sprintf("Storj Cloud Storage for child project %s and period %s %d", projName, period.UTC().Month(), period.UTC().Year())
stripeInvoice, err := service.stripeClient.Invoices().New(
&stripe.InvoiceParams{
Params: stripe.Params{Context: ctx},
Customer: stripe.String(billingID),
AutoAdvance: stripe.Bool(false),
Description: stripe.String(description),
PendingInvoiceItemsBehavior: stripe.String("exclude"),
Metadata: map[string]string{"Child Account": cusID},
},
)
if err != nil {
return nil, err
}
return &stripeInvoice.ID, nil
}

// SetInvoiceStatus will set all open invoices within the specified date range to the requested status.
func (service *Service) SetInvoiceStatus(ctx context.Context, startPeriod, endPeriod time.Time, status string, dryRun bool) (err error) {
defer mon.Task()(&ctx)(&err)
Expand Down
21 changes: 19 additions & 2 deletions satellite/satellitedb/customers.go
Expand Up @@ -61,6 +61,22 @@ func (customers *customers) GetCustomerID(ctx context.Context, userID uuid.UUID)
return idRow.CustomerId, nil
}

// GetStripeIDs returns stripe customer and billing ids.
func (customers *customers) GetStripeIDs(ctx context.Context, userID uuid.UUID) (billingID *string, customerID string, err error) {
defer mon.Task()(&ctx, userID)(&err)

idRow, err := customers.db.Get_StripeCustomer_CustomerId_StripeCustomer_BillingCustomerId_By_UserId(ctx, dbx.StripeCustomer_UserId(userID[:]))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, "", stripe.ErrNoCustomer
}

return nil, "", err
}

return idRow.BillingCustomerId, idRow.CustomerId, nil
}

// GetUserID return userID given stripe customer id.
func (customers *customers) GetUserID(ctx context.Context, customerID string) (_ uuid.UUID, err error) {
defer mon.Task()(&ctx)(&err)
Expand All @@ -83,7 +99,7 @@ func (customers *customers) List(ctx context.Context, userIDCursor uuid.UUID, li

rows, err := customers.db.QueryContext(ctx, customers.db.Rebind(`
SELECT
stripe_customers.user_id, stripe_customers.customer_id, stripe_customers.package_plan, stripe_customers.purchased_package_at
stripe_customers.user_id, stripe_customers.customer_id, stripe_customers.billing_customer_id, stripe_customers.package_plan, stripe_customers.purchased_package_at
FROM
stripe_customers
WHERE
Expand All @@ -103,7 +119,7 @@ func (customers *customers) List(ctx context.Context, userIDCursor uuid.UUID, li
results := []stripe.Customer{}
for rows.Next() {
var customer stripe.Customer
err := rows.Scan(&customer.UserID, &customer.ID, &customer.PackagePlan, &customer.PackagePurchasedAt)
err := rows.Scan(&customer.UserID, &customer.ID, &customer.BillingID, &customer.PackagePlan, &customer.PackagePurchasedAt)
if err != nil {
return stripe.CustomersPage{}, errs.New("unable to get stripe customer: %+v", err)
}
Expand Down Expand Up @@ -171,6 +187,7 @@ func fromDBXCustomer(dbxCustomer *dbx.StripeCustomer) (*stripe.Customer, error)
return &stripe.Customer{
ID: dbxCustomer.CustomerId,
UserID: userID,
BillingID: dbxCustomer.BillingCustomerId,
PackagePlan: dbxCustomer.PackagePlan,
PackagePurchasedAt: dbxCustomer.PurchasedPackageAt,
}, nil
Expand Down
6 changes: 6 additions & 0 deletions satellite/satellitedb/dbx/billing.dbx
Expand Up @@ -7,6 +7,8 @@ model stripe_customer (
field user_id blob
// customer_id is the Stripe customer identifier.
field customer_id text
// billing_customer_id is the Stripe customer identifier of the account to send invoices to.
field billing_customer_id text ( nullable, updatable )
// package_plan is the package plan currently applicable to the customer.
field package_plan text ( nullable, updatable )
// purchased_package_at is the time the package plan was purchased.
Expand All @@ -30,6 +32,10 @@ read one (
select stripe_customer.user_id
where stripe_customer.customer_id = ?
)
read one (
select stripe_customer.customer_id stripe_customer.billing_customer_id
where stripe_customer.user_id = ?
)

update stripe_customer (
where stripe_customer.user_id = ?
Expand Down

0 comments on commit 4eabe4a

Please sign in to comment.