Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
cmd/tier: accept -checkout on subscribe (#215)
Browse files Browse the repository at this point in the history
This commit adds Tier flare to Stripe's Checkout features. Unlike using
Stripe Checkout directly, Tier will suppress duplicate subscriptions,
and allow clients to address Customers and Prices using Tier Org and
Plan/Feature names, respectively, which allow clients to "bring your
own" IDs.

This commit also adds a new --checkout flag to the subscribe command,
generating a Stripe Checkout link via the Tier API. It also
adds an optional --cancel_url flag for specifying the URL to
send the user if they back out of the checkout session.
  • Loading branch information
bmizerany committed Jan 21, 2023
1 parent 30d4122 commit 6e72ddf
Show file tree
Hide file tree
Showing 12 changed files with 319 additions and 77 deletions.
63 changes: 49 additions & 14 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"tier.run/api/apitypes"
"tier.run/api/materialize"
"tier.run/control"
"tier.run/refs"
"tier.run/stripe"
"tier.run/trweb"
"tier.run/values"
Expand Down Expand Up @@ -124,12 +125,21 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if trweb.WriteError(w, lookupErr(err)) || trweb.WriteError(w, err) {
return
}
var e *control.ValidationError
if errors.As(err, &e) {
var ve *control.ValidationError
if errors.As(err, &ve) {
trweb.WriteError(w, &trweb.HTTPError{
Status: 400,
Code: "invalid_request",
Message: e.Message,
Message: ve.Message,
})
return
}
var pe *refs.ParseError
if errors.As(err, &pe) {
trweb.WriteError(w, &trweb.HTTPError{
Status: 400,
Code: "invalid_request",
Message: pe.Message,
})
return
}
Expand All @@ -154,6 +164,8 @@ func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error {
return h.serveReport(w, r)
case "/v1/subscribe":
return h.serveSubscribe(w, r)
case "/v1/checkout":
return h.serveCheckout(w, r)
case "/v1/phase":
return h.servePhase(w, r)
case "/v1/pull":
Expand All @@ -165,11 +177,45 @@ func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error {
}
}

func (h *Handler) serveCheckout(w http.ResponseWriter, r *http.Request) error {
var cr apitypes.CheckoutRequest
if err := trweb.DecodeStrict(r, &cr); err != nil {
return err
}
m, err := h.c.Pull(r.Context(), 0)
if err != nil {
return err
}
fs, err := control.ExpandPlans(m, cr.Features...)
if err != nil {
return err
}
link, err := h.c.Checkout(r.Context(), cr.Org, cr.SuccessURL, &control.CheckoutParams{
TrialDays: cr.TrialDays,
Features: fs,
CancelURL: cr.CancelURL,
})
if err != nil {
return err
}
return httpJSON(w, &apitypes.CheckoutResponse{URL: link})
}

func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
var sr apitypes.ScheduleRequest
if err := trweb.DecodeStrict(r, &sr); err != nil {
return err
}
if sr.Info != nil {
info := infoToOrgInfo(sr.Info)
if err := h.c.PutCustomer(r.Context(), sr.Org, info); err != nil {
return err
}
}
if len(sr.Phases) == 0 {
return nil
}

var phases []control.Phase
if len(sr.Phases) > 0 {
m, err := h.c.Pull(r.Context(), 0)
Expand All @@ -189,17 +235,6 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
})
}
}
if sr.Info != nil {
info := infoToOrgInfo(sr.Info)
if err := h.c.PutCustomer(r.Context(), sr.Org, info); err != nil {
return err
}
}

if len(phases) == 0 {
return nil
}

return h.c.Schedule(r.Context(), sr.Org, phases)
}

Expand Down
60 changes: 56 additions & 4 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,40 @@ func newTestClient(t *testing.T) (*tier.Client, *control.Client) {
return tc, cc
}

