diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 138632f31..d7beb48fd 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "crypto/tls" "encoding/json" "errors" "fmt" @@ -80,8 +79,6 @@ type Config struct { // UseConfiguredWebhooks loads webhooks config from user's account UseConfiguredWebhooks bool - // EndpointsRoutes is a mapping of local webhook endpoint urls to the events they consume - EndpointRoutes []EndpointRoute // List of events to listen and proxy Events []string @@ -114,12 +111,9 @@ type Config struct { type Proxy struct { cfg *Config - endpointClients []*EndpointClient - stripeAuthClient *stripeauth.Client - webSocketClient *websocket.Client - - // Events is the supported event types for the command - events map[string]bool + stripeAuthClient *stripeauth.Client + webSocketClient *websocket.Client + webhookEventProcessor *WebhookEventProcessor } const maxConnectAttempts = 3 @@ -133,6 +127,12 @@ func (p *Proxy) IsConnected() <-chan struct{} { return p.webSocketClient.Connected() } +func (p *Proxy) sendMessage(msg *websocket.OutgoingMessage) { + if p.webSocketClient != nil { + p.webSocketClient.SendMessage(msg) + } +} + // Run sets the websocket connection and starts the Goroutines to forward // incoming events to the local endpoint. func (p *Proxy) Run(ctx context.Context) error { @@ -162,7 +162,7 @@ func (p *Proxy) Run(ctx context.Context) error { Log: p.cfg.Log, NoWSS: p.cfg.NoWSS, ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second, - EventHandler: websocket.EventHandlerFunc(p.processWebhookEvent), + EventHandler: p.webhookEventProcessor, }, ) @@ -223,7 +223,6 @@ func GetSessionSecret(ctx context.Context, client stripe.RequestPerformer, devic p, err := Init(ctx, &Config{ Client: client, DeviceName: deviceName, - EndpointRoutes: make([]EndpointRoute, 0), WebSocketFeature: "webhooks", }) if err != nil { @@ -282,35 +281,13 @@ func (p *Proxy) createSession(ctx context.Context) (*stripeauth.StripeCLISession return session, err } -func (p *Proxy) filterWebhookEvent(msg *websocket.WebhookEvent) bool { - if msg.Endpoint.APIVersion != nil && !p.cfg.UseLatestAPIVersion { - p.cfg.Log.WithFields(log.Fields{ - "prefix": "proxy.Proxy.filterWebhookEvent", - "api_version": getAPIVersionString(msg.Endpoint.APIVersion), - }).Debugf("Received event with non-default API version, ignoring") - - return true - } - - if msg.Endpoint.APIVersion == nil && p.cfg.UseLatestAPIVersion { - p.cfg.Log.WithFields(log.Fields{ - "prefix": "proxy.Proxy.filterWebhookEvent", - }).Debugf("Received event with default API version, ignoring") - - return true - } - - return false -} - // This function outputs the event payload in the format specified. // Currently only supports JSON. -func (p *Proxy) formatOutput(format string, eventPayload string) string { +func formatOutput(format string, eventPayload string) string { var event map[string]interface{} err := json.Unmarshal([]byte(eventPayload), &event) if err != nil { - p.cfg.Log.Debug("Received malformed event from Stripe, ignoring") - return fmt.Sprint(err) + return fmt.Sprintf("Received malformed event: %s", err) } switch strings.ToUpper(format) { // The distinction between this and PrintJSON is that this output is stripped of all pretty format. @@ -322,122 +299,6 @@ func (p *Proxy) formatOutput(format string, eventPayload string) string { } } -func (p *Proxy) processWebhookEvent(msg websocket.IncomingMessage) { - if msg.WebhookEvent == nil { - p.cfg.Log.Debug("WebSocket specified for Webhooks received non-webhook event") - return - } - - webhookEvent := msg.WebhookEvent - - p.cfg.Log.WithFields(log.Fields{ - "prefix": "proxy.Proxy.processWebhookEvent", - "webhook_id": webhookEvent.WebhookID, - "webhook_converesation_id": webhookEvent.WebhookConversationID, - }).Debugf("Processing webhook event") - - var evt StripeEvent - - err := json.Unmarshal([]byte(webhookEvent.EventPayload), &evt) - if err != nil { - p.cfg.Log.Debug("Received malformed event from Stripe, ignoring") - return - } - - req, err := ExtractRequestData(evt.RequestData) - - if err != nil { - p.cfg.Log.Debug("Received malformed event from Stripe, ignoring") - return - } - - evt.Request = req - - p.cfg.Log.WithFields(log.Fields{ - "prefix": "proxy.Proxy.processWebhookEvent", - "webhook_id": webhookEvent.WebhookID, - "webhook_conversation_id": webhookEvent.WebhookConversationID, - "event_id": evt.ID, - "event_type": evt.Type, - "api_version": getAPIVersionString(msg.Endpoint.APIVersion), - }).Trace("Webhook event trace") - - // at this point the message is valid so we can acknowledge it - ackMessage := websocket.NewEventAck(webhookEvent.WebhookID, webhookEvent.WebhookConversationID) - p.webSocketClient.SendMessage(ackMessage) - - if p.filterWebhookEvent(webhookEvent) { - return - } - - evtCtx := eventContext{ - webhookID: webhookEvent.WebhookID, - webhookConversationID: webhookEvent.WebhookConversationID, - event: &evt, - } - - if p.events["*"] || p.events[evt.Type] { - p.cfg.OutCh <- websocket.DataElement{ - Data: evt, - Marshaled: p.formatOutput(outputFormatJSON, webhookEvent.EventPayload), - } - - for _, endpoint := range p.endpointClients { - if endpoint.SupportsEventType(evt.IsConnect(), evt.Type) { - // TODO: handle errors returned by endpointClients - go endpoint.Post( - evtCtx, - webhookEvent.EventPayload, - webhookEvent.HTTPHeaders, - ) - } - } - } -} - -func (p *Proxy) processEndpointResponse(evtCtx eventContext, forwardURL string, resp *http.Response) { - buf, err := io.ReadAll(resp.Body) - if err != nil { - p.cfg.OutCh <- websocket.ErrorElement{ - Error: FailedToReadResponseError{Err: err}, - } - return - } - - body := truncate(string(buf), maxBodySize, true) - - p.cfg.OutCh <- websocket.DataElement{ - Data: EndpointResponse{ - Event: evtCtx.event, - Resp: resp, - }, - } - - idx := 0 - headers := make(map[string]string) - - for k, v := range resp.Header { - headers[truncate(k, maxHeaderKeySize, false)] = truncate(v[0], maxHeaderValueSize, true) - idx++ - - if idx > maxNumHeaders { - break - } - } - - if p.webSocketClient != nil { - msg := websocket.NewWebhookResponse( - evtCtx.webhookID, - evtCtx.webhookConversationID, - forwardURL, - resp.StatusCode, - body, - headers, - ) - p.webSocketClient.SendMessage(msg) - } -} - // // Public functions // @@ -514,37 +375,22 @@ func Init(ctx context.Context, cfg *Config) (*Proxy, error) { } } + processorConfig := &WebhookEventProcessorConfig{ + Log: cfg.Log, + Events: cfg.Events, + OutCh: cfg.OutCh, + UseLatestAPIVersion: cfg.UseLatestAPIVersion, + SkipVerify: cfg.SkipVerify, + Timeout: cfg.Timeout, + } + p := &Proxy{ cfg: cfg, stripeAuthClient: stripeauth.NewClient(cfg.Client, &stripeauth.Config{ Log: cfg.Log, }), - events: convertToMap(cfg.Events), - } - - for _, route := range endpointRoutes { - // append to endpointClients - p.endpointClients = append(p.endpointClients, NewEndpointClient( - route.URL, - route.ForwardHeaders, - route.Connect, - route.EventTypes, - &EndpointConfig{ - HTTPClient: &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Timeout: time.Duration(cfg.Timeout) * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.SkipVerify}, - }, - }, - Log: p.cfg.Log, - ResponseHandler: EndpointResponseHandlerFunc(p.processEndpointResponse), - OutCh: p.cfg.OutCh, - }, - )) } + p.webhookEventProcessor = NewWebhookEventProcessor(p.sendMessage, endpointRoutes, processorConfig) return p, nil } diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index 6e49514ef..2fd43da56 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -28,11 +28,11 @@ func TestFilterWebhookEvent(t *testing.T) { }, } - require.False(t, proxyUseDefault.filterWebhookEvent(evtDefault)) - require.True(t, proxyUseDefault.filterWebhookEvent(evtLatest)) + require.False(t, proxyUseDefault.webhookEventProcessor.filterWebhookEvent(evtDefault)) + require.True(t, proxyUseDefault.webhookEventProcessor.filterWebhookEvent(evtLatest)) - require.True(t, proxyUseLatest.filterWebhookEvent(evtDefault)) - require.False(t, proxyUseLatest.filterWebhookEvent(evtLatest)) + require.True(t, proxyUseLatest.webhookEventProcessor.filterWebhookEvent(evtDefault)) + require.False(t, proxyUseLatest.webhookEventProcessor.filterWebhookEvent(evtLatest)) } func TestTruncate(t *testing.T) { @@ -147,11 +147,11 @@ func TestForwardToOnly(t *testing.T) { } p, err := Init(context.Background(), &cfg) require.NoError(t, err) - require.Equal(t, 2, len(p.endpointClients)) - require.EqualValues(t, "http://localhost:4242", p.endpointClients[0].URL) - require.EqualValues(t, false, p.endpointClients[0].connect) - require.EqualValues(t, "http://localhost:4242", p.endpointClients[1].URL) - require.EqualValues(t, true, p.endpointClients[1].connect) + require.Equal(t, 2, len(p.webhookEventProcessor.endpointClients)) + require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[0].URL) + require.EqualValues(t, false, p.webhookEventProcessor.endpointClients[0].connect) + require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[1].URL) + require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[1].connect) } func TestForwardConnectToOnly(t *testing.T) { @@ -161,9 +161,9 @@ func TestForwardConnectToOnly(t *testing.T) { } p, err := Init(context.Background(), &cfg) require.NoError(t, err) - require.Equal(t, 1, len(p.endpointClients)) - require.EqualValues(t, "http://localhost:4242/connect", p.endpointClients[0].URL) - require.EqualValues(t, true, p.endpointClients[0].connect) + require.Equal(t, 1, len(p.webhookEventProcessor.endpointClients)) + require.EqualValues(t, "http://localhost:4242/connect", p.webhookEventProcessor.endpointClients[0].URL) + require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[0].connect) } func TestForwardToAndForwardConnectTo(t *testing.T) { @@ -173,11 +173,11 @@ func TestForwardToAndForwardConnectTo(t *testing.T) { } p, err := Init(context.Background(), &cfg) require.NoError(t, err) - require.Equal(t, 2, len(p.endpointClients)) - require.EqualValues(t, "http://localhost:4242", p.endpointClients[0].URL) - require.EqualValues(t, false, p.endpointClients[0].connect) - require.EqualValues(t, "http://localhost:4242/connect", p.endpointClients[1].URL) - require.EqualValues(t, true, p.endpointClients[1].connect) + require.Equal(t, 2, len(p.webhookEventProcessor.endpointClients)) + require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[0].URL) + require.EqualValues(t, false, p.webhookEventProcessor.endpointClients[0].connect) + require.EqualValues(t, "http://localhost:4242/connect", p.webhookEventProcessor.endpointClients[1].URL) + require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[1].connect) } func TestExtractRequestData(t *testing.T) { diff --git a/pkg/proxy/webhook_event_processor.go b/pkg/proxy/webhook_event_processor.go new file mode 100644 index 000000000..b9a99ea06 --- /dev/null +++ b/pkg/proxy/webhook_event_processor.go @@ -0,0 +1,222 @@ +package proxy + +import ( + "crypto/tls" + "encoding/json" + "io" + "net/http" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/stripe/stripe-cli/pkg/websocket" +) + +// WebhookEventProcessorConfig defines the external inputs that infuence the +// behavior of a WebhookEventProcessor. +type WebhookEventProcessorConfig struct { + // The logger used to log messages to stdin/err + Log *log.Logger + + // List of events to listen and proxy + Events []string + + // OutCh is the channel to send logs and statuses to for processing in other packages + OutCh chan websocket.IElement + + // Indicates whether to filter events formatted with the default or latest API version + UseLatestAPIVersion bool + + // Indicates whether to skip certificate verification when forwarding webhooks to HTTPS endpoints + SkipVerify bool + + // Override default timeout + Timeout int64 +} + +// WebhookEventProcessor encapsulates logic around processing and forwarding +// webhook events. +type WebhookEventProcessor struct { + cfg *WebhookEventProcessorConfig + + // Events is the supported event types for the command + events map[string]bool + endpointClients []*EndpointClient + sendMessage func(*websocket.OutgoingMessage) +} + +// NewWebhookEventProcessor constructs a WebhookEventProcessor from the provided +// websocket delivery function, route table, and config. +func NewWebhookEventProcessor(sendMessage func(*websocket.OutgoingMessage), routes []EndpointRoute, cfg *WebhookEventProcessorConfig) *WebhookEventProcessor { + p := &WebhookEventProcessor{ + cfg: cfg, + events: convertToMap(cfg.Events), + sendMessage: sendMessage, + } + + for _, route := range routes { + // append to endpointClients + p.endpointClients = append(p.endpointClients, NewEndpointClient( + route.URL, + route.ForwardHeaders, + route.Connect, + route.EventTypes, + &EndpointConfig{ + HTTPClient: &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: time.Duration(cfg.Timeout) * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.SkipVerify}, + }, + }, + Log: cfg.Log, + ResponseHandler: EndpointResponseHandlerFunc(p.processEndpointResponse), + OutCh: cfg.OutCh, + }, + )) + } + + return p +} + +// ProcessEvent processes webhook events, notifying listeners via the configured +// OutCh, sending acknowledgements with the configured websocket sender, and +// forwarding events to configured endpoints. +// +// ProcessEvent implements the websocket.EndpointResponseHandler interface. +func (p *WebhookEventProcessor) ProcessEvent(msg websocket.IncomingMessage) { + if msg.WebhookEvent == nil { + p.cfg.Log.Debug("WebSocket specified for Webhooks received non-webhook event") + return + } + + webhookEvent := msg.WebhookEvent + + p.cfg.Log.WithFields(log.Fields{ + "prefix": "proxy.WebhookEventProcessor.ProcessEvent", + "webhook_id": webhookEvent.WebhookID, + "webhook_converesation_id": webhookEvent.WebhookConversationID, + }).Debugf("Processing webhook event") + + var evt StripeEvent + + err := json.Unmarshal([]byte(webhookEvent.EventPayload), &evt) + if err != nil { + p.cfg.Log.Debug("Received malformed event from Stripe, ignoring") + return + } + + req, err := ExtractRequestData(evt.RequestData) + + if err != nil { + p.cfg.Log.Debug("Received malformed event from Stripe, ignoring") + return + } + + evt.Request = req + + p.cfg.Log.WithFields(log.Fields{ + "prefix": "proxy.WebhookEventProcessor.ProcessEvent", + "webhook_id": webhookEvent.WebhookID, + "webhook_conversation_id": webhookEvent.WebhookConversationID, + "event_id": evt.ID, + "event_type": evt.Type, + "api_version": getAPIVersionString(webhookEvent.Endpoint.APIVersion), + }).Trace("Webhook event trace") + + // at this point the message is valid so we can acknowledge it + ackMessage := websocket.NewEventAck(webhookEvent.WebhookID, webhookEvent.WebhookConversationID) + p.sendMessage(ackMessage) + + if p.filterWebhookEvent(webhookEvent) { + return + } + + evtCtx := eventContext{ + webhookID: webhookEvent.WebhookID, + webhookConversationID: webhookEvent.WebhookConversationID, + event: &evt, + } + + if p.events["*"] || p.events[evt.Type] { + p.cfg.OutCh <- websocket.DataElement{ + Data: evt, + Marshaled: formatOutput(outputFormatJSON, webhookEvent.EventPayload), + } + + for _, endpoint := range p.endpointClients { + if endpoint.SupportsEventType(evt.IsConnect(), evt.Type) { + // TODO: handle errors returned by endpointClients + go endpoint.Post( + evtCtx, + webhookEvent.EventPayload, + webhookEvent.HTTPHeaders, + ) + } + } + } +} + +func (p *WebhookEventProcessor) filterWebhookEvent(msg *websocket.WebhookEvent) bool { + if msg.Endpoint.APIVersion != nil && !p.cfg.UseLatestAPIVersion { + p.cfg.Log.WithFields(log.Fields{ + "prefix": "proxy.WebhookEventProcessor.filterWebhookEvent", + "api_version": getAPIVersionString(msg.Endpoint.APIVersion), + }).Debugf("Received event with non-default API version, ignoring") + + return true + } + + if msg.Endpoint.APIVersion == nil && p.cfg.UseLatestAPIVersion { + p.cfg.Log.WithFields(log.Fields{ + "prefix": "proxy.WebhookEventProcessor.filterWebhookEvent", + }).Debugf("Received event with default API version, ignoring") + + return true + } + + return false +} + +func (p *WebhookEventProcessor) processEndpointResponse(evtCtx eventContext, forwardURL string, resp *http.Response) { + buf, err := io.ReadAll(resp.Body) + if err != nil { + p.cfg.OutCh <- websocket.ErrorElement{ + Error: FailedToReadResponseError{Err: err}, + } + return + } + + body := truncate(string(buf), maxBodySize, true) + + p.cfg.OutCh <- websocket.DataElement{ + Data: EndpointResponse{ + Event: evtCtx.event, + Resp: resp, + }, + } + + idx := 0 + headers := make(map[string]string) + + for k, v := range resp.Header { + headers[truncate(k, maxHeaderKeySize, false)] = truncate(v[0], maxHeaderValueSize, true) + idx++ + + if idx > maxNumHeaders { + break + } + } + + msg := websocket.NewWebhookResponse( + evtCtx.webhookID, + evtCtx.webhookConversationID, + forwardURL, + resp.StatusCode, + body, + headers, + ) + p.sendMessage(msg) +}