Skip to content

Commit

Permalink
Extract core proxy logic outside of cobra listen cmd (#574)
Browse files Browse the repository at this point in the history
* a few things to cleanup and move

* wtv

* fix endpoint routes

* move withSIGTERMCancel to cmd

* cleanup

* moved everything... almost

* lint

* remove empty liste_test file

* delete deprecated New function

* cleanup

* update test

* refactor visitor to be shared

* everything... works i think

* lint

* rename LogElement to DataElement and check type

* move visitor into websocket package

* remove old TODO

* cleanup
  • Loading branch information
pepin-stripe committed May 19, 2021
1 parent 1f68229 commit 594a477
Show file tree
Hide file tree
Showing 13 changed files with 481 additions and 371 deletions.
255 changes: 114 additions & 141 deletions pkg/cmd/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@ import (
"context"
"errors"
"fmt"
"net/url"
"strconv"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/briandowns/spinner"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/pflag"

"github.com/stripe/stripe-cli/pkg/ansi"
"github.com/stripe/stripe-cli/pkg/proxy"
"github.com/stripe/stripe-cli/pkg/requests"
"github.com/stripe/stripe-cli/pkg/stripe"
"github.com/stripe/stripe-cli/pkg/validators"
"github.com/stripe/stripe-cli/pkg/version"
"github.com/stripe/stripe-cli/pkg/websocket"
)

const webhooksWebSocketFeature = "webhooks"
const timeLayout = "2006-01-02 15:04:05"
const outputFormatJSON = "JSON"

type listenCmd struct {
cmd *cobra.Command
Expand Down Expand Up @@ -96,181 +101,149 @@ Stripe account.`,
// Normally, this function would be listed alphabetically with the others declared in this file,
// but since it's acting as the core functionality for the cmd above, I'm keeping it close.
func (lc *listenCmd) runListenCmd(cmd *cobra.Command, args []string) error {
if !lc.printJSON && !lc.onlyPrintSecret && !lc.skipUpdate {
version.CheckLatestVersion()
}

deviceName, err := Config.Profile.GetDeviceName()
if err != nil {
return err
}

endpointRoutes := make([]proxy.EndpointRoute, 0)

key, err := Config.Profile.GetAPIKey(lc.livemode)
if err != nil {
return err
}

if !lc.printJSON && !lc.onlyPrintSecret && !lc.skipUpdate {
version.CheckLatestVersion()
}

for _, event := range lc.events {
if _, found := validEvents[event]; !found {
fmt.Printf("Warning: You're attempting to listen for \"%s\", which isn't a valid event\n", event)
// --print-secret option
if lc.onlyPrintSecret {
secret, err := proxy.GetSessionSecret(deviceName, key, lc.apiBaseURL)
if err != nil {
return err
}
fmt.Printf("%s\n", secret)
return nil
}

if len(lc.events) == 0 {
lc.events = []string{"*"}
}

if len(lc.forwardConnectURL) == 0 {
lc.forwardConnectURL = lc.forwardURL
}

// default to non connect headers if no forward connect headers
if len(lc.forwardConnectHeaders) == 0 {
lc.forwardConnectHeaders = lc.forwardHeaders
}

if len(lc.forwardURL) > 0 {
endpointRoutes = append(endpointRoutes, proxy.EndpointRoute{
URL: parseURL(lc.forwardURL),
ForwardHeaders: lc.forwardHeaders,
Connect: false,
EventTypes: lc.events,
})
}

if len(lc.forwardConnectURL) > 0 {
endpointRoutes = append(endpointRoutes, proxy.EndpointRoute{
URL: parseURL(lc.forwardConnectURL),
ForwardHeaders: lc.forwardConnectHeaders,
Connect: true,
EventTypes: lc.events,
})
}

// validate forward-urls args
if lc.useConfiguredWebhooks && len(lc.forwardURL) > 0 {
if strings.HasPrefix(lc.forwardURL, "/") {
return errors.New("--forward-to cannot be a relative path when loading webhook endpoints from the API")
}

if strings.HasPrefix(lc.forwardConnectURL, "/") {
return errors.New("--forward-connect-to cannot be a relative path when loading webhook endpoints from the API")
}

endpoints := lc.getEndpointsFromAPI(key)
if len(endpoints.Data) == 0 {
return errors.New("You have not defined any webhook endpoints on your account. Go to the Stripe Dashboard to add some: https://dashboard.stripe.com/test/webhooks")
}

endpointRoutes = buildEndpointRoutes(endpoints, parseURL(lc.forwardURL), parseURL(lc.forwardConnectURL), lc.forwardHeaders, lc.forwardConnectHeaders)
} else if lc.useConfiguredWebhooks && len(lc.forwardURL) == 0 {
return errors.New("--load-from-webhooks-api requires a location to forward to with --forward-to")
}

p := proxy.New(&proxy.Config{
DeviceName: deviceName,
Key: key,
EndpointRoutes: endpointRoutes,
APIBaseURL: lc.apiBaseURL,
WebSocketFeature: webhooksWebSocketFeature,
PrintJSON: lc.printJSON,
Format: lc.format,
UseLatestAPIVersion: lc.latestAPIVersion,
SkipVerify: lc.skipVerify,
Log: log.StandardLogger(),
NoWSS: lc.noWSS,
}, lc.events)

if lc.onlyPrintSecret {
secret, err := p.GetSessionSecret(context.Background())
if err != nil {
return err
}
fmt.Printf("%s\n", secret)
return nil
}

err = p.Run(context.Background())
logger := log.StandardLogger()
proxyVisitor := createVisitor(logger, lc.format, lc.printJSON)
proxyOutCh := make(chan websocket.IElement)

p, err := proxy.Init(&proxy.Config{
DeviceName: deviceName,
Key: key,
ForwardURL: lc.forwardURL,
ForwardHeaders: lc.forwardHeaders,
ForwardConnectURL: lc.forwardConnectURL,
ForwardConnectHeaders: lc.forwardConnectHeaders,
UseConfiguredWebhooks: lc.useConfiguredWebhooks,
APIBaseURL: lc.apiBaseURL,
WebSocketFeature: webhooksWebSocketFeature,
PrintJSON: lc.printJSON,
UseLatestAPIVersion: lc.latestAPIVersion,
SkipVerify: lc.skipVerify,
Log: logger,
NoWSS: lc.noWSS,
Events: lc.events,
OutCh: proxyOutCh,
})
if err != nil {
return err
}

return nil
}

func (lc *listenCmd) getEndpointsFromAPI(secretKey string) requests.WebhookEndpointList {
apiBaseURL := lc.apiBaseURL
if apiBaseURL == "" {
apiBaseURL = stripe.DefaultAPIBaseURL
}
ctx := withSIGTERMCancel(context.Background(), func() {
log.WithFields(log.Fields{
"prefix": "proxy.Proxy.Run",
}).Debug("Ctrl+C received, cleaning up...")
})

return requests.WebhookEndpointsList(apiBaseURL, "2019-03-14", secretKey, &Config.Profile)
}
go p.Run(ctx)

func buildEndpointRoutes(endpoints requests.WebhookEndpointList, forwardURL, forwardConnectURL string, forwardHeaders []string, forwardConnectHeaders []string) []proxy.EndpointRoute {
endpointRoutes := make([]proxy.EndpointRoute, 0)

for _, endpoint := range endpoints.Data {
u, err := url.Parse(endpoint.URL)
// Silently skip over invalid paths
if err == nil {
// Since webhooks in the dashboard may have a more generic url, only extract
// the path. We'll use this with `localhost` or with the `--forward-to` flag
if endpoint.Application == "" {
endpointRoutes = append(endpointRoutes, proxy.EndpointRoute{
URL: buildForwardURL(forwardURL, u),
ForwardHeaders: forwardHeaders,
Connect: false,
EventTypes: endpoint.EnabledEvents,
})
} else {
endpointRoutes = append(endpointRoutes, proxy.EndpointRoute{
URL: buildForwardURL(forwardConnectURL, u),
ForwardHeaders: forwardConnectHeaders,
Connect: true,
EventTypes: endpoint.EnabledEvents,
})
}
for el := range proxyOutCh {
err := el.Accept(proxyVisitor)
if err != nil {
return err
}
}

return endpointRoutes
return nil
}

// parseURL parses the potentially incomplete URL provided in the configuration
// and returns a full URL
func parseURL(url string) string {
_, err := strconv.Atoi(url)
if err == nil {
// If the input is just a number, assume it's a port number
url = "localhost:" + url
}

if strings.HasPrefix(url, "/") {
// If the input starts with a /, assume it's a relative path
url = "localhost" + url
}
func withSIGTERMCancel(ctx context.Context, onCancel func()) context.Context {
// Create a context that will be canceled when Ctrl+C is pressed
ctx, cancel := context.WithCancel(ctx)

if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
// Add the protocol if it's not already there
url = "http://" + url
}
interruptCh := make(chan os.Signal, 1)
signal.Notify(interruptCh, os.Interrupt, syscall.SIGTERM)

return url
go func() {
<-interruptCh
onCancel()
cancel()
}()
return ctx
}

func buildForwardURL(forwardURL string, destination *url.URL) string {
f, err := url.Parse(forwardURL)
if err != nil {
log.Fatalf("Provided forward url cannot be parsed: %s", forwardURL)
}
func createVisitor(logger *log.Logger, format string, printJSON bool) *websocket.Visitor {
var s *spinner.Spinner

return &websocket.Visitor{
VisitError: func(ee websocket.ErrorElement) error {
ansi.StopSpinner(s, "", logger.Out)
logger.Fatal(ee.Error)
return ee.Error
},
VisitStatus: func(se websocket.StateElement) error {
switch se.State {
case websocket.Loading:
s = ansi.StartNewSpinner("Getting ready...", logger.Out)
case websocket.Reconnecting:
ansi.StartSpinner(s, "Session expired, reconnecting...", logger.Out)
case websocket.Ready:
ansi.StopSpinner(s, fmt.Sprintf("Ready! Your webhook signing secret is %s (^C to quit)", ansi.Bold(se.Data[0])), logger.Out)
case websocket.Done:
ansi.StopSpinner(s, "", logger.Out)
}
return nil
},
VisitData: func(de websocket.DataElement) error {
stripeEvent, ok := de.Data.(proxy.StripeEvent)
if !ok {
return fmt.Errorf("VisitData received unexpected type for DataElement, got %T expected %T", de, proxy.StripeEvent{})
}

return fmt.Sprintf(
"%s://%s%s%s",
f.Scheme,
f.Host,
strings.TrimSuffix(f.Path, "/"), // avoids having a double "//"
destination.Path,
)
if strings.ToUpper(format) == outputFormatJSON || printJSON {
fmt.Println(de.Marshaled)
} else {
maybeConnect := ""
if stripeEvent.IsConnect() {
maybeConnect = "connect "
}

localTime := time.Now().Format(timeLayout)

color := ansi.Color(os.Stdout)
outputStr := fmt.Sprintf("%s --> %s%s [%s]",
color.Faint(localTime),
maybeConnect,
ansi.Linkify(ansi.Bold(stripeEvent.Type), stripeEvent.URLForEventType(), logger.Out),
ansi.Linkify(stripeEvent.ID, stripeEvent.URLForEventID(), logger.Out),
)
fmt.Println(outputStr)
}
return nil
},
}
}
75 changes: 0 additions & 75 deletions pkg/cmd/listen_test.go

This file was deleted.

0 comments on commit 594a477

Please sign in to comment.