Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
api: add tax behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
bmizerany committed Mar 10, 2023
1 parent 108add4 commit 5ce81a2
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 58 deletions.
12 changes: 5 additions & 7 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
return err
}
phases = append(phases, control.Phase{
Trial: p.Trial,
Effective: p.Effective,
Features: fs,
AutomaticTax: sr.Tax.Automatic,
Trial: p.Trial,
Effective: p.Effective,
Features: fs,
Tax: sr.Tax,
})
}
}
Expand Down Expand Up @@ -342,9 +342,7 @@ func (h *Handler) servePhase(w http.ResponseWriter, r *http.Request) error {
Plans: p.Plans,
Fragments: p.Fragments(),
Trial: p.Trial,
Tax: apitypes.Taxation{
Automatic: p.AutomaticTax,
},
Tax: p.Tax,
})
}
}
Expand Down
3 changes: 2 additions & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"tier.run/refs"
"tier.run/stripe"
"tier.run/stripe/stroke"
"tier.run/types/tax"
"tier.run/types/they"
)

Expand Down Expand Up @@ -319,7 +320,7 @@ func TestScheduleAutomaticTax(t *testing.T) {
}
})
_, err := tc.Schedule(ctx, "org:test", &tier.ScheduleParams{
Tax: tier.Taxation{Automatic: true},
Tax: tax.Applied{Automatically: true},
Phases: []apitypes.Phase{
{
Features: []string{"plan:test@0"},
Expand Down
17 changes: 7 additions & 10 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"tier.run/refs"
"tier.run/types/payment"
"tier.run/types/tax"
)

type Error struct {
Expand All @@ -26,18 +27,14 @@ type Phase struct {
Features []string `json:"features,omitempty"`
}

type Taxation struct {
Automatic bool `json:"automatic,omitempty"`
}

type PhaseResponse struct {
Effective time.Time `json:"effective,omitempty"`
End time.Time `json:"end,omitempty"`
Features []refs.FeaturePlan `json:"features,omitempty"`
Plans []refs.Plan `json:"plans,omitempty"`
Fragments []refs.FeaturePlan `json:"fragments,omitempty"`
Trial bool `json:"trial,omitempty"`
Tax Taxation `json:"tax,omitempty"`
Tax tax.Applied `json:"tax"`
}

func (pr PhaseResponse) MarshalJSON() ([]byte, error) {
Expand Down Expand Up @@ -91,11 +88,11 @@ type CheckoutRequest struct {
}

type ScheduleRequest struct {
Org string `json:"org"`
PaymentMethodID string `json:"payment_method_id"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
Tax Taxation `json:"tax"`
Org string `json:"org"`
PaymentMethodID string `json:"payment_method_id"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
Tax tax.Applied `json:"tax"`
}

// ScheduleResponse is the expected response from a schedule request. It is
Expand Down
48 changes: 42 additions & 6 deletions api/apitypes/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"

"tier.run/refs"
"tier.run/types/tax"
"tier.run/values"
)

Expand Down Expand Up @@ -53,12 +54,47 @@ type Divide struct {
}

type Feature struct {
Title string `json:"title,omitempty"`
Base float64 `json:"base,omitempty"`
Mode string `json:"mode,omitempty"`
Aggregate string `json:"aggregate,omitempty"`
Tiers []Tier `json:"tiers,omitempty"`
Divide *Divide `json:"divide,omitempty"`
Title string `json:"title,omitempty"`
Base float64 `json:"base,omitempty"`
Mode string `json:"mode,omitempty"`
Aggregate string `json:"aggregate,omitempty"`
Tiers []Tier `json:"tiers,omitempty"`
Divide Divide `json:"divide"`
Tax tax.Settings `json:"tax"`
}

func (v Feature) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Title string `json:"title,omitempty"`
Base float64 `json:"base,omitempty"`
Mode string `json:"mode,omitempty"`
Aggregate string `json:"aggregate,omitempty"`
Tiers []Tier `json:"tiers,omitempty"`
Divide *Divide `json:"divide,omitempty"`
Tax *tax.Settings `json:"tax,omitempty"`
}{
Title: v.Title,
Base: v.Base,
Mode: v.Mode,
Aggregate: v.Aggregate,
Tiers: v.Tiers,
Divide: zeroAsNil(v.Divide),
Tax: zeroAsNil(v.Tax),
})
}

// zeroAsNil returns a pointer to v if v is not the zero value for its type.
// If v implements IsZero, it is used to determine if v is the zero value.
func zeroAsNil[T comparable](v T) *T {
z, ok := any(v).(interface{ IsZero() bool })
if ok && z.IsZero() {
return nil
}
var zero T
if v == zero {
return nil
}
return &v
}

type Plan struct {
Expand Down
10 changes: 6 additions & 4 deletions api/materialize/views.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ func FromPricingHuJSON(data []byte) (fs []control.Feature, err error) {
for feature, f := range p.Features {
fn := feature.WithPlan(plan)

divide := values.Coalesce(f.Divide, &apitypes.Divide{})
ff := control.Feature{
FeaturePlan: fn,

Expand All @@ -55,8 +54,10 @@ func FromPricingHuJSON(data []byte) (fs []control.Feature, err error) {
Mode: values.Coalesce(f.Mode, "graduated"),
Aggregate: values.Coalesce(f.Aggregate, "sum"),

TransformDenominator: divide.By,
TransformRoundUp: divide.Rounding == "up",
TransformDenominator: f.Divide.By,
TransformRoundUp: f.Divide.Rounding == "up",

Tax: f.Tax,
}

if len(f.Tiers) > 0 {
Expand Down Expand Up @@ -110,13 +111,14 @@ func ToPricingJSON(fs []control.Feature) ([]byte, error) {
Mode: values.ZeroIf(f.Mode, "graduated"),
Aggregate: values.ZeroIf(f.Aggregate, "sum"),
Tiers: tiers,
Tax: f.Tax,
}
if f.TransformDenominator != 0 {
var round string
if f.TransformRoundUp {
round = "up"
}
af.Divide = &apitypes.Divide{
af.Divide = apitypes.Divide{
By: f.TransformDenominator,
Rounding: round,
}
Expand Down
27 changes: 27 additions & 0 deletions api/materialize/views_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"tier.run/client/tier"
"tier.run/control"
"tier.run/refs"
"tier.run/types/tax"
)

func TestPricingHuJSON(t *testing.T) {
Expand Down Expand Up @@ -41,6 +42,13 @@ func TestPricingHuJSON(t *testing.T) {
}
},
},
"plan:tax@1": {
"features": {
"feature:tax:not:included": {
"tax": {"included": true},
},
},
},
}
}`)

Expand All @@ -63,6 +71,16 @@ func TestPricingHuJSON(t *testing.T) {
Aggregate: "sum", // defaults
Base: 100,
},
{
PlanTitle: "plan:tax@1",
Title: "feature:tax:not:included@plan:tax@1",
FeaturePlan: refs.MustParseFeaturePlan("feature:tax:not:included@plan:tax@1"),
Currency: "usd",
Interval: "@monthly",
Mode: "graduated", // defaults
Aggregate: "sum",
Tax: tax.Settings{Included: true},
},
{
PlanTitle: "Just an example plan to show off features",
Title: "feature:volume@plan:example@1",
Expand Down Expand Up @@ -123,6 +141,14 @@ func TestPricingHuJSON(t *testing.T) {
"divide": {"by": 100, "rounding": "up"},
}
}
},
"plan:tax@1": {
"title": "plan:tax@1",
"features": {
"feature:tax:not:included": {
"tax": {"included": true}
}
}
}
}
}`)
Expand All @@ -134,6 +160,7 @@ func diffJSON(t *testing.T, got, want []byte) {
t.Helper()

format := func(b []byte) string {
t.Helper()
b, err := hujson.Standardize(b)
if err != nil {
t.Fatal(err)
Expand Down
5 changes: 2 additions & 3 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"tier.run/api/apitypes"
"tier.run/fetch"
"tier.run/refs"
"tier.run/types/tax"
)

// ClockHeader is the header used to pass the clock ID to the tier sidecar.
Expand Down Expand Up @@ -305,14 +306,12 @@ type CheckoutParams struct {
RequireBillingAddress bool
}

type Taxation = apitypes.Taxation

type ScheduleParams struct {
Info *OrgInfo
Phases []Phase
PaymentMethodID string

Tax Taxation
Tax tax.Applied
}

func (c *Client) Schedule(ctx context.Context, org string, p *ScheduleParams) (*apitypes.ScheduleResponse, error) {
Expand Down
9 changes: 5 additions & 4 deletions cmd/tier/tier.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"tier.run/control"
"tier.run/profile"
"tier.run/stripe"
"tier.run/types/tax"
"tier.run/version"
)

Expand Down Expand Up @@ -279,7 +280,7 @@ func runTier(cmd string, args []string) (err error) {
cancelURL := fs.String("cancel_url", "", "sets the cancel URL for use with -checkout")
requireBillingAddress := fs.Bool("require_billing_address", false, "require billing address for use with --checkout")
paymentMethod := fs.String("paymentmethod", "", "sets the Stripe payment method for the subscription (e.g. pm_123). It is ignored with --checkout")
tax := fs.String("tax", "", "sets the Stripe tax rate for the subscription ('auto' is currently the only supported value)")
taxtype := fs.String("tax", "", "sets the Stripe tax rate for the subscription ('auto' is currently the only supported value)")
if err := fs.Parse(args); err != nil {
return err
}
Expand All @@ -293,8 +294,8 @@ func runTier(cmd string, args []string) (err error) {
fmt.Fprintln(stderr, "tier: the -cancel flag must be used without arguments")
return errUsage
}
if *tax != "" && *tax != "auto" {
fmt.Fprintf(stderr, "tier: invalid tax rate %q\n", *tax)
if *taxtype != "" && *taxtype != "auto" {
fmt.Fprintf(stderr, "tier: invalid tax rate %q\n", *taxtype)
return errUsage
}

Expand Down Expand Up @@ -324,7 +325,7 @@ func runTier(cmd string, args []string) (err error) {
Email: *email,
},
PaymentMethodID: *paymentMethod,
Tax: tier.Taxation{Automatic: *tax == "auto"},
Tax: tax.Applied{Automatically: *taxtype == "auto"},
}
switch {
case *trial > 0:
Expand Down
13 changes: 11 additions & 2 deletions control/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/sync/errgroup"
"tier.run/refs"
"tier.run/stripe"
"tier.run/types/tax"
"tier.run/values"
)

Expand Down Expand Up @@ -106,6 +107,8 @@ type Feature struct {

TransformDenominator int // the denominator for transforming usage
TransformRoundUp bool // whether to round up transformed usage; otherwise round down

Tax tax.Settings
}

// TODO(bmizerany): remove FQN and replace with simply adding the version to
Expand Down Expand Up @@ -330,9 +333,13 @@ func (c *Client) pushFeature(ctx context.Context, f Feature) (providerID string,
data.Set("metadata", "tier.limit", limit)
}

if f.Tax.Included {
data.Set("tax_behavior", "inclusive")
} else {
data.Set("tax_behavior", "exclusive")
}

// TODO(bmizerany): data.Set("active", ?)
// TODO(bmizerany): data.Set("tax_behavior", "?")
// TODO(bmizerany): data.Set("transform_quantity", "?")
// TODO(bmizerany): data.Set("currency_options", "?")

var v struct {
Expand Down Expand Up @@ -374,6 +381,7 @@ type stripePrice struct {
DivideBy int `json:"divide_by"`
Round string `json:"round"`
} `json:"transform_quantity"`
TaxBehavior string `json:"tax_behavior"`
}

func stripePriceToFeature(p stripePrice) Feature {
Expand All @@ -388,6 +396,7 @@ func stripePriceToFeature(p stripePrice) Feature {
Aggregate: aggregateFromStripe[p.Recurring.AggregateUsage],
TransformDenominator: p.TransformQuantity.DivideBy,
TransformRoundUp: p.TransformQuantity.Round == "up",
Tax: tax.Settings{Included: p.TaxBehavior == "inclusive"},
}

if len(p.Tiers) == 0 && p.Recurring.UsageType == "metered" {
Expand Down
9 changes: 9 additions & 0 deletions control/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"kr.dev/diff"
"tier.run/refs"
"tier.run/stripe/stroke"
"tier.run/types/tax"
)

func newTestClient(t *testing.T) *Client {
Expand Down Expand Up @@ -80,6 +81,14 @@ func TestRoundTrip(t *testing.T) {
{Upto: 1, Price: 100, Base: 0},
},
},
{
FeaturePlan: refs.MustParseFeaturePlan("feature:tax@0"),
Interval: "@daily",
Currency: "eur",
Title: "Test2",
Base: 1000,
Tax: tax.Settings{Included: true},
},
}

if !slices.IsSortedFunc(want, func(a, b Feature) bool {
Expand Down
Loading

0 comments on commit 5ce81a2

Please sign in to comment.