func TestAPICheckout(t *testing.T) {
ctx := context.Background()
tc, cc := newTestClient(t)
m := []control.Feature{{
FeaturePlan: mpf("feature:x@plan:test@0"),
Interval: "@monthly",
Currency: "usd",
}}
cc.Push(ctx, m, pushLogger(t))

t.Run("card setup", func(t *testing.T) {
r, err := tc.Checkout(ctx, "org:test", "https://example.com/success", nil)
if err != nil {
t.Fatal(err)
}
t.Logf("checkout: %s", r.URL)
if r.URL == "" {
t.Error("unexpected empty checkout url")
}
})
t.Run("subscription", func(t *testing.T) {
r, err := tc.Checkout(ctx, "org:test", "https://example.com/success", &tier.CheckoutParams{
Features: []string{"feature:x@plan:test@0"},
})
if err != nil {
t.Fatal(err)
}
t.Logf("checkout: %s", r.URL)
if r.URL == "" {
t.Error("unexpected empty checkout url")
}
})
}

func TestAPISubscribe(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -186,8 +220,8 @@ func TestAPISubscribe(t *testing.T) {

sub("org:test", []string{"plan:test@0", "feature:nope@0"}, &apitypes.Error{
Status: 400,
Code: "feature_not_found",
Message: "feature not found",
Code: "TERR1020",
Message: "feature or plan not found",
})

sub("org:test", []string{"plan:nope@0"}, &apitypes.Error{
Expand Down Expand Up @@ -446,8 +480,9 @@ func TestScheduleWithCustomerInfoNoPhases(t *testing.T) {
},
}

_, err := tc.Schedule(ctx, "org:test", p)
diff.Test(t, t.Fatalf,
tc.Schedule(ctx, "org:test", p),
err,
&apitypes.Error{
Status: 400,
Code: "invalid_metadata",
Expand All @@ -466,7 +501,7 @@ func TestScheduleWithCustomerInfoNoPhases(t *testing.T) {
)
diff.Test(t, t.Fatalf, got, apitypes.WhoIsResponse{})

if err := tc.Schedule(ctx, "org:test", &tier.ScheduleParams{
if _, err := tc.Schedule(ctx, "org:test", &tier.ScheduleParams{
Info: &tier.OrgInfo{
Email: "test2@example2.com",
},
Expand Down Expand Up @@ -499,3 +534,20 @@ func maybeFailNow(t *testing.T) {
t.FailNow()
}
}

func pushLogger(t *testing.T) func(f control.Feature, err error) {
t.Helper()
return pushLogWith(t, t.Fatalf)
}

func pushLogWith(t *testing.T, fatalf func(string, ...any)) func(f control.Feature, err error) {
t.Helper()
return func(f control.Feature, err error) {
t.Helper()
if err == nil {
t.Logf("pushed %q", f.FeaturePlan)
} else {
fatalf("error pushing %q: %v", f.FeaturePlan, err)
}
}
}
44 changes: 30 additions & 14 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ func (e *Error) Error() string {
}

type Phase struct {
Trial bool
Effective time.Time
Features []string
Trial bool `json:"trial,omitempty"`
Effective time.Time `json:"effective,omitempty"`
Features []string `json:"features,omitempty"`
}

type PhaseResponse struct {
Expand All @@ -32,7 +32,7 @@ type PhaseResponse struct {
}

type InvoiceSettings struct {
DefaultPaymentMethod string
DefaultPaymentMethod string `json:"default_payment_method"`
}

type OrgInfo struct {
Expand All @@ -42,22 +42,38 @@ type OrgInfo struct {
Phone string `json:"phone"`
Metadata map[string]string `json:"metadata"`

PaymentMethod string
InvoiceSettings InvoiceSettings
PaymentMethod string `json:"payment_method"`
InvoiceSettings InvoiceSettings `json:"invoice_settings"`
}

type CheckoutRequest struct {
Org string `json:"org"`
TrialDays int `json:"trial_days"`
Features []string `json:"features"`
SuccessURL string `json:"success_url"`
CancelURL string `json:"cancel_url"`
}

type ScheduleRequest struct {
Org string
Info *OrgInfo
Phases []Phase
Org string `json:"org"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
}

// ScheduleResponse is the expected response from a schedule request. It is
// currently empty, reserved for furture use.
type ScheduleResponse struct{}

type CheckoutResponse struct {
URL string `json:"url"`
}

type ReportRequest struct {
Org string
Feature refs.Name
N int
At time.Time // default is time.Now()
Clobber bool
Org string `json:"org"`
Feature refs.Name `json:"feature"`
N int `json:"n"`
At time.Time `json:"at"`
Clobber bool `json:"clobber"`
}

type WhoIsResponse struct {
Expand Down
39 changes: 26 additions & 13 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,30 +238,43 @@ func (c *Client) Subscribe(ctx context.Context, org string, featuresAndPlans ...
return err
}

// Checkout creates a new checkout link for the provided org and features, if
// any; otherwise, if no features are specified, and payment setup link is
// returned instead.
func (c *Client) Checkout(ctx context.Context, org string, successURL string, p *CheckoutParams) (*apitypes.CheckoutResponse, error) {
if p == nil {
p = &CheckoutParams{}
}
r := &apitypes.CheckoutRequest{
Org: org,
SuccessURL: successURL,
CancelURL: p.CancelURL,
TrialDays: p.TrialDays,
Features: p.Features,
}
return fetch.OK[*apitypes.CheckoutResponse, *apitypes.Error](ctx, c.client(), "POST", c.baseURL("/v1/checkout"), r)
}

type Phase = apitypes.Phase
type OrgInfo = apitypes.OrgInfo

type CheckoutParams struct {
TrialDays int
Features []string
CancelURL string
}

type ScheduleParams struct {
Info *OrgInfo
Phases []Phase
}

func (c *Client) Schedule(ctx context.Context, org string, p *ScheduleParams) error {
_, err := fetch.OK[struct{}, *apitypes.Error](ctx, c.client(), "POST", c.baseURL("/v1/subscribe"), &apitypes.ScheduleRequest{
func (c *Client) Schedule(ctx context.Context, org string, p *ScheduleParams) (*apitypes.ScheduleResponse, error) {
return fetch.OK[*apitypes.ScheduleResponse, *apitypes.Error](ctx, c.client(), "POST", c.baseURL("/v1/subscribe"), &apitypes.ScheduleRequest{
Org: org,
Info: (*apitypes.OrgInfo)(p.Info),
Phases: copyPhases(p.Phases),
Phases: p.Phases,
})

return err
}

func copyPhases(phases []Phase) []apitypes.Phase {
c := make([]apitypes.Phase, len(phases))
for i, p := range phases {
c[i] = apitypes.Phase(p)
}
return c
}

func (c *Client) WhoAmI(ctx context.Context) (apitypes.WhoAmIResponse, error) {
Expand Down
13 changes: 13 additions & 0 deletions cmd/tier/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io/fs"
"os"
Expand Down Expand Up @@ -67,6 +69,7 @@ func preallocateAccount() error {

func createSwitchAccount(ctx context.Context, sc *stripe.Client) (stripe.Account, error) {
return stripe.CreateAccount(ctx, sc, &stripe.AccountParams{
BusinessName: randomString("tier.switch."),
Meta: stripe.Meta{
"tier.account": "switch",
},
Expand Down Expand Up @@ -122,3 +125,13 @@ func cachePath(parts ...string) (string, error) {
}
return path, nil
}

// randomString returns a random 16 byte hexencoded string with the given
// prefix.
func randomString(prefix string) string {
s := make([]byte, 16)
if _, err := rand.Read(s); err != nil {
panic(err)
}
return prefix + hex.EncodeToString(s)
}
15 changes: 12 additions & 3 deletions cmd/tier/help.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,22 @@ Flags:
--email
set the org's email address
--cancel
cancel the org's subscription. It is an error to provide a plan
or featurePlan with this flag.
Checkout only flags:
--checkout=<success_url>
subscribe the org to plans and features using Stripe Checkout.
The success url is required, and may be used with the
--cancel_url flag.
--trial days
set the org's trial period to the provided number of days. If
negative, the tial period will last indefinitely, and no other
phase will come after it.
--cancel
cancel the org's subscription. It is an error to provide a plan
or featurePlan with this flag.
--cancel_url=<cancel_url>
specify a cancel_url for Stripe Checkout. This flag is ignored
if --checkout is not set.
Global Flags:
Expand Down

0 comments on commit 6e72ddf

Please sign in to comment.