Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
api: accept payment on schedule (#257)
Browse files Browse the repository at this point in the history
This changes scheduling such that clients may specify payment methods
other than the default payment method for a customer. It is now the
recommended way for associating payment methods with a schedule.

In a future commit, support for listing payment methods via the Tier API
will be added.
  • Loading branch information
bmizerany committed Feb 20, 2023
1 parent a9af4b2 commit 237d45b
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 68 deletions.
18 changes: 17 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/kr/pretty"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -120,6 +121,17 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
})
return
}

var ipe *stripe.Error
if errors.As(err, &ipe) && strings.Contains(ipe.Message, "No such PaymentMethod") {
trweb.WriteError(w, &trweb.HTTPError{
Status: 400,
Code: "invalid_payment_method",
Message: ipe.Message,
})
return
}

if trweb.WriteError(w, lookupErr(err)) || trweb.WriteError(w, err) {
return
}
Expand Down Expand Up @@ -234,7 +246,11 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
})
}
}
return h.c.Schedule(r.Context(), sr.Org, phases)

return h.c.Schedule(r.Context(), sr.Org, control.ScheduleParams{
PaymentMethod: sr.PaymentMethodID,
Phases: phases,
})
}

func (h *Handler) serveReport(w http.ResponseWriter, r *http.Request) error {
Expand Down
21 changes: 21 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,27 @@ func TestAPISubscribe(t *testing.T) {
Code: "TERR1020",
Message: "feature or plan not found",
})

_, err := tc.Schedule(ctx, "org:test", &tier.ScheduleParams{
Phases: []apitypes.Phase{
{Trial: true, Features: []string{"plan:test@0"}},
},
PaymentMethodID: "pm_card_us",
})

// Quick lint check to make sure the PaymentMethod made it to Stripe.
// In production, payment methods can be set on a sub by sub basis;
// however in test mode, we may only use test payment methods, and in
// test mode, stripe does not accept test payment methods on a sub by
// sub basis, so there is no real way to test our support for this
// feature. Instead, here, we just check stripe complains about the
// payment method to show it saw what we wanted it to see in
// production.
diff.Test(t, t.Errorf, err, &apitypes.Error{
Status: 400,
Code: "invalid_payment_method",
Message: "No such PaymentMethod: 'pm_card_us'",
})
}

