Skip to content
8 changes: 8 additions & 0 deletions admin/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func (s *Service) InitOrganizationBilling(ctx context.Context, org *database.Org
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: org.BillingPlanName,
BillingPlanDisplayName: org.BillingPlanDisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
Expand Down Expand Up @@ -162,6 +164,8 @@ func (s *Service) RepairOrganizationBilling(ctx context.Context, org *database.O
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: org.BillingPlanName,
BillingPlanDisplayName: org.BillingPlanDisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
Expand Down Expand Up @@ -221,6 +225,8 @@ func (s *Service) RepairOrganizationBilling(ctx context.Context, org *database.O
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: org.BillingPlanName,
BillingPlanDisplayName: org.BillingPlanDisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
Expand Down Expand Up @@ -294,6 +300,8 @@ func (s *Service) StartTrial(ctx context.Context, org *database.Organization) (*
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: &plan.Name,
BillingPlanDisplayName: &plan.DisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions admin/billing/orb.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ func (o *Orb) ChangeSubscriptionPlan(ctx context.Context, subscriptionID string,
if err != nil {
return nil, err
}

return &Subscription{
ID: s.ID,
Customer: getBillingCustomerFromOrbCustomer(&s.Customer),
Expand Down Expand Up @@ -282,6 +283,7 @@ func (o *Orb) CancelSubscriptionsForCustomer(ctx context.Context, customerID str
cancelDate = sub.EndDate
}
}

return cancelDate, nil
}

Expand Down
58 changes: 56 additions & 2 deletions admin/billing/orb_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/orbcorp/orb-go"
"github.com/rilldata/rill/admin/jobs"
"github.com/rilldata/rill/runtime/pkg/httputil"
"github.com/rilldata/rill/runtime/pkg/observability"
"go.uber.org/zap"
)

Expand All @@ -25,7 +26,7 @@ const (
maxBodyBytes = int64(65536)
)

var interestingEvents = []string{"invoice.payment_succeeded", "invoice.payment_failed", "invoice.issue_failed"}
var interestingEvents = []string{"invoice.payment_succeeded", "invoice.payment_failed", "invoice.issue_failed", "subscription.started", "subscription.ended", "subscription.plan_changed"}

type orbWebhook struct {
orb *Orb
Expand Down Expand Up @@ -86,6 +87,36 @@ func (o *orbWebhook) handleWebhook(w http.ResponseWriter, r *http.Request) error
}
// inefficient one time conversion to named logger as its rare event and no need to log every thing else with named logger
o.orb.logger.Named("billing").Warn("invoice issue failed", zap.String("customer_id", ie.OrbInvoice.Customer.ExternalCustomerID), zap.String("invoice_id", ie.OrbInvoice.ID), zap.String("props", fmt.Sprintf("%v", ie.Properties)))
case "subscription.started":
var se subscriptionEvent
err = json.Unmarshal(payload, &se)
if err != nil {
return httputil.Errorf(http.StatusBadRequest, "error parsing event data: %w", err)
}
err = o.handlePlanChange(r.Context(), se) // as of now we are just using this to update plan cache
if err != nil {
return httputil.Errorf(http.StatusInternalServerError, "error handling event: %w", err)
}
case "subscription.ended":
var se subscriptionEvent
err = json.Unmarshal(payload, &se)
if err != nil {
return httputil.Errorf(http.StatusBadRequest, "error parsing event data: %w", err)
}
err = o.handlePlanChange(r.Context(), se) // as of now we are just using this to update plan cache
if err != nil {
return httputil.Errorf(http.StatusInternalServerError, "error handling event: %w", err)
}
case "subscription.plan_changed":
var se subscriptionEvent
err = json.Unmarshal(payload, &se)
if err != nil {
return httputil.Errorf(http.StatusBadRequest, "error parsing event data: %w", err)
}
err = o.handlePlanChange(r.Context(), se) // as of now we are just using this to update plan cache
if err != nil {
return httputil.Errorf(http.StatusInternalServerError, "error handling event: %w", err)
}
default:
// do nothing
}
Expand All @@ -97,10 +128,11 @@ func (o *orbWebhook) handleWebhook(w http.ResponseWriter, r *http.Request) error
func (o *orbWebhook) handleInvoicePaymentSucceeded(ctx context.Context, ie invoiceEvent) error {
res, err := o.jobs.PaymentSuccess(ctx, ie.OrbInvoice.Customer.ExternalCustomerID, ie.OrbInvoice.ID)
if err != nil {
o.orb.logger.Error("failed to insert invoice payment success job", zap.String("billing_customer_id", ie.OrbInvoice.Customer.ExternalCustomerID), zap.Error(err), observability.ZapCtx(ctx))
return err
}
if res.Duplicate {
o.orb.logger.Debug("duplicate invoice payment success event", zap.String("event_d", ie.ID))
o.orb.logger.Debug("duplicate invoice payment success event", zap.String("event_id", ie.ID))
}
return nil
}
Expand All @@ -117,6 +149,7 @@ func (o *orbWebhook) handleInvoicePaymentFailed(ctx context.Context, ie invoiceE
ie.OrbInvoice.PaymentFailedAt,
)
if err != nil {
o.orb.logger.Error("failed to insert invoice payment failed job", zap.String("billing_customer_id", ie.OrbInvoice.Customer.ExternalCustomerID), zap.Error(err), observability.ZapCtx(ctx))
return err
}
if res.Duplicate {
Expand All @@ -125,6 +158,19 @@ func (o *orbWebhook) handleInvoicePaymentFailed(ctx context.Context, ie invoiceE
return nil
}

func (o *orbWebhook) handlePlanChange(ctx context.Context, se subscriptionEvent) error {
if se.OrbSubscription.Customer.ExternalCustomerID == "" {
return nil
}

_, err := o.jobs.PlanChanged(ctx, se.OrbSubscription.Customer.ExternalCustomerID)
if err != nil {
o.orb.logger.Error("failed to insert plan changed job", zap.String("billing_customer_id", se.OrbSubscription.Customer.ExternalCustomerID), zap.Error(err), observability.ZapCtx(ctx))
return err
}
return nil
}

// Validates whether or not the webhook payload was sent by Orb.
func (o *orbWebhook) verifySignature(payload []byte, headers http.Header, now time.Time) error {
if o.orb.webhookSecret == "" {
Expand Down Expand Up @@ -192,3 +238,11 @@ type invoiceEvent struct {
Properties interface{} `json:"properties"`
OrbInvoice orb.Invoice `json:"invoice"`
}

type subscriptionEvent struct {
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
Type string `json:"type"`
Properties interface{} `json:"properties"`
OrbSubscription orb.Subscription `json:"subscription"`
}
11 changes: 7 additions & 4 deletions admin/billing/payment/stripe_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package payment
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/rilldata/rill/admin/jobs"
"github.com/rilldata/rill/runtime/pkg/httputil"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/stripe/stripe-go/v79"
"github.com/stripe/stripe-go/v79/webhook"
"go.uber.org/zap"
Expand Down Expand Up @@ -96,7 +96,8 @@ func (s *stripeWebhook) handleWebhook(w http.ResponseWriter, r *http.Request) er
func (s *stripeWebhook) handlePaymentMethodAdded(ctx context.Context, eventID string, method *stripe.PaymentMethod) error {
res, err := s.jobs.PaymentMethodAdded(ctx, method.ID, method.Customer.ID, string(method.Type), time.UnixMilli(method.Created*1000))
if err != nil {
return fmt.Errorf("failed to add payment method added: %w", err)
s.stripe.logger.Error("failed to add payment method added job", zap.String("payment_customer_id", method.Customer.ID), zap.Error(err), observability.ZapCtx(ctx))
return err
}
if res.Duplicate {
s.stripe.logger.Debug("duplicate payment_method.attached event", zap.String("event_id", eventID))
Expand All @@ -108,7 +109,8 @@ func (s *stripeWebhook) handlePaymentMethodAdded(ctx context.Context, eventID st
func (s *stripeWebhook) handlePaymentMethodRemoved(ctx context.Context, eventID, customerID string, method *stripe.PaymentMethod) error {
res, err := s.jobs.PaymentMethodRemoved(ctx, method.ID, customerID, time.UnixMilli(method.Created*1000))
if err != nil {
return fmt.Errorf("failed to add payment method added: %w", err)
s.stripe.logger.Error("failed to add payment method removed job", zap.String("payment_customer_id", customerID), zap.Error(err), observability.ZapCtx(ctx))
return err
}
if res.Duplicate {
s.stripe.logger.Debug("duplicate payment_method.detached event", zap.String("event_id", eventID))
Expand All @@ -120,7 +122,8 @@ func (s *stripeWebhook) handlePaymentMethodRemoved(ctx context.Context, eventID,
func (s *stripeWebhook) handleCustomerAddressUpdated(ctx context.Context, eventID string, eventTime time.Time, customer *stripe.Customer) error {
res, err := s.jobs.CustomerAddressUpdated(ctx, customer.ID, eventTime)
if err != nil {
return fmt.Errorf("failed to add customer updated event: %w", err)
s.stripe.logger.Error("failed to add customer updated job", zap.String("payment_customer_id", customer.ID), zap.Error(err), observability.ZapCtx(ctx))
return err
}
if res.Duplicate {
s.stripe.logger.Debug("duplicate customer.updated event", zap.String("event_d", eventID))
Expand Down
4 changes: 4 additions & 0 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ type Organization struct {
BillingCustomerID string `db:"billing_customer_id"`
PaymentCustomerID string `db:"payment_customer_id"`
BillingEmail string `db:"billing_email"`
BillingPlanName *string `db:"billing_plan_name"`
BillingPlanDisplayName *string `db:"billing_plan_display_name"`
CreatedByUserID *string `db:"created_by_user_id"`
}

Expand Down Expand Up @@ -372,6 +374,8 @@ type UpdateOrganizationOptions struct {
BillingCustomerID string
PaymentCustomerID string
BillingEmail string
BillingPlanName *string
BillingPlanDisplayName *string
CreatedByUserID *string
}

Expand Down
3 changes: 3 additions & 0 deletions admin/database/postgres/migrations/0059.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- keeping the fields nullable as we need to be able to distinguish between plan not cached vs no plan
ALTER TABLE orgs ADD COLUMN billing_plan_name TEXT;
ALTER TABLE orgs ADD COLUMN billing_plan_display_name TEXT;
4 changes: 2 additions & 2 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ func (c *connection) UpdateOrganization(ctx context.Context, id string, opts *da

res := &database.Organization{}
err := c.getDB(ctx).QueryRowxContext(ctx,
`UPDATE orgs SET name=$1, display_name=$2, description=$3, logo_asset_id=$4, favicon_asset_id=$5, custom_domain=$6, quota_projects=$7, quota_deployments=$8, quota_slots_total=$9, quota_slots_per_deployment=$10, quota_outstanding_invites=$11, quota_storage_limit_bytes_per_deployment=$12, billing_customer_id=$13, payment_customer_id=$14, billing_email=$15, created_by_user_id=$16, updated_on=now() WHERE id=$17 RETURNING *`,
opts.Name, opts.DisplayName, opts.Description, opts.LogoAssetID, opts.FaviconAssetID, opts.CustomDomain, opts.QuotaProjects, opts.QuotaDeployments, opts.QuotaSlotsTotal, opts.QuotaSlotsPerDeployment, opts.QuotaOutstandingInvites, opts.QuotaStorageLimitBytesPerDeployment, opts.BillingCustomerID, opts.PaymentCustomerID, opts.BillingEmail, opts.CreatedByUserID, id).StructScan(res)
`UPDATE orgs SET name=$1, display_name=$2, description=$3, logo_asset_id=$4, favicon_asset_id=$5, custom_domain=$6, quota_projects=$7, quota_deployments=$8, quota_slots_total=$9, quota_slots_per_deployment=$10, quota_outstanding_invites=$11, quota_storage_limit_bytes_per_deployment=$12, billing_customer_id=$13, payment_customer_id=$14, billing_email=$15, created_by_user_id=$16, billing_plan_name=$17, billing_plan_display_name=$18, updated_on=now() WHERE id=$19 RETURNING *`,
opts.Name, opts.DisplayName, opts.Description, opts.LogoAssetID, opts.FaviconAssetID, opts.CustomDomain, opts.QuotaProjects, opts.QuotaDeployments, opts.QuotaSlotsTotal, opts.QuotaSlotsPerDeployment, opts.QuotaOutstandingInvites, opts.QuotaStorageLimitBytesPerDeployment, opts.BillingCustomerID, opts.PaymentCustomerID, opts.BillingEmail, opts.CreatedByUserID, opts.BillingPlanName, opts.BillingPlanDisplayName, id).StructScan(res)
if err != nil {
return nil, parseErr("org", err)
}
Expand Down
2 changes: 2 additions & 0 deletions admin/jobs/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type Client interface {
RepairOrgBilling(ctx context.Context, orgID string) (*InsertResult, error) // biller is just used for deduplication
StartOrgTrial(ctx context.Context, orgID string) (*InsertResult, error)
DeleteOrg(ctx context.Context, orgID string) (*InsertResult, error)

PlanChanged(ctx context.Context, billingCustomerID string) (*InsertResult, error)
}

type InsertResult struct {
Expand Down
4 changes: 4 additions & 0 deletions admin/jobs/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ func (n *noop) StartOrgTrial(ctx context.Context, orgID string) (*InsertResult,
func (n *noop) DeleteOrg(ctx context.Context, orgID string) (*InsertResult, error) {
return nil, nil
}

func (n *noop) PlanChanged(ctx context.Context, billingCustomerID string) (*InsertResult, error) {
return nil, nil
}
68 changes: 68 additions & 0 deletions admin/jobs/river/biller_event_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/billing"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/runtime/pkg/email"
"github.com/riverqueue/river"
Expand Down Expand Up @@ -308,3 +309,70 @@ func (w *PaymentFailedGracePeriodCheckWorker) checkFailedInvoicesForOrg(ctx cont
}
return hasOverdue, nil
}

type PlanChangedArgs struct {
BillingCustomerID string
}

func (PlanChangedArgs) Kind() string { return "plan_changed" }

type PlanChangedWorker struct {
river.WorkerDefaults[PlanChangedArgs]
admin *admin.Service
}

func (w *PlanChangedWorker) Work(ctx context.Context, job *river.Job[PlanChangedArgs]) error {
org, err := w.admin.DB.FindOrganizationForBillingCustomerID(ctx, job.Args.BillingCustomerID)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// org got deleted, ignore
return nil
}
return fmt.Errorf("failed to find organization of billing customer id %q: %w", job.Args.BillingCustomerID, err)
}

orgName := org.Name
// something related to plan changed, just fetch the latest plan from the biller
sub, err := w.admin.Biller.GetActiveSubscription(ctx, org.BillingCustomerID)
if err != nil && !errors.Is(err, billing.ErrNotFound) {
return fmt.Errorf("failed to get subscriptions for org %q: %w", orgName, err)
}

var planDisplayName string
var planName string
if sub == nil {
planDisplayName = ""
planName = ""
} else {
planDisplayName = sub.Plan.DisplayName
planName = sub.Plan.Name
}

if org.BillingPlanName == nil || *org.BillingPlanName != planName {
_, err = w.admin.DB.UpdateOrganization(ctx, org.ID, &database.UpdateOrganizationOptions{
Name: org.Name,
DisplayName: org.DisplayName,
Description: org.Description,
LogoAssetID: org.LogoAssetID,
FaviconAssetID: org.FaviconAssetID,
CustomDomain: org.CustomDomain,
QuotaProjects: org.QuotaProjects,
QuotaDeployments: org.QuotaDeployments,
QuotaSlotsTotal: org.QuotaSlotsTotal,
QuotaSlotsPerDeployment: org.QuotaSlotsPerDeployment,
QuotaOutstandingInvites: org.QuotaOutstandingInvites,
QuotaStorageLimitBytesPerDeployment: org.QuotaStorageLimitBytesPerDeployment,
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: &planName,
BillingPlanDisplayName: &planDisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
return fmt.Errorf("failed to update plan cache for org %q: %w", orgName, err)
}
}

return nil
}
15 changes: 15 additions & 0 deletions admin/jobs/river/river.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func New(ctx context.Context, dsn string, adm *admin.Service) (jobs.Client, erro
river.AddWorker(workers, &PaymentFailedWorker{admin: adm, logger: billingLogger})
river.AddWorker(workers, &PaymentSuccessWorker{admin: adm, logger: billingLogger})
river.AddWorker(workers, &PaymentFailedGracePeriodCheckWorker{admin: adm, logger: billingLogger})
river.AddWorker(workers, &PlanChangedWorker{admin: adm})

// trial checks worker
river.AddWorker(workers, &TrialEndingSoonWorker{admin: adm, logger: billingLogger})
Expand Down Expand Up @@ -373,6 +374,20 @@ func (c *Client) DeleteOrg(ctx context.Context, orgID string) (*jobs.InsertResul
}, nil
}

func (c *Client) PlanChanged(ctx context.Context, billingCustomerID string) (*jobs.InsertResult, error) {
res, err := c.riverClient.Insert(ctx, PlanChangedArgs{
BillingCustomerID: billingCustomerID,
}, &river.InsertOpts{})
if err != nil {
return nil, err
}

return &jobs.InsertResult{
ID: res.Job.ID,
Duplicate: res.UniqueSkippedAsDuplicate,
}, nil
}

type ErrorHandler struct {
logger *zap.Logger
}
Expand Down
2 changes: 2 additions & 0 deletions admin/jobs/river/subscription_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ func (w *SubscriptionCancellationCheckWorker) subscriptionCancellationCheck(ctx
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: org.BillingPlanName,
BillingPlanDisplayName: org.BillingPlanDisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions admin/jobs/river/trial_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ func (w *TrialGracePeriodCheckWorker) trialGracePeriodCheck(ctx context.Context)
BillingCustomerID: org.BillingCustomerID,
PaymentCustomerID: org.PaymentCustomerID,
BillingEmail: org.BillingEmail,
BillingPlanName: org.BillingPlanName,
BillingPlanDisplayName: org.BillingPlanDisplayName,
CreatedByUserID: org.CreatedByUserID,
})
if err != nil {
Expand Down
Loading
Loading