Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
api: cancel schedules with empty phases (#205)
Browse files Browse the repository at this point in the history
This commit changes /v1/subscribe to accept an update consisting solely
of one zero-value phase. Such an update represents the desire to cancel
the org schedule effective immediately.
  • Loading branch information
bmizerany committed Jan 3, 2023
1 parent 313127f commit 6d2351d
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 56 deletions.
11 changes: 7 additions & 4 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
if err := trweb.DecodeStrict(r, &sr); err != nil {
return err
}

var phases []control.Phase
if len(sr.Phases) > 0 {
m, err := h.c.Pull(r.Context(), 0)
Expand All @@ -166,9 +165,13 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
})
}
}

info := (*control.OrgInfo)(sr.Info)
return h.c.ScheduleNow(r.Context(), sr.Org, info, phases)
if sr.Info != nil {
info := (*control.OrgInfo)(sr.Info)
if err := h.c.PutCustomer(r.Context(), sr.Org, info); err != nil {
return err
}
}
return h.c.ScheduleNow(r.Context(), sr.Org, phases)
}

func (h *Handler) serveReport(w http.ResponseWriter, r *http.Request) error {
Expand Down
7 changes: 6 additions & 1 deletion cmd/tier/tier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ const responsePricesValidPlan = `

func TestSubscribeUnexpectedMissingCustomer(t *testing.T) {
tt := testtier(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/subscription_schedules" {
t.Logf("request: %s %s", r.Method, r.URL.Path)
if wants(r, "POST", "/v1/subscription_schedules") {
w.WriteHeader(400)
io.WriteString(w, `{
"error": {
Expand Down Expand Up @@ -564,3 +565,7 @@ func chdir(t *testing.T, dir string) {
}
})
}

func wants(r *http.Request, method, path string) bool {
return r.Method == method && r.URL.Path == path
}
75 changes: 44 additions & 31 deletions control/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ func (c *Client) lookupSubscription(ctx context.Context, org, name string) (sub
return s, nil
}

