Skip to content

Commit

Permalink
Added support for custom mutual TLS header checking
Browse files Browse the repository at this point in the history
  • Loading branch information
robiball committed Oct 12, 2018
1 parent 2d6486b commit d690a8b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 25 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func main() {
handlers.Init(sw)
log.Info("Connected to Slack API")
// Start a server to respond to callbacks from Slack
s := server.NewSlackHandler("/slack", appToken, signingSecret, log.Error, log.Errorf)
s := server.NewSlackHandler("/slack", appToken, signingSecret, nil, log.Error, log.Errorf)
s.HandleCommand("/help-me", handlers.HelpRequest)
s.HandleCallback("dialog_submission", "HelpRequest", handlers.HelpCallback)
addr := viper.GetString("listen-address")
Expand Down
31 changes: 24 additions & 7 deletions server/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/mitchellh/mapstructure"
"github.com/nlopes/slack"
"regexp"
)

// Request wraps http.Request
Expand All @@ -25,25 +26,41 @@ type Request struct {
}

// Validate the request comes from Slack
func (r *Request) Validate(secret string) error {
func (r *Request) Validate(secret string, dnHeader *string) error {
// If a dnHeader has been provided, check that the header contains the slack CN
if dnHeader != nil {
slackDNHeader := r.Header.Get(*dnHeader)
dnError := fmt.Errorf("invalid CN in DN header")

r, _ := regexp.Compile("CN=(.*?),")
cn := r.FindStringSubmatch(slackDNHeader)
if len(cn) != 2 { // It should match the CN exactly one, and contain the CN value as a group
return dnError
}

if cn[1] != "platform-tls-client.slack.com" {
return dnError
}
}

slackTimestampHeader := r.Header.Get("X-Slack-Request-Timestamp")
slackTimestamp, err := strconv.ParseInt(slackTimestampHeader, 10, 64)

// Abort if timestamp is invalid
if err != nil {
return fmt.Errorf("Invalid timestamp sent from slack: %s", err)
return fmt.Errorf("invalid timestamp sent from slack: %s", err)
}

// Abort if timestamp is stale (older than 5 minutes)
now := int64(time.Now().Unix())
if (now - slackTimestamp) > (60 * 5) {
return fmt.Errorf("Stale timestamp sent from slack: %s", err)
return fmt.Errorf("stale timestamp sent from slack: %s", err)
}

// Abort if request body is invalid
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return fmt.Errorf("Invalid request body sent from slack: %s", err)
return fmt.Errorf("invalid request body sent from slack: %s", err)
}
slackBody := string(body)

Expand All @@ -54,7 +71,7 @@ func (r *Request) Validate(secret string) error {
sec.Write(slackBaseStr)
mySig := fmt.Sprintf("v0=%s", []byte(hex.EncodeToString(sec.Sum(nil))))
if mySig != slackSignature {
return errors.New("Invalid signature sent from slack")
return errors.New("invalid signature sent from slack")
}
// All good! The request is valid
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
Expand All @@ -78,10 +95,10 @@ func (r *Request) parsePayload() error {
var payload CallbackPayload
j := r.Form.Get("payload")
if j == "" {
return errors.New("Empty payload")
return errors.New("empty payload")
}
if err := json.Unmarshal([]byte(j), &payload); err != nil {
return fmt.Errorf("Error parsing payload JSON: %s", err)
return fmt.Errorf("error parsing payload JSON: %s", err)
}
r.payload = payload
return nil
Expand Down
6 changes: 4 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ type SlackHandler struct {
basePath string
appToken string
secretToken string
dnHeader *string // Used for Mutual TLS
}

// NewSlackHandler returns an initialised SlackHandler
func NewSlackHandler(basePath, appToken, secretToken string, l LogFunc, lf LogfFunc) *SlackHandler {
func NewSlackHandler(basePath, appToken, secretToken string, dnHeader *string, l LogFunc, lf LogfFunc) *SlackHandler {
return &SlackHandler{
DefaultRoute: func(res *Response, req *Request, ctx interface{}) error {
res.Text(http.StatusNotFound, "Not found")
Expand All @@ -48,6 +49,7 @@ func NewSlackHandler(basePath, appToken, secretToken string, l LogFunc, lf LogfF
basePath: basePath,
appToken: appToken,
secretToken: secretToken,
dnHeader: dnHeader,
}
}

Expand Down Expand Up @@ -90,7 +92,7 @@ func (h *SlackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// If the request did not look like it came from slack, 400 and abort
if err := req.Validate(h.secretToken); err != nil {
if err := req.Validate(h.secretToken, h.dnHeader); err != nil {
h.Logf("Bad request from slack: %s", err)
res.Text(400, "invalid slack request")
return
Expand Down
56 changes: 41 additions & 15 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

var (
slackSecret = "fake_secret"
dnHeader = "dummy-dn"
basePath = "/slack"
errString string
logf = func(msg string, i ...interface{}) {
Expand All @@ -41,6 +42,9 @@ func addSlackHeaders(body string, r *http.Request) {
h.Write(slackBaseStr)
mySig := fmt.Sprintf("v0=%s", []byte(hex.EncodeToString(h.Sum(nil))))
r.Header.Set("X-Slack-Signature", mySig)

// Add dummy mutual TLS header
r.Header.Set(dnHeader, "CN=platform-tls-client.slack.com,O=Slack Technologies")
}

func performGenericRequest(raw, path string, s *SlackHandler) *http.Response {
Expand All @@ -65,7 +69,7 @@ func TestMatchSlashCommand(t *testing.T) {
}
return nil
}
s := NewSlackHandler(basePath, "TOKEN", slackSecret, log, logf)
s := NewSlackHandler(basePath, "TOKEN", slackSecret, &dnHeader, log, logf)
s.HandleCommand("/bob-test", h)
resp := performGenericRequest(raw, basePath, s)

Expand All @@ -87,7 +91,7 @@ func TestUnmatchedSlashCommand(t *testing.T) {
}
return nil
}
s := NewSlackHandler(basePath, "TOKEN", slackSecret, log, logf)
s := NewSlackHandler(basePath, "TOKEN", slackSecret, &dnHeader, log, logf)
s.HandleCommand("/foobar", h)
resp := performGenericRequest(raw, basePath, s)

Expand All @@ -110,7 +114,7 @@ func TestDialogSubmissionEvent(t *testing.T) {
}
return nil
}
s := NewSlackHandler(basePath, "TOKEN", slackSecret, log, logf)
s := NewSlackHandler(basePath, "TOKEN", slackSecret, &dnHeader, log, logf)
s.HandleCallback("dialog_submission", "employee_offsite_1138b", h)
resp := performGenericRequest(raw, basePath, s)

Expand All @@ -131,7 +135,7 @@ func TestMalformedActionEvent(t *testing.T) {
"Fail on invalid JSON",
"payload=ssion%22%3A%20%7B%22name%22%3A%20%22Sigourney%20Dreamweaver%22%2C%22email%22%3A%20%22sigdre%40example.com%22%2C%22phone%22%3A%20%22%2B1%20800-555-1212%22%2C%22meal%22%3A%20%22burrito%22%2C%22comment%22%3A%20%22No%20sour%20cream%20please%22%2C%22team_channel%22%3A%20%22C0LFFBKPB%22%2C%22who_should_sing%22%3A%20%22U0MJRG1AL%22%7D%2C%22callback_id%22%3A%20%22employee_offsite_1138b%22%2C%22team%22%3A%20%7B%22id%22%3A%20%22T1ABCD2E12%22%2C%22domain%22%3A%20%22coverbands%22%7D%2C%22user%22%3A%20%7B%22id%22%3A%20%22W12A3BCDEF%22%2C%22name%22%3A%20%22dreamweaver%22%7D%2C%22channel%22%3A%20%7B%22id%22%3A%20%22C1AB2C3DE%22%2C%22name%22%3A%20%22coverthon-1999%22%7D%2C%22action_ts%22%3A%20%22936893340.702759%22%2C%22token%22%3A%20%22TOKEN%22%2C%22response_url%22%3A%20%22https%3A%2F%2Fhooks.slack.com%2Fapp%2FT012AB0A1%2F123456789%2FJpmK0yzoZDeRiqfeduTBYXWQ%22",
404,
"Error parsing payload: Error parsing payload JSON: invalid character 's' looking for beginning of value",
"Error parsing payload: error parsing payload JSON: invalid character 's' looking for beginning of value",
},
{
"Fail on missing value for 'type'",
Expand All @@ -149,7 +153,7 @@ func TestMalformedActionEvent(t *testing.T) {

for _, tc := range tt {
t.Run(tc.name, func(T *testing.T) {
s := NewSlackHandler(basePath, "TOKEN", slackSecret, log, logf)
s := NewSlackHandler(basePath, "TOKEN", slackSecret, &dnHeader, log, logf)
resp := performGenericRequest(tc.raw, basePath, s)
if resp.StatusCode != tc.sCode {
t.Errorf("Expected a %d status. Got '%d'", tc.sCode, resp.StatusCode)
Expand All @@ -165,7 +169,7 @@ func TestMatchPath(t *testing.T) {
h := func(res *Response, req *Request, ctx interface{}) error {
return nil
}
s := NewSlackHandler("/slack", "TOKEN", slackSecret, log, logf)
s := NewSlackHandler("/slack", "TOKEN", slackSecret, &dnHeader, log, logf)
s.HandlePath("/foo", h)
raw := "foo=bar"
resp := performGenericRequest(raw, "/foo", s)
Expand All @@ -178,20 +182,20 @@ func TestMatchPath(t *testing.T) {

func TestHandlerErrors(t *testing.T) {
h := func(res *Response, req *Request, ctx interface{}) error {
return fmt.Errorf("Serious problem")
return fmt.Errorf("serious problem")
}
s := NewSlackHandler("/slack", "TOKEN", slackSecret, log, logf)
s := NewSlackHandler("/slack", "TOKEN", slackSecret, &dnHeader, log, logf)
s.HandlePath("/foo", h)
raw := "foo=bar"
performGenericRequest(raw, "/foo", s)

if errString != "HTTP handler error: Serious problem" {
if errString != "HTTP handler error: serious problem" {
t.Fatalf("Unexpected error string: %s", errString)
}
}

func TestMissingTimestamp(t *testing.T) {
s := NewSlackHandler("/slack", "TOKEN", slackSecret, log, logf)
s := NewSlackHandler("/slack", "TOKEN", slackSecret, nil, log, logf)
s.HandlePath("/foo", nil)
req := httptest.NewRequest("POST", "/foo", nil)
w := httptest.NewRecorder()
Expand All @@ -201,13 +205,13 @@ func TestMissingTimestamp(t *testing.T) {
if resp.StatusCode != 400 {
t.Fatalf("Expected a 400 status. Got '%d'", resp.StatusCode)
}
if !strings.HasPrefix(errString, "Bad request from slack: Invalid timestamp sent from slack") {
if !strings.HasPrefix(errString, "Bad request from slack: invalid timestamp sent from slack") {
t.Fatalf("Unexpected error string: %s", errString)
}
}

func TestStaleTimestamp(t *testing.T) {
s := NewSlackHandler("/slack", "TOKEN", slackSecret, log, logf)
s := NewSlackHandler("/slack", "TOKEN", slackSecret, nil, log, logf)
s.HandlePath("/foo", nil)
req := httptest.NewRequest("POST", "/foo", nil)

Expand All @@ -223,13 +227,13 @@ func TestStaleTimestamp(t *testing.T) {
if resp.StatusCode != 400 {
t.Fatalf("Expected a 400 status. Got '%d'", resp.StatusCode)
}
if !strings.HasPrefix(errString, "Bad request from slack: Stale timestamp sent from slack") {
if !strings.HasPrefix(errString, "Bad request from slack: stale timestamp sent from slack") {
t.Fatalf("Unexpected error string: %s", errString)
}
}

func TestInvalidSecret(t *testing.T) {
s := NewSlackHandler("/slack", "TOKEN", "bad_secret", log, logf)
s := NewSlackHandler("/slack", "TOKEN", "bad_secret", &dnHeader, log, logf)
s.HandlePath("/foo", nil)
raw := "text"
body := bytes.NewBufferString(raw)
Expand All @@ -243,7 +247,29 @@ func TestInvalidSecret(t *testing.T) {
if resp.StatusCode != 400 {
t.Fatalf("Expected a 400 status. Got '%d'", resp.StatusCode)
}
if !strings.HasPrefix(errString, "Bad request from slack: Invalid signature sent from slack") {
if !strings.HasPrefix(errString, "Bad request from slack: invalid signature sent from slack") {
t.Fatalf("Unexpected error string: %s", errString)
}
}

func TestInvalidDN(t *testing.T) {
dnHeader := "slack-dn"
s := NewSlackHandler("/slack", "TOKEN", "bad_secret", &dnHeader, log, logf)
s.HandlePath("/foo", nil)

req := httptest.NewRequest("POST", "/foo", nil)
req.Header.Set(dnHeader, "not.slack.com")

w := httptest.NewRecorder()
s.ServeHTTP(w, req)

resp := w.Result()
if resp.StatusCode != 400 {
t.Fatalf("Expected a 400 status. Got '%d'", resp.StatusCode)
}
if !strings.HasPrefix(errString, "Bad request from slack: invalid CN in DN header") {
t.Fatalf("Unexpected error string: %s", errString)
}
}


0 comments on commit d690a8b

Please sign in to comment.