Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract WebhookEventProcessor from Proxy #1066

Merged
merged 5 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 22 additions & 176 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package proxy

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -80,8 +79,6 @@ type Config struct {
// UseConfiguredWebhooks loads webhooks config from user's account
UseConfiguredWebhooks bool

// EndpointsRoutes is a mapping of local webhook endpoint urls to the events they consume
EndpointRoutes []EndpointRoute
// List of events to listen and proxy
Events []string

Expand Down Expand Up @@ -114,12 +111,9 @@ type Config struct {
type Proxy struct {
cfg *Config

endpointClients []*EndpointClient
stripeAuthClient *stripeauth.Client
webSocketClient *websocket.Client

// Events is the supported event types for the command
events map[string]bool
stripeAuthClient *stripeauth.Client
webSocketClient *websocket.Client
webhookEventProcessor *WebhookEventProcessor
}

const maxConnectAttempts = 3
Expand All @@ -133,6 +127,12 @@ func (p *Proxy) IsConnected() <-chan struct{} {
return p.webSocketClient.Connected()
}

func (p *Proxy) sendMessage(msg *websocket.OutgoingMessage) {
if p.webSocketClient != nil {
p.webSocketClient.SendMessage(msg)
}
}

// Run sets the websocket connection and starts the Goroutines to forward
// incoming events to the local endpoint.
func (p *Proxy) Run(ctx context.Context) error {
Expand Down Expand Up @@ -162,7 +162,7 @@ func (p *Proxy) Run(ctx context.Context) error {
Log: p.cfg.Log,
NoWSS: p.cfg.NoWSS,
ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second,
EventHandler: websocket.EventHandlerFunc(p.processWebhookEvent),
EventHandler: p.webhookEventProcessor,
},
)

Expand Down Expand Up @@ -223,7 +223,6 @@ func GetSessionSecret(ctx context.Context, client stripe.RequestPerformer, devic
p, err := Init(ctx, &Config{
Client: client,
DeviceName: deviceName,
EndpointRoutes: make([]EndpointRoute, 0),
WebSocketFeature: "webhooks",
})
if err != nil {
Expand Down Expand Up @@ -282,35 +281,13 @@ func (p *Proxy) createSession(ctx context.Context) (*stripeauth.StripeCLISession
return session, err
}

func (p *Proxy) filterWebhookEvent(msg *websocket.WebhookEvent) bool {
if msg.Endpoint.APIVersion != nil && !p.cfg.UseLatestAPIVersion {
p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.filterWebhookEvent",
"api_version": getAPIVersionString(msg.Endpoint.APIVersion),
}).Debugf("Received event with non-default API version, ignoring")

return true
}

if msg.Endpoint.APIVersion == nil && p.cfg.UseLatestAPIVersion {
p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.filterWebhookEvent",
}).Debugf("Received event with default API version, ignoring")

return true
}

return false
}

// This function outputs the event payload in the format specified.
// Currently only supports JSON.
func (p *Proxy) formatOutput(format string, eventPayload string) string {
func formatOutput(format string, eventPayload string) string {
var event map[string]interface{}
err := json.Unmarshal([]byte(eventPayload), &event)
if err != nil {
p.cfg.Log.Debug("Received malformed event from Stripe, ignoring")
return fmt.Sprint(err)
return fmt.Sprintf("Received malformed event: %s", err)
}
switch strings.ToUpper(format) {
// The distinction between this and PrintJSON is that this output is stripped of all pretty format.
Expand All @@ -322,122 +299,6 @@ func (p *Proxy) formatOutput(format string, eventPayload string) string {
}
}

func (p *Proxy) processWebhookEvent(msg websocket.IncomingMessage) {
if msg.WebhookEvent == nil {
p.cfg.Log.Debug("WebSocket specified for Webhooks received non-webhook event")
return
}

webhookEvent := msg.WebhookEvent

p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.processWebhookEvent",
"webhook_id": webhookEvent.WebhookID,
"webhook_converesation_id": webhookEvent.WebhookConversationID,
}).Debugf("Processing webhook event")

var evt StripeEvent

err := json.Unmarshal([]byte(webhookEvent.EventPayload), &evt)
if err != nil {
p.cfg.Log.Debug("Received malformed event from Stripe, ignoring")
return
}

req, err := ExtractRequestData(evt.RequestData)

if err != nil {
p.cfg.Log.Debug("Received malformed event from Stripe, ignoring")
return
}

evt.Request = req

p.cfg.Log.WithFields(log.Fields{
"prefix": "proxy.Proxy.processWebhookEvent",
"webhook_id": webhookEvent.WebhookID,
"webhook_conversation_id": webhookEvent.WebhookConversationID,
"event_id": evt.ID,
"event_type": evt.Type,
"api_version": getAPIVersionString(msg.Endpoint.APIVersion),
}).Trace("Webhook event trace")

// at this point the message is valid so we can acknowledge it
ackMessage := websocket.NewEventAck(webhookEvent.WebhookID, webhookEvent.WebhookConversationID)
p.webSocketClient.SendMessage(ackMessage)

if p.filterWebhookEvent(webhookEvent) {
return
}

evtCtx := eventContext{
webhookID: webhookEvent.WebhookID,
webhookConversationID: webhookEvent.WebhookConversationID,
event: &evt,
}

if p.events["*"] || p.events[evt.Type] {
p.cfg.OutCh <- websocket.DataElement{
Data: evt,
Marshaled: p.formatOutput(outputFormatJSON, webhookEvent.EventPayload),
}

for _, endpoint := range p.endpointClients {
if endpoint.SupportsEventType(evt.IsConnect(), evt.Type) {
// TODO: handle errors returned by endpointClients
go endpoint.Post(
evtCtx,
webhookEvent.EventPayload,
webhookEvent.HTTPHeaders,
)
}
}
}
}

