Skip to content

Commit

Permalink
Update stripeauth to support multiple websocket features (#1079)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernerd-stripe committed May 25, 2023
1 parent 99d6afc commit be94053
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 28 deletions.
6 changes: 5 additions & 1 deletion pkg/logtailing/tailer.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ func (t *Tailer) createSession(ctx context.Context) (*stripeauth.StripeCLISessio
// Try to authorize at least 5 times before failing. Sometimes we have random
// transient errors that we just need to retry for.
for i := 0; i <= 5; i++ {
session, err = t.stripeAuthClient.Authorize(ctx, t.cfg.DeviceName, requestLogsWebSocketFeature, &filters, nil)
session, err = t.stripeAuthClient.Authorize(ctx, stripeauth.CreateSessionRequest{
DeviceName: t.cfg.DeviceName,
WebSocketFeatures: []string{requestLogsWebSocketFeature},
Filters: &filters,
})

if err == nil {
exitCh <- struct{}{}
Expand Down
6 changes: 5 additions & 1 deletion pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,11 @@ func (p *Proxy) createSession(ctx context.Context) (*stripeauth.StripeCLISession
ForwardConnectURL: p.cfg.ForwardConnectURL,
}

session, err = p.stripeAuthClient.Authorize(ctx, p.cfg.DeviceName, p.cfg.WebSocketFeature, nil, &devURLMap)
session, err = p.stripeAuthClient.Authorize(ctx, stripeauth.CreateSessionRequest{
DeviceName: p.cfg.DeviceName,
WebSocketFeatures: []string{p.cfg.WebSocketFeature},
DeviceURLMap: &devURLMap,
})

if err == nil {
exitCh <- struct{}{}
Expand Down
5 changes: 4 additions & 1 deletion pkg/samples/samples.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ func ConfigureDotEnv(ctx context.Context, config *config.Config) (map[string]str
}
authClient := stripeauth.NewClient(stripeClient, nil)

authSession, err := authClient.Authorize(ctx, deviceName, "webhooks", nil, nil)
authSession, err := authClient.Authorize(ctx, stripeauth.CreateSessionRequest{
DeviceName: deviceName,
WebSocketFeatures: []string{"webhooks"},
})
if err != nil {
return nil, err
}
Expand Down
33 changes: 23 additions & 10 deletions pkg/stripeauth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,39 @@ type DeviceURLMap struct {
ForwardConnectURL string
}

// CreateSessionRequest defines the API input parameters for client.Authorize.
type CreateSessionRequest struct {
DeviceName string
WebSocketFeatures []string

Filters *string
DeviceURLMap *DeviceURLMap
}

// Authorize sends a request to Stripe to initiate a new CLI session.
func (c *Client) Authorize(ctx context.Context, deviceName string, websocketFeature string, filters *string, devURLMap *DeviceURLMap) (*StripeCLISession, error) {
func (c *Client) Authorize(ctx context.Context, req CreateSessionRequest) (*StripeCLISession, error) {
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripeauth.client.Authorize",
}).Debug("Authenticating with Stripe...")

form := url.Values{}
form.Add("device_name", deviceName)
form.Add("websocket_feature", websocketFeature)

if filters != nil {
form.Add("filters", *filters)
form.Add("device_name", req.DeviceName)
for _, feature := range req.WebSocketFeatures {
form.Add("websocket_features[]", feature)
}

if devURLMap != nil && len(devURLMap.ForwardURL) > 0 {
form.Add("forward_to_url", devURLMap.ForwardURL)
if req.Filters != nil {
form.Add("filters", *req.Filters)
}

if devURLMap != nil && len(devURLMap.ForwardConnectURL) > 0 {
form.Add("forward_connect_to_url", devURLMap.ForwardConnectURL)
if devURLMap := req.DeviceURLMap; devURLMap != nil {
if len(devURLMap.ForwardURL) > 0 {
form.Add("forward_to_url", devURLMap.ForwardURL)
}

if len(devURLMap.ForwardConnectURL) > 0 {
form.Add("forward_connect_to_url", devURLMap.ForwardConnectURL)
}
}

resp, err := c.client.PerformRequest(ctx, http.MethodPost, stripeCLISessionPath, form.Encode(), nil)
Expand Down
44 changes: 29 additions & 15 deletions pkg/stripeauth/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package stripeauth
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -33,16 +32,19 @@ func TestAuthorize(t *testing.T) {
require.NotEmpty(t, r.UserAgent())
require.NotEmpty(t, r.Header.Get("X-Stripe-Client-User-Agent"))

body, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, "device_name=my-device&websocket_feature=webhooks", string(body))
require.Equal(t, "my-device", r.FormValue("device_name"))
require.Equal(t, "webhooks", r.FormValue("websocket_features[]"))
}))
defer ts.Close()

baseURL, _ := url.Parse(ts.URL)
client := NewClient(&stripe.Client{APIKey: "sk_test_123", BaseURL: baseURL}, nil)

session, err := client.Authorize(context.Background(), "my-device", "webhooks", nil, nil)
session, err := client.Authorize(context.Background(), CreateSessionRequest{
DeviceName: "my-device",
WebSocketFeatures: []string{"webhooks"},
})
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, "some-id", session.WebSocketID)
require.Equal(t, "wss://example.com/subscribe/acct_123", session.WebSocketURL)
Expand All @@ -53,22 +55,23 @@ func TestAuthorize(t *testing.T) {

func TestUserAgent(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

require.Regexp(t, regexp.MustCompile(`^Stripe/v1 stripe-cli/\w+$`), r.Header.Get("User-Agent"))
w.Write([]byte(`{}`))
}))
defer ts.Close()

baseURL, _ := url.Parse(ts.URL)
client := NewClient(&stripe.Client{APIKey: "sk_test_123", BaseURL: baseURL}, nil)

client.Authorize(context.Background(), "my-device", "webhooks", nil, nil)
_, err := client.Authorize(context.Background(), CreateSessionRequest{
DeviceName: "my-device",
WebSocketFeatures: []string{"webhooks"},
})
require.NoError(t, err)
}

func TestStripeClientUserAgent(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

encodedUserAgent := r.Header.Get("X-Stripe-Client-User-Agent")
require.NotEmpty(t, encodedUserAgent)

Expand All @@ -79,23 +82,29 @@ func TestStripeClientUserAgent(t *testing.T) {
// Just test a few headers that we know to be stable.
require.Equal(t, "stripe-cli", userAgent["name"])
require.Equal(t, "stripe", userAgent["publisher"])

w.Write([]byte(`{}`))
}))
defer ts.Close()

baseURL, _ := url.Parse(ts.URL)
client := NewClient(&stripe.Client{APIKey: "sk_test_123", BaseURL: baseURL}, nil)

client.Authorize(context.Background(), "my-device", "webhooks", nil, nil)
_, err := client.Authorize(context.Background(), CreateSessionRequest{
DeviceName: "my-device",
WebSocketFeatures: []string{"webhooks"},
})
require.NoError(t, err)
}

func TestAuthorizeWithURLDeviceMap(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

require.Equal(t, "my-device", r.FormValue("device_name"))
require.Equal(t, "webhooks", r.FormValue("websocket_feature"))
require.Equal(t, "webhooks", r.FormValue("websocket_features[]"))
require.Equal(t, "http://localhost:3000/events", r.FormValue("forward_to_url"))
require.Equal(t, "http://localhost:3000/connect/events", r.FormValue("forward_connect_to_url"))

w.Write([]byte(`{}`))
}))
defer ts.Close()

Expand All @@ -107,5 +116,10 @@ func TestAuthorizeWithURLDeviceMap(t *testing.T) {
ForwardConnectURL: "http://localhost:3000/connect/events",
}

client.Authorize(context.Background(), "my-device", "webhooks", nil, &devURLMap)
_, err := client.Authorize(context.Background(), CreateSessionRequest{
DeviceName: "my-device",
WebSocketFeatures: []string{"webhooks"},
DeviceURLMap: &devURLMap,
})
require.NoError(t, err)
}

0 comments on commit be94053

Please sign in to comment.