func TestPhaseBadOrg(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ type CheckoutRequest struct {
}

type ScheduleRequest struct {
Org string `json:"org"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
Org string `json:"org"`
PaymentMethodID string `json:"payment_method_id"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
}

// ScheduleResponse is the expected response from a schedule request. It is
Expand Down
12 changes: 7 additions & 5 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,17 @@ type CheckoutParams struct {
}

type ScheduleParams struct {
Info *OrgInfo
Phases []Phase
Info *OrgInfo
Phases []Phase
PaymentMethodID string
}

func (c *Client) Schedule(ctx context.Context, org string, p *ScheduleParams) (*apitypes.ScheduleResponse, error) {
return fetchOK[*apitypes.ScheduleResponse, *apitypes.Error](ctx, c, "POST", "/v1/subscribe", &apitypes.ScheduleRequest{
Org: org,
Info: (*apitypes.OrgInfo)(p.Info),
Phases: p.Phases,
Org: org,
Info: (*apitypes.OrgInfo)(p.Info),
Phases: p.Phases,
PaymentMethodID: p.PaymentMethodID,
})

}
Expand Down
3 changes: 3 additions & 0 deletions cmd/tier/help.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ Checkout only flags:
--cancel_url=<cancel_url>
specify a cancel_url for Stripe Checkout. This flag is ignored
if --checkout is not set.
--paymentmethod=<paymentmethod_id>
specify a payment method to use for the subscription. This flag
is ignored with --checkout.
Global Flags:
Expand Down
2 changes: 2 additions & 0 deletions cmd/tier/tier.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ func runTier(cmd string, args []string) (err error) {
successURL := fs.String("checkout", "", "subscribe via Stripe checkout")
cancelURL := fs.String("cancel_url", "", "sets the cancel URL for use with -checkout")
requireBillingAddress := fs.Bool("require_billing_address", false, "require billing address for use with --checkout")
paymentMethod := fs.String("paymentmethod", "", "sets the Stripe payment method for the subscription (e.g. pm_123). It is ignored with --checkout")
if err := fs.Parse(args); err != nil {
return err
}
Expand Down Expand Up @@ -317,6 +318,7 @@ func runTier(cmd string, args []string) (err error) {
Info: &tier.OrgInfo{
Email: *email,
},
PaymentMethodID: *paymentMethod,
}
switch {
case *trial > 0:
Expand Down
111 changes: 62 additions & 49 deletions control/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (c *Client) lookupSubscription(ctx context.Context, org, name string) (sub
return s, nil
}

func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub string, phases []Phase) (err error) {
func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub string, p ScheduleParams) (err error) {
defer errorfmt.Handlef("stripe: createSchedule: %q: %w", org, &err)

create := func(f stripe.Form) (string, error) {
Expand All @@ -200,23 +200,43 @@ func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub s
}
// We can only update phases after the schedule is created from
// the subscription.
return c.updateSchedule(ctx, sid, name, phases)
return c.updateSchedule(ctx, sid, name, p)
} else {
defer errorfmt.Handlef("newSub: %w", &err)
cid, err := c.WhoIs(ctx, org)
if err != nil {
return err
}
var f stripe.Form
if p.PaymentMethod != "" {
f.Set("default_settings", "default_payment_method", p.PaymentMethod)
}
f.Set("customer", cid)
if err := addPhases(ctx, c, &f, false, name, phases); err != nil {
if err := addPhases(ctx, c, &f, false, name, p.Phases); err != nil {
return err
}
_, err = create(f)
return err
}
}

type stripeSubSchedule struct {
stripe.ID
Current struct {
Start int64 `json:"start_date"`
End int64 `json:"end_date"`
} `json:"current_phase"`
Phases []struct {
Metadata struct {
Name string `json:"tier.subscription"`
}
Start int64 `json:"start_date"`
Items []struct {
Price stripePrice
}
}
}

func (c *Client) lookupPhases(ctx context.Context, org string, s subscription, name string) (current Phase, all []Phase, err error) {
defer errorfmt.Handlef("lookupPhases: %w", &err)

Expand All @@ -225,26 +245,9 @@ func (c *Client) lookupPhases(ctx context.Context, org string, s subscription, n
return ps[0], ps, nil
}

type T struct {
stripe.ID
Current struct {
Start int64 `json:"start_date"`
End int64 `json:"end_date"`
} `json:"current_phase"`
Phases []struct {
Metadata struct {
Name string `json:"tier.subscription"`
}
Start int64 `json:"start_date"`
Items []struct {
Price stripePrice
}
}
}

g, ctx := errgroup.WithContext(ctx)

var ss T
var ss stripeSubSchedule
g.Go(func() error {
var f stripe.Form
f.Add("expand[]", "phases.items.price")
Expand Down Expand Up @@ -348,13 +351,16 @@ func subscriptionToPhases(org string, s subscription) []Phase {
return ps
}

func (c *Client) updateSchedule(ctx context.Context, schedID, name string, phases []Phase) (err error) {
func (c *Client) updateSchedule(ctx context.Context, schedID, name string, p ScheduleParams) (err error) {
defer errorfmt.Handlef("stripe: updateSchedule: %q: %w", schedID, &err)
if schedID == "" {
return errors.New("subscription id required")
}
var f stripe.Form
if err := addPhases(ctx, c, &f, true, name, phases); err != nil {
if p.PaymentMethod != "" {
f.Set("default_settings", "default_payment_method", p.PaymentMethod)
}
if err := addPhases(ctx, c, &f, true, name, p.Phases); err != nil {
return err
}
return c.Stripe.Do(ctx, "POST", "/v1/subscription_schedules/"+schedID, f, nil)
Expand Down Expand Up @@ -478,8 +484,13 @@ func (c *Client) Checkout(ctx context.Context, org string, successURL string, p
}
}

func (c *Client) Schedule(ctx context.Context, org string, phases []Phase) error {
err := c.schedule(ctx, org, phases)
type ScheduleParams struct {
PaymentMethod string
Phases []Phase
}

func (c *Client) Schedule(ctx context.Context, org string, p ScheduleParams) error {
err := c.schedule(ctx, org, p)
c.Logf("stripe: schedule: %v", err)
var e *stripe.Error
if errors.As(err, &e) {
Expand All @@ -493,21 +504,21 @@ func (c *Client) Schedule(ctx context.Context, org string, phases []Phase) error
return err
}

func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err error) {
func (c *Client) schedule(ctx context.Context, org string, p ScheduleParams) (err error) {
defer errorfmt.Handlef("tier: schedule: %q: %w", org, &err)

if err := c.PutCustomer(ctx, org, nil); err != nil {
return err
}

if len(phases) == 0 {
if len(p.Phases) == 0 {
return errors.New("tier: schedule: at least one phase required")
}

scheduleNow := phases[0].Effective.IsZero()
cancelNow := scheduleNow && len(phases[0].Features) == 0
scheduleNow := p.Phases[0].Effective.IsZero()
cancelNow := scheduleNow && len(p.Phases[0].Features) == 0

if cancelNow && len(phases) > 1 {
if cancelNow && len(p.Phases) > 1 {
return errors.New("tier: a cancel phase must be the final phase")
}

Expand All @@ -522,7 +533,7 @@ func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err
//
// If this is a "cancel immediately" request, it returns
// ErrInvalidCancel because there is no subscription to cancel.
return c.createSchedule(ctx, org, defaultScheduleName, "", phases)
return c.createSchedule(ctx, org, defaultScheduleName, "", p)
}
if err != nil {
return err
Expand All @@ -536,7 +547,7 @@ func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err

if s.ScheduleID == "" {
// We have a subscription, but it is has no active schedule, so start a new one.
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, phases)
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, p)
} else {
cp, _, err := c.lookupPhases(ctx, org, s, defaultScheduleName)
if err != nil {
Expand All @@ -546,18 +557,18 @@ func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err
if cp.Valid() {
if scheduleNow {
// attach phase to current
phases[0].Effective = cp.Effective
p.Phases[0].Effective = cp.Effective
} else {
phases = append([]Phase{cp}, phases...)
p.Phases = append([]Phase{cp}, p.Phases...)
}
}

err = c.updateSchedule(ctx, s.ScheduleID, defaultScheduleName, phases)
err = c.updateSchedule(ctx, s.ScheduleID, defaultScheduleName, p)
if isReleased(err) {
// Lost a race with the clock and the schedule was
// released just after seeing it, but before our
// update.
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, phases)
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, p)
}
if err != nil {
return err
Expand All @@ -582,9 +593,9 @@ func isReleased(err error) bool {
// taking over any in-progress schedule. The customer is billed immediately
// with prorations if any.
func (c *Client) SubscribeTo(ctx context.Context, org string, fs []refs.FeaturePlan) error {
return c.Schedule(ctx, org, []Phase{{
Features: fs,
}})
return c.Schedule(ctx, org, ScheduleParams{
Phases: []Phase{{Features: fs}},
})

}

Expand Down Expand Up @@ -840,24 +851,26 @@ func (c *Client) Isolated() bool {
return c.Stripe.AccountID != ""
}

type stripeCustomer struct {
stripe.ID
Email string
Metadata struct {
Org string `json:"tier.org"`
}
}

func (c *Client) WhoIs(ctx context.Context, org string) (id string, err error) {
defer errorfmt.Handlef("whois: %q: %w", org, &err)
if !strings.HasPrefix(org, "org:") {
return "", &ValidationError{Message: "org must be prefixed with \"org:\""}
}

cid, err := c.cache.load(org, func() (string, error) {
type T struct {
stripe.ID
Email string
Metadata struct {
Org string `json:"tier.org"`
}
}
var f stripe.Form
cus, err := stripe.List[T](ctx, c.Stripe, "GET", "/v1/customers", f).Find(func(v T) bool {
return v.Metadata.Org == org
})
cus, err := stripe.List[stripeCustomer](ctx, c.Stripe, "GET", "/v1/customers", f).
Find(func(v stripeCustomer) bool {
return v.Metadata.Org == org
})
if err != nil {
return "", err
}
Expand Down

0 comments on commit 237d45b

Please sign in to comment.