Skip to content

Commit

Permalink
Extract WebhookEventProcessor from Proxy (#1066)
Browse files Browse the repository at this point in the history
* Refactor webhook event processing into WebhookEventProcessor

* Use specific type for API version

* Update log lines to match refactored types
  • Loading branch information
bernerd-stripe committed May 3, 2023
1 parent b1cfe3a commit d13c261
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 193 deletions.
198 changes: 22 additions & 176 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package proxy

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -80,8 +79,6 @@ type Config struct {
// UseConfiguredWebhooks loads webhooks config from user's account
UseConfiguredWebhooks bool

// EndpointsRoutes is a mapping of local webhook endpoint urls to the events they consume
EndpointRoutes []EndpointRoute
// List of events to listen and proxy
Events []string

Expand Down Expand Up @@ -114,12 +111,9 @@ type Config struct {
type Proxy struct {
cfg *Config

endpointClients []*EndpointClient
stripeAuthClient *stripeauth.Client
webSocketClient *websocket.Client

// Events is the supported event types for the command
events map[string]bool
stripeAuthClient *stripeauth.Client
webSocketClient *websocket.Client
webhookEventProcessor *WebhookEventProcessor
}

const maxConnectAttempts = 3
Expand All @@ -133,6 +127,12 @@ func (p *Proxy) IsConnected() <-chan struct{} {
return p.webSocketClient.Connected()
}

func (p *Proxy) sendMessage(msg *websocket.OutgoingMessage) {
if p.webSocketClient != nil {
p.webSocketClient.SendMessage(msg)
}
}

// Run sets the websocket connection and starts the Goroutines to forward
// incoming events to the local endpoint.
func (p *Proxy) Run(ctx context.Context) error {
Expand Down Expand Up @@ -162,7 +162,7 @@ func (p *Proxy) Run(ctx context.Context) error {
Log: p.cfg.Log,
NoWSS: p.cfg.NoWSS,
ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second,
EventHandler: websocket.EventHandlerFunc(p.processWebhookEvent),
EventHandler: p.webhookEventProcessor,
},
)

Expand Down Expand Up @@ -223,7 +223,6 @@ func GetSessionSecret(ctx context.Context, client stripe.RequestPerformer, devic
p, err := Init(ctx, &Config{
Client: client,
DeviceName: deviceName,
EndpointRoutes: make([]EndpointRoute, 0),
WebSocketFeature: "webhooks",
})
if err != nil {
Expand Down Expand Up @@ -282,35 +281,13 @@ func (p *Proxy) createSession(ctx context.Context) (*stripeauth.StripeCLISession
return session, err
}

func (p *Proxy) filterWebhookEvent(msg *websocket.WebhookEvent) bool {
if msg.Endpoint.APIVersion != nil && !p.cfg.UseLatestAPIVersion {
p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.filterWebhookEvent",
"api_version": getAPIVersionString(msg.Endpoint.APIVersion),
}).Debugf("Received event with non-default API version, ignoring")

return true
}

if msg.Endpoint.APIVersion == nil && p.cfg.UseLatestAPIVersion {
p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.filterWebhookEvent",
}).Debugf("Received event with default API version, ignoring")

return true
}

return false
}

// This function outputs the event payload in the format specified.
// Currently only supports JSON.
func (p *Proxy) formatOutput(format string, eventPayload string) string {
func formatOutput(format string, eventPayload string) string {
var event map[string]interface{}
err := json.Unmarshal([]byte(eventPayload), &event)
if err != nil {
p.cfg.Log.Debug("Received malformed event from Stripe, ignoring")
return fmt.Sprint(err)
return fmt.Sprintf("Received malformed event: %s", err)
}
switch strings.ToUpper(format) {
// The distinction between this and PrintJSON is that this output is stripped of all pretty format.
Expand All @@ -322,122 +299,6 @@ func (p *Proxy) formatOutput(format string, eventPayload string) string {
}
}

func (p *Proxy) processWebhookEvent(msg websocket.IncomingMessage) {
if msg.WebhookEvent == nil {
p.cfg.Log.Debug("WebSocket specified for Webhooks received non-webhook event")
return
}

webhookEvent := msg.WebhookEvent

p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.processWebhookEvent",
"webhook_id": webhookEvent.WebhookID,
"webhook_converesation_id": webhookEvent.WebhookConversationID,
}).Debugf("Processing webhook event")

var evt StripeEvent

err := json.Unmarshal([]byte(webhookEvent.EventPayload), &evt)
if err != nil {
p.cfg.Log.Debug("Received malformed event from Stripe, ignoring")
return
}

req, err := ExtractRequestData(evt.RequestData)

if err != nil {
p.cfg.Log.Debug("Received malformed event from Stripe, ignoring")
return
}

evt.Request = req

p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.processWebhookEvent",
"webhook_id": webhookEvent.WebhookID,
"webhook_conversation_id": webhookEvent.WebhookConversationID,
"event_id": evt.ID,
"event_type": evt.Type,
"api_version": getAPIVersionString(msg.Endpoint.APIVersion),
}).Trace("Webhook event trace")

// at this point the message is valid so we can acknowledge it
ackMessage := websocket.NewEventAck(webhookEvent.WebhookID, webhookEvent.WebhookConversationID)
p.webSocketClient.SendMessage(ackMessage)

if p.filterWebhookEvent(webhookEvent) {
return
}

evtCtx := eventContext{
webhookID: webhookEvent.WebhookID,
webhookConversationID: webhookEvent.WebhookConversationID,
event: &evt,
}

