Skip to content

Commit

Permalink
Merge pull request #189 from stripe/sergeyrud-move-custom-request-han…
Browse files Browse the repository at this point in the history
…dler-after-the-main-check

Move the custom request handler call after the main ACL check
  • Loading branch information
kevinv-stripe committed May 31, 2023
2 parents c227b0d + 445d0d6 commit 8dd3072
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
3 changes: 2 additions & 1 deletion pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ type Config struct {

// Custom handler for users to allow running code per requests, users can pass in custom methods to verify requests based
// on headers, code for metrics etc.
// If smokescreen denies a request, this handler is not called.
// If the handler returns an error, smokescreen will deny the request.
CustomRequestHandler func(*http.Request) error
PostDecisionRequestHandler func(*http.Request) error
}

type missingRoleError struct {
Expand Down
38 changes: 18 additions & 20 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,6 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
}

sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")

// Call the custom request handler if it exists
if config.CustomRequestHandler != nil {
err = config.CustomRequestHandler(req)
if err != nil {
pctx.Error = denyError{err}
return req, rejectResponse(pctx, pctx.Error)
}
}

sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination)

// Returning any kind of response in this handler is goproxy's way of short circuiting
Expand All @@ -499,6 +489,15 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
return req, rejectResponse(pctx, denyError{errors.New(sctx.decision.reason)})
}

// Call the custom request handler if it exists
if config.PostDecisionRequestHandler != nil {
err = config.PostDecisionRequestHandler(req)
if err != nil {
pctx.Error = denyError{err}
return req, rejectResponse(pctx, pctx.Error)
}
}

// Proceed with proxying the request
return req, nil
})
Expand Down Expand Up @@ -621,16 +620,6 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) {
pctx.Error = denyError{err}
return "", pctx.Error
}

// Call the custom request handler if it exists
if config.CustomRequestHandler != nil {
err = config.CustomRequestHandler(pctx.Req)
if err != nil {
pctx.Error = denyError{err}
return "", pctx.Error
}
}

sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination)
if pctx.Error != nil {
return "", denyError{pctx.Error}
Expand All @@ -639,6 +628,15 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) {
return "", denyError{errors.New(sctx.decision.reason)}
}

// Call the custom request handler if it exists
if config.PostDecisionRequestHandler != nil {
err = config.PostDecisionRequestHandler(pctx.Req)
if err != nil {
pctx.Error = denyError{err}
return "", pctx.Error
}
}

return destination.String(), nil
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ func TestCustomRequestHandler(t *testing.T) {
return nil
}

t.Run("CustomRequestHandler works for HTTPS", func(t *testing.T) {
t.Run("PostDecisionRequestHandler works for HTTPS", func(t *testing.T) {
testCases := []struct {
header http.Header
expectedError bool
Expand All @@ -1088,7 +1088,7 @@ func TestCustomRequestHandler(t *testing.T) {
r.NoError(err)
err = cfg.SetAllowAddresses([]string{"127.0.0.1"})
r.NoError(err)
cfg.CustomRequestHandler = customRequestHandler
cfg.PostDecisionRequestHandler = customRequestHandler

l, err := net.Listen("tcp", "localhost:0")
r.NoError(err)
Expand Down Expand Up @@ -1119,7 +1119,7 @@ func TestCustomRequestHandler(t *testing.T) {
}
})

t.Run("CustomRequestHandler works for HTTP", func(t *testing.T) {
t.Run("PostDecisionRequestHandler works for HTTP", func(t *testing.T) {
testCases := []struct {
header string
expectedError bool
Expand All @@ -1137,7 +1137,7 @@ func TestCustomRequestHandler(t *testing.T) {
r.NoError(err)
err = cfg.SetAllowAddresses([]string{"127.0.0.1"})
r.NoError(err)
cfg.CustomRequestHandler = customRequestHandler
cfg.PostDecisionRequestHandler = customRequestHandler

l, err := net.Listen("tcp", "localhost:0")
r.NoError(err)
Expand Down

0 comments on commit 8dd3072

Please sign in to comment.