func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub string, info *OrgInfo, phases []Phase) (err error) {
func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub string, phases []Phase) (err error) {
defer errorfmt.Handlef("stripe: createSubscription: %q: %w", org, &err)

// Update customer regardless of whether we have phases to update or
// not.
cid, err := c.putCustomer(ctx, org, info)
cid, err := c.putCustomer(ctx, org, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -236,8 +236,8 @@ func addPhases(ctx context.Context, c *Client, f *stripe.Form, update bool, name
return nil
}

func (c *Client) Schedule(ctx context.Context, org string, info *OrgInfo, phases []Phase) (err error) {
err = c.schedule(ctx, org, info, phases)
func (c *Client) Schedule(ctx context.Context, org string, phases []Phase) (err error) {
err = c.schedule(ctx, org, phases)
var e *stripe.Error
if errors.As(err, &e) {
if e.Code == "resource_missing" && e.Param == "customer" {
Expand All @@ -250,7 +250,7 @@ func (c *Client) Schedule(ctx context.Context, org string, info *OrgInfo, phases
return err
}

func (c *Client) schedule(ctx context.Context, org string, info *OrgInfo, phases []Phase) (err error) {
func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err error) {
defer errorfmt.Handlef("tier: schedule: %q: %w", org, &err)

for i, p := range phases {
Expand All @@ -262,34 +262,28 @@ func (c *Client) schedule(ctx context.Context, org string, info *OrgInfo, phases
}
}

if info != nil {
if _, err := c.putCustomer(ctx, org, info); err != nil {
return err
}
}

s, err := c.lookupSubscription(ctx, org, subscriptionNameTODO)
if errors.Is(err, ErrOrgNotFound) {
// We only need to pay the API penalty of creating a customer
// if we know in fact it does not exist.
if _, err := c.putCustomer(ctx, org, info); err != nil {
if _, err := c.putCustomer(ctx, org, nil); err != nil {
return err
}
return c.createSchedule(ctx, org, subscriptionNameTODO, "", info, phases)
return c.createSchedule(ctx, org, subscriptionNameTODO, "", phases)
}
if errors.Is(err, stripe.ErrNotFound) {
return c.createSchedule(ctx, org, subscriptionNameTODO, "", info, phases)
return c.createSchedule(ctx, org, subscriptionNameTODO, "", phases)
}
if err != nil {
return err
}
if s.ScheduleID != "" {
err = c.updateSchedule(ctx, s.ScheduleID, subscriptionNameTODO, phases)
if isReleased(err) {
return c.createSchedule(ctx, org, subscriptionNameTODO, s.ID, info, phases)
return c.createSchedule(ctx, org, subscriptionNameTODO, s.ID, phases)
}
} else {
return c.createSchedule(ctx, org, subscriptionNameTODO, s.ID, info, phases)
return c.createSchedule(ctx, org, subscriptionNameTODO, s.ID, phases)
}
return err
}
Expand All @@ -306,41 +300,60 @@ func isReleased(err error) bool {
return false
}

func (c *Client) scheduleCancel(ctx context.Context, org string) (err error) {
defer errorfmt.Handlef("tier: ScheduleCancel: %q: %w", org, &err)
s, err := c.lookupSubscription(ctx, org, subscriptionNameTODO)
if err != nil {
return err
}
var f stripe.Form
f.Set("prorate", true)
f.Set("invoice_now", true)
return c.Stripe.Do(ctx, "DELETE", "/v1/subscriptions/"+s.ID, f, nil)
}

// ScheduleNow is like Schedule but immediately starts the first phase as the
// current phase and cuts off any phases that have not yet been transitioned
// to.
//
// The first phase must have a zero Effective time to indicate that it should
// start now.
func (c *Client) ScheduleNow(ctx context.Context, org string, info *OrgInfo, phases []Phase) error {
func (c *Client) ScheduleNow(ctx context.Context, org string, phases []Phase) (err error) {
defer errorfmt.Handlef("tier: ScheduleNow: %q: %w", org, &err)
c.Logf("phases: %v", phases)
if len(phases) > 0 {
if !phases[0].Effective.IsZero() {
return errors.New("first phase must be effective now")
}
cps, err := c.LookupPhases(ctx, org)
if err != nil && !errors.Is(err, ErrOrgNotFound) {
return err
}
for _, p := range cps {
if p.Current {
p0 := phases[0]
p.Features = p0.Features
p.Trial = p0.Trial
phases[0] = p
break
if len(phases) == 1 && len(phases[0].Features) == 0 {
return c.scheduleCancel(ctx, org)
} else {
cps, err := c.LookupPhases(ctx, org)
if err != nil && !errors.Is(err, ErrOrgNotFound) {
return err
}
for _, p := range cps {
if p.Current {
p0 := phases[0]
p.Features = p0.Features
p.Trial = p0.Trial
phases[0] = p
break
}
}
}
}
return c.Schedule(ctx, org, info, phases)
return c.Schedule(ctx, org, phases)
}

// SubscribeTo subscribes org to the provided features effective immediately,
// taking over any in-progress schedule. The customer is billed immediately
// with prorations if any.
func (c *Client) SubscribeTo(ctx context.Context, org string, fs []refs.FeaturePlan) error {
return c.ScheduleNow(ctx, org, nil, []Phase{{
return c.ScheduleNow(ctx, org, []Phase{{
Features: fs,
}})

}

func (c *Client) lookupFeatures(ctx context.Context, keys []refs.FeaturePlan) ([]Feature, error) {
Expand Down Expand Up @@ -407,7 +420,7 @@ func (c *Client) LookupStatus(ctx context.Context, org string) (string, error) {
}

func (c *Client) LookupPhases(ctx context.Context, org string) (ps []Phase, err error) {
defer errorfmt.Handlef("LookupPhase: %w", &err)
defer errorfmt.Handlef("LookupPhases: %w", &err)

cid, err := c.WhoIs(ctx, org)
if err != nil {
Expand Down
97 changes: 77 additions & 20 deletions control/schedule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ func (s *scheduleTester) advanceToNextPeriod() {
s.clock.Advance(eop)
}

func (s *scheduleTester) cancel(org string) {
s.t.Helper()
s.t.Logf("cancelling %s", org)
if err := s.cc.ScheduleNow(context.Background(), org, []Phase{{}}); err != nil {
s.t.Fatal(err)
}
}

func (s *scheduleTester) schedule(org string, trialDays int, fs ...refs.FeaturePlan) {
s.t.Helper()
s.t.Logf("subscribing %s to %v with trialDays=%d", org, fs, trialDays)
Expand All @@ -188,7 +196,7 @@ func (s *scheduleTester) schedule(org string, trialDays int, fs ...refs.FeatureP
Features: fs,
}}
}
if err := s.cc.ScheduleNow(context.Background(), org, nil, ps); err != nil {
if err := s.cc.ScheduleNow(context.Background(), org, ps); err != nil {
s.t.Fatalf("error subscribing: %v", err)
}
}
Expand Down Expand Up @@ -348,25 +356,73 @@ func TestSchedule_TrialSwapWithPaid(t *testing.T) {
}
}

func TestScheduleUpdateOrgOnSchedule(t *testing.T) {
info := &OrgInfo{Email: "test@foo.com"}
c := newTestClient(t)
ctx := context.Background()
c.Push(ctx, []Feature{{
FeaturePlan: mpf("feature:x@plan:test@0"),
Interval: "@daily",
func TestScheduleCancel(t *testing.T) {
featureX := mpf("feature:x@plan:test@0")
featureBase := mpf("feature:base@plan:test@0")

s := newScheduleTester(t)
s.push([]Feature{{
FeaturePlan: featureX,
Interval: "@monthly",
Currency: "usd",
}}, pushLogger(t))
err := c.Schedule(ctx, "org:example", info, []Phase{{
Features: []refs.FeaturePlan{mpf("feature:x@plan:test@0")},
Mode: "graduated",
Aggregate: "sum",
Tiers: []Tier{{Upto: Inf, Base: 1000}},
}, {
FeaturePlan: featureBase,
Interval: "@monthly",
Currency: "usd",
Base: 31 * 1000,
}})
if err != nil {
t.Fatalf("got %v, want nil", err)
}
err = c.ScheduleNow(ctx, "org:example", info, nil) // org update only
if err != nil {
t.Fatal(err)
}

s.schedule("org:paid", 0, featureX, featureBase)
s.report("org:paid", "feature:x", 99)
s.advance(10)
s.cancel("org:paid")
s.advanceToNextPeriod()

// check usage is billed
s.checkInvoices("org:paid", []Invoice{{
Lines: []InvoiceLineItem{{
Feature: featureBase,
Quantity: 1,
Amount: -21000,
Proration: true,
},
lineItem(featureX, 99, 0), // usage on feature:x
lineItem(featureX, 0, 1000), // flag fee on feature:x
},
SubtotalPreTax: -20000,
Subtotal: -20000,
TotalPreTax: -20000,
Total: -20000,
}, {
Lines: []InvoiceLineItem{
lineItem(featureBase, 1, 31000),
lineItem(featureX, 0, 0),
},
SubtotalPreTax: 31000,
Subtotal: 31000,
TotalPreTax: 31000,
Total: 31000,
}})
}

func TestScheduleCancelNoLimits(t *testing.T) {
featureX := mpf("feature:x@plan:test@0")
s := newScheduleTester(t)
s.push([]Feature{{
FeaturePlan: featureX,
Interval: "@monthly",
Currency: "usd",
Mode: "graduated",
Aggregate: "sum",
Tiers: []Tier{{Upto: Inf, Base: 1000}},
}})
s.schedule("org:paid", 0, featureX)
s.checkLimits("org:paid", []Usage{{Feature: featureX, Limit: Inf}})
s.cancel("org:paid")
s.checkLimits("org:paid", nil)
}

func TestScheduleMinMaxItems(t *testing.T) {
Expand All @@ -384,9 +440,10 @@ func TestScheduleMinMaxItems(t *testing.T) {

c.Push(ctx, fs, pushLogger(t))

// effectively cancel an org that does not exist
err := c.SubscribeTo(ctx, "org:example", nil)
if !errors.Is(err, ErrInvalidPhase) {
t.Fatalf("got %v, want %v", err, ErrTooManyItems)
if !errors.Is(err, ErrOrgNotFound) {
t.Fatalf("got %v, want %v", err, ErrOrgNotFound)
}

fps := FeaturePlans(fs)
Expand Down
7 changes: 7 additions & 0 deletions control/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -88,6 +89,12 @@ func (c *Client) LookupLimits(ctx context.Context, org string) ([]Usage, error)
}

lines, err := stripe.Slurp[T](ctx, c.Stripe, "GET", "/v1/invoices/upcoming/lines", f)
var se *stripe.Error
if errors.As(err, &se) {
if se.Code == "invoice_upcoming_none" {
return nil, nil
}
}
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 6d2351d

Please sign in to comment.