if p.events["*"] || p.events[evt.Type] {
p.cfg.OutCh <- websocket.DataElement{
Data: evt,
Marshaled: p.formatOutput(outputFormatJSON, webhookEvent.EventPayload),
}

for _, endpoint := range p.endpointClients {
if endpoint.SupportsEventType(evt.IsConnect(), evt.Type) {
// TODO: handle errors returned by endpointClients
go endpoint.Post(
evtCtx,
webhookEvent.EventPayload,
webhookEvent.HTTPHeaders,
)
}
}
}
}

func (p *Proxy) processEndpointResponse(evtCtx eventContext, forwardURL string, resp *http.Response) {
buf, err := io.ReadAll(resp.Body)
if err != nil {
p.cfg.OutCh <- websocket.ErrorElement{
Error: FailedToReadResponseError{Err: err},
}
return
}

body := truncate(string(buf), maxBodySize, true)

p.cfg.OutCh <- websocket.DataElement{
Data: EndpointResponse{
Event: evtCtx.event,
Resp: resp,
},
}

idx := 0
headers := make(map[string]string)

for k, v := range resp.Header {
headers[truncate(k, maxHeaderKeySize, false)] = truncate(v[0], maxHeaderValueSize, true)
idx++

if idx > maxNumHeaders {
break
}
}

if p.webSocketClient != nil {
msg := websocket.NewWebhookResponse(
evtCtx.webhookID,
evtCtx.webhookConversationID,
forwardURL,
resp.StatusCode,
body,
headers,
)
p.webSocketClient.SendMessage(msg)
}
}

//
// Public functions
//
Expand Down Expand Up @@ -514,37 +375,22 @@ func Init(ctx context.Context, cfg *Config) (*Proxy, error) {
}
}

processorConfig := &WebhookEventProcessorConfig{
Log: cfg.Log,
Events: cfg.Events,
OutCh: cfg.OutCh,
UseLatestAPIVersion: cfg.UseLatestAPIVersion,
SkipVerify: cfg.SkipVerify,
Timeout: cfg.Timeout,
}

p := &Proxy{
cfg: cfg,
stripeAuthClient: stripeauth.NewClient(cfg.Client, &stripeauth.Config{
Log: cfg.Log,
}),
events: convertToMap(cfg.Events),
}

for _, route := range endpointRoutes {
// append to endpointClients
p.endpointClients = append(p.endpointClients, NewEndpointClient(
route.URL,
route.ForwardHeaders,
route.Connect,
route.EventTypes,
&EndpointConfig{
HTTPClient: &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: time.Duration(cfg.Timeout) * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.SkipVerify},
},
},
Log: p.cfg.Log,
ResponseHandler: EndpointResponseHandlerFunc(p.processEndpointResponse),
OutCh: p.cfg.OutCh,
},
))
}
p.webhookEventProcessor = NewWebhookEventProcessor(p.sendMessage, endpointRoutes, processorConfig)

return p, nil
}
Expand Down
34 changes: 17 additions & 17 deletions pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func TestFilterWebhookEvent(t *testing.T) {
},
}

require.False(t, proxyUseDefault.filterWebhookEvent(evtDefault))
require.True(t, proxyUseDefault.filterWebhookEvent(evtLatest))
require.False(t, proxyUseDefault.webhookEventProcessor.filterWebhookEvent(evtDefault))
require.True(t, proxyUseDefault.webhookEventProcessor.filterWebhookEvent(evtLatest))

require.True(t, proxyUseLatest.filterWebhookEvent(evtDefault))
require.False(t, proxyUseLatest.filterWebhookEvent(evtLatest))
require.True(t, proxyUseLatest.webhookEventProcessor.filterWebhookEvent(evtDefault))
require.False(t, proxyUseLatest.webhookEventProcessor.filterWebhookEvent(evtLatest))
}

func TestTruncate(t *testing.T) {
Expand Down Expand Up @@ -147,11 +147,11 @@ func TestForwardToOnly(t *testing.T) {
}
p, err := Init(context.Background(), &cfg)
require.NoError(t, err)
require.Equal(t, 2, len(p.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.endpointClients[0].URL)
require.EqualValues(t, false, p.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242", p.endpointClients[1].URL)
require.EqualValues(t, true, p.endpointClients[1].connect)
require.Equal(t, 2, len(p.webhookEventProcessor.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[0].URL)
require.EqualValues(t, false, p.webhookEventProcessor.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[1].URL)
require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[1].connect)
}

func TestForwardConnectToOnly(t *testing.T) {
Expand All @@ -161,9 +161,9 @@ func TestForwardConnectToOnly(t *testing.T) {
}
p, err := Init(context.Background(), &cfg)
require.NoError(t, err)
require.Equal(t, 1, len(p.endpointClients))
require.EqualValues(t, "http://localhost:4242/connect", p.endpointClients[0].URL)
require.EqualValues(t, true, p.endpointClients[0].connect)
require.Equal(t, 1, len(p.webhookEventProcessor.endpointClients))
require.EqualValues(t, "http://localhost:4242/connect", p.webhookEventProcessor.endpointClients[0].URL)
require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[0].connect)
}

func TestForwardToAndForwardConnectTo(t *testing.T) {
Expand All @@ -173,11 +173,11 @@ func TestForwardToAndForwardConnectTo(t *testing.T) {
}
p, err := Init(context.Background(), &cfg)
require.NoError(t, err)
require.Equal(t, 2, len(p.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.endpointClients[0].URL)
require.EqualValues(t, false, p.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242/connect", p.endpointClients[1].URL)
require.EqualValues(t, true, p.endpointClients[1].connect)
require.Equal(t, 2, len(p.webhookEventProcessor.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[0].URL)
require.EqualValues(t, false, p.webhookEventProcessor.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242/connect", p.webhookEventProcessor.endpointClients[1].URL)
require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[1].connect)
}

func TestExtractRequestData(t *testing.T) {
Expand Down
Loading

0 comments on commit d13c261

Please sign in to comment.