func (p *Proxy) processEndpointResponse(evtCtx eventContext, forwardURL string, resp *http.Response) {
buf, err := io.ReadAll(resp.Body)
if err != nil {
p.cfg.OutCh <- websocket.ErrorElement{
Error: FailedToReadResponseError{Err: err},
}
return
}

body := truncate(string(buf), maxBodySize, true)

p.cfg.OutCh <- websocket.DataElement{
Data: EndpointResponse{
Event: evtCtx.event,
Resp: resp,
},
}

idx := 0
headers := make(map[string]string)

for k, v := range resp.Header {
headers[truncate(k, maxHeaderKeySize, false)] = truncate(v[0], maxHeaderValueSize, true)
idx++

if idx > maxNumHeaders {
break
}
}

if p.webSocketClient != nil {
msg := websocket.NewWebhookResponse(
evtCtx.webhookID,
evtCtx.webhookConversationID,
forwardURL,
resp.StatusCode,
body,
headers,
)
p.webSocketClient.SendMessage(msg)
}
}

//
// Public functions
//
Expand Down Expand Up @@ -514,37 +375,22 @@ func Init(ctx context.Context, cfg *Config) (*Proxy, error) {
}
}

processorConfig := &WebhookEventProcessorConfig{
Log: cfg.Log,
Events: cfg.Events,
OutCh: cfg.OutCh,
UseLatestAPIVersion: cfg.UseLatestAPIVersion,
SkipVerify: cfg.SkipVerify,
Timeout: cfg.Timeout,
}

p := &Proxy{
cfg: cfg,
stripeAuthClient: stripeauth.NewClient(cfg.Client, &stripeauth.Config{
Log: cfg.Log,
}),
events: convertToMap(cfg.Events),
}

for _, route := range endpointRoutes {
// append to endpointClients
p.endpointClients = append(p.endpointClients, NewEndpointClient(
route.URL,
route.ForwardHeaders,
route.Connect,
route.EventTypes,
&EndpointConfig{
HTTPClient: &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: time.Duration(cfg.Timeout) * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.SkipVerify},
},
},
Log: p.cfg.Log,
ResponseHandler: EndpointResponseHandlerFunc(p.processEndpointResponse),
OutCh: p.cfg.OutCh,
},
))
}
p.webhookEventProcessor = NewWebhookEventProcessor(p.sendMessage, endpointRoutes, processorConfig)

return p, nil
}
Expand Down
34 changes: 17 additions & 17 deletions pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func TestFilterWebhookEvent(t *testing.T) {
},
}

require.False(t, proxyUseDefault.filterWebhookEvent(evtDefault))
require.True(t, proxyUseDefault.filterWebhookEvent(evtLatest))
require.False(t, proxyUseDefault.webhookEventProcessor.filterWebhookEvent(evtDefault))
require.True(t, proxyUseDefault.webhookEventProcessor.filterWebhookEvent(evtLatest))

require.True(t, proxyUseLatest.filterWebhookEvent(evtDefault))
require.False(t, proxyUseLatest.filterWebhookEvent(evtLatest))
require.True(t, proxyUseLatest.webhookEventProcessor.filterWebhookEvent(evtDefault))
require.False(t, proxyUseLatest.webhookEventProcessor.filterWebhookEvent(evtLatest))
}

func TestTruncate(t *testing.T) {
Expand Down Expand Up @@ -147,11 +147,11 @@ func TestForwardToOnly(t *testing.T) {
}
p, err := Init(context.Background(), &cfg)
require.NoError(t, err)
require.Equal(t, 2, len(p.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.endpointClients[0].URL)
require.EqualValues(t, false, p.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242", p.endpointClients[1].URL)
require.EqualValues(t, true, p.endpointClients[1].connect)
require.Equal(t, 2, len(p.webhookEventProcessor.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[0].URL)
require.EqualValues(t, false, p.webhookEventProcessor.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[1].URL)
require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[1].connect)
}

func TestForwardConnectToOnly(t *testing.T) {
Expand All @@ -161,9 +161,9 @@ func TestForwardConnectToOnly(t *testing.T) {
}
p, err := Init(context.Background(), &cfg)
require.NoError(t, err)
require.Equal(t, 1, len(p.endpointClients))
require.EqualValues(t, "http://localhost:4242/connect", p.endpointClients[0].URL)
require.EqualValues(t, true, p.endpointClients[0].connect)
require.Equal(t, 1, len(p.webhookEventProcessor.endpointClients))
require.EqualValues(t, "http://localhost:4242/connect", p.webhookEventProcessor.endpointClients[0].URL)
require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[0].connect)
}

func TestForwardToAndForwardConnectTo(t *testing.T) {
Expand All @@ -173,11 +173,11 @@ func TestForwardToAndForwardConnectTo(t *testing.T) {
}
p, err := Init(context.Background(), &cfg)
require.NoError(t, err)
require.Equal(t, 2, len(p.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.endpointClients[0].URL)
require.EqualValues(t, false, p.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242/connect", p.endpointClients[1].URL)
require.EqualValues(t, true, p.endpointClients[1].connect)
require.Equal(t, 2, len(p.webhookEventProcessor.endpointClients))
require.EqualValues(t, "http://localhost:4242", p.webhookEventProcessor.endpointClients[0].URL)
require.EqualValues(t, false, p.webhookEventProcessor.endpointClients[0].connect)
require.EqualValues(t, "http://localhost:4242/connect", p.webhookEventProcessor.endpointClients[1].URL)
require.EqualValues(t, true, p.webhookEventProcessor.endpointClients[1].connect)
}

func TestExtractRequestData(t *testing.T) {
Expand Down
Loading
Loading