Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Validate Request Headers
Browse files Browse the repository at this point in the history
Right now our validation consists of trying to deserialize requests into
the format we expect, and panicking if we cannot do so. Though this does
not bring down the connector as different connections are isolated from
each other, this has the downside that an invalid request can still e.g.
write to the database, if it happens to look like a valid request.

This commit adds logic to validate the headers prometheus sends on
remote requests so we can detect earlier if the data is invalid.
  • Loading branch information
JLockerman committed Aug 31, 2020
1 parent 9839127 commit 4ad81ac
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 25 deletions.
18 changes: 18 additions & 0 deletions cmd/timescale-prometheus/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ var (
Help: "Total number of queries which failed on send to remote storage.",
},
)
invalidReadReqs = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: util.PromNamespace,
Name: "invalid_read_requests",
Help: "Total number of remote read requests with invalid metadata.",
},
)
invalidWriteReqs = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: util.PromNamespace,
Name: "invalid_write_requests",
Help: "Total number of remote write requests with invalid metadata.",
},
)
sentBatchDuration = prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: util.PromNamespace,
Expand Down Expand Up @@ -140,6 +154,8 @@ func init() {
prometheus.MustRegister(sentSamples)
prometheus.MustRegister(failedSamples)
prometheus.MustRegister(failedQueries)
prometheus.MustRegister(invalidReadReqs)
prometheus.MustRegister(invalidWriteReqs)
prometheus.MustRegister(sentBatchDuration)
prometheus.MustRegister(queryBatchDuration)
prometheus.MustRegister(httpRequestDuration)
Expand Down Expand Up @@ -248,6 +264,8 @@ func main() {
ReceivedQueries: receivedQueries,
CachedMetricNames: cachedMetricNames,
CachedLabels: cachedLabels,
InvalidReadReqs: invalidReadReqs,
InvalidWriteReqs: invalidWriteReqs,
}
writeHandler := timeHandler(httpRequestDuration, "write", api.Write(client, elector, &promMetrics))
router.Post("/write", writeHandler)
Expand Down
23 changes: 21 additions & 2 deletions pkg/api/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package api

import (
"fmt"
"github.com/timescale/timescale-prometheus/pkg/log"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"

"github.com/timescale/timescale-prometheus/pkg/log"
)

var (
Expand Down Expand Up @@ -50,7 +53,7 @@ func TestHealth(t *testing.T) {

healthHandle := Health(mock)

test := GenerateHandleTester(t, healthHandle)
test := GenerateHealthHandleTester(t, healthHandle)
w := test("GET", strings.NewReader(""))

if w.Code != c.httpStatus {
Expand All @@ -72,3 +75,19 @@ func TestHealth(t *testing.T) {
})
}
}

func GenerateHealthHandleTester(t *testing.T, handleFunc http.Handler) HandleTester {
return func(method string, body io.Reader) *httptest.ResponseRecorder {
req, err := http.NewRequest(method, "", body)
if err != nil {
t.Errorf("%v", err)
}
req.Header.Set(
"Content-Type",
"application/x-www-form-urlencoded; param=value",
)
w := httptest.NewRecorder()
handleFunc.ServeHTTP(w, req)
return w
}
}
2 changes: 2 additions & 0 deletions pkg/api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ type Metrics struct {
QueryBatchDuration prometheus.Histogram
CachedMetricNames prometheus.CounterFunc
CachedLabels prometheus.CounterFunc
InvalidReadReqs prometheus.Counter
InvalidWriteReqs prometheus.Counter
}
50 changes: 46 additions & 4 deletions pkg/api/read.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
package api

import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"

"github.com/gogo/protobuf/proto"
"github.com/golang/snappy"
"github.com/timescale/timescale-prometheus/pkg/log"
"github.com/timescale/timescale-prometheus/pkg/pgmodel"
"github.com/timescale/timescale-prometheus/pkg/prompb"
"io/ioutil"
"net/http"
"time"
)

func Read(reader pgmodel.Reader, metrics *Metrics) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !validateReadHeaders(w, r) {
metrics.InvalidReadReqs.Inc()
return
}

compressed, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Error("msg", "Read error", "err", err.Error())
log.Error("msg", "Read header validation error", "err", err.Error())
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -68,3 +76,37 @@ func Read(reader pgmodel.Reader, metrics *Metrics) http.Handler {
}
})
}

func validateReadHeaders(w http.ResponseWriter, r *http.Request) bool {
// validate headers from https://github.com/prometheus/prometheus/blob/2bd077ed9724548b6a631b6ddba48928704b5c34/storage/remote/client.go
if r.Method != "POST" {
buildReadError(w, fmt.Sprintf("HTTP Method %s instead of POST", r.Method))
return false
}

if !strings.Contains(r.Header.Get("Content-Encoding"), "snappy") {
buildReadError(w, fmt.Sprintf("non-snappy compressed data got: %s", r.Header.Get("Content-Encoding")))
return false
}

if r.Header.Get("Content-Type") != "application/x-protobuf" {
buildReadError(w, "non-protobuf data")
return false
}

remoteReadVersion := r.Header.Get("X-Prometheus-Remote-Read-Version")
if remoteReadVersion == "" {
err := "missing X-Prometheus-Remote-Read-Version"
log.Warn("msg", "Read header validation error", "err", err)
} else if !strings.HasPrefix(remoteReadVersion, "0.1.") {
buildReadError(w, fmt.Sprintf("unexpected Remote-Read-Version %s, expected 0.1.X", remoteReadVersion))
return false
}

return true
}

func buildReadError(w http.ResponseWriter, err string) {
log.Error("msg", "Read header validation error", "err", err)
http.Error(w, err, http.StatusBadRequest)
}
36 changes: 32 additions & 4 deletions pkg/api/read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package api

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/gogo/protobuf/proto"
"github.com/golang/snappy"
"github.com/timescale/timescale-prometheus/pkg/prompb"
"net/http"
"testing"
)

func TestRead(t *testing.T) {
Expand All @@ -27,6 +30,11 @@ func TestRead(t *testing.T) {
responseCode: http.StatusBadRequest,
requestBody: "123",
},
{
name: "bad header",
responseCode: http.StatusBadRequest,
requestBody: "123",
},
{
name: "malformed read request",
responseCode: http.StatusBadRequest,
Expand Down Expand Up @@ -68,16 +76,18 @@ func TestRead(t *testing.T) {
receivedQueriesCounter := &mockMetric{}
queryDurationHist := &mockMetric{}
failedQueriesCounter := &mockMetric{}
invalidReadReqs := &mockMetric{}
metrics := &Metrics{
QueryBatchDuration: queryDurationHist,
FailedQueries: failedQueriesCounter,
ReceivedQueries: receivedQueriesCounter,
InvalidReadReqs: invalidReadReqs,
}
handler := Read(mockReader, metrics)

test := GenerateHandleTester(t, handler)
test := GenerateReadHandleTester(t, handler, c.name == "bad header")

w := test("GET", getReader(c.requestBody))
w := test("POST", getReader(c.requestBody))

if w.Code != c.responseCode {
t.Errorf("Unexpected HTTP status code received: got %d wanted %d", w.Code, c.responseCode)
Expand Down Expand Up @@ -111,3 +121,21 @@ func (m *mockReader) Read(r *prompb.ReadRequest) (*prompb.ReadResponse, error) {
m.request = r
return m.response, m.err
}

func GenerateReadHandleTester(t *testing.T, handleFunc http.Handler, badHeader bool) HandleTester {
return func(method string, body io.Reader) *httptest.ResponseRecorder {
req, err := http.NewRequest(method, "", body)
if err != nil {
t.Errorf("%v", err)
}
if !badHeader {
req.Header.Add("Content-Encoding", "snappy")
req.Header.Set("Content-Type", "application/x-protobuf")
req.Header.Set("X-Prometheus-Remote-Read-Version", "0.1.0")
}

w := httptest.NewRecorder()
handleFunc.ServeHTTP(w, req)
return w
}
}
45 changes: 45 additions & 0 deletions pkg/api/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"sync/atomic"
"time"

Expand All @@ -18,6 +19,14 @@ import (

func Write(writer pgmodel.DBInserter, elector *util.Elector, metrics *Metrics) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// we treat invalid requests as the same as no request for
// leadership-timeout purposes
if !validateWriteHeaders(w, r) {
metrics.InvalidWriteReqs.Inc()
return
}

shouldWrite, err := isWriter(elector)
if err != nil {
metrics.LeaderGauge.Set(0)
Expand Down Expand Up @@ -104,3 +113,39 @@ func getCounterValue(counter prometheus.Counter) float64 {
}
return dtoMetric.GetCounter().GetValue()
}

func validateWriteHeaders(w http.ResponseWriter, r *http.Request) bool {
// validate headers from https://github.com/prometheus/prometheus/blob/2bd077ed9724548b6a631b6ddba48928704b5c34/storage/remote/client.go
if r.Method != "POST" {
buildWriteError(w, fmt.Sprintf("HTTP Method %s instead of POST", r.Method))
return false
}

if !strings.Contains(r.Header.Get("Content-Encoding"), "snappy") {
buildWriteError(w, fmt.Sprintf("non-snappy compressed data got: %s", r.Header.Get("Content-Encoding")))
return false
}

if r.Header.Get("Content-Type") != "application/x-protobuf" {
buildWriteError(w, "non-protobuf data")
return false
}

remoteWriteVersion := r.Header.Get("X-Prometheus-Remote-Write-Version")
if remoteWriteVersion == "" {
buildWriteError(w, "Missing X-Prometheus-Remote-Write-Version header")
return false
}

if !strings.HasPrefix(remoteWriteVersion, "0.1.") {
buildWriteError(w, fmt.Sprintf("unexpected Remote-Write-Version %s, expected 0.1.X", remoteWriteVersion))
return false
}

return true
}

func buildWriteError(w http.ResponseWriter, err string) {
log.Error("msg", "Write header validation error", "err", err)
http.Error(w, err, http.StatusBadRequest)
}
38 changes: 24 additions & 14 deletions pkg/api/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ package api

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"testing/iotest"
"time"

"github.com/gogo/protobuf/proto"
"github.com/golang/snappy"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -10,13 +18,6 @@ import (
"github.com/timescale/timescale-prometheus/pkg/log"
"github.com/timescale/timescale-prometheus/pkg/prompb"
"github.com/timescale/timescale-prometheus/pkg/util"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"testing/iotest"
"time"
)

func TestWrite(t *testing.T) {
Expand Down Expand Up @@ -47,6 +48,12 @@ func TestWrite(t *testing.T) {
responseCode: http.StatusBadRequest,
requestBody: string(snappy.Encode(nil, []byte("test"))),
},
{
name: "bad header",
isLeader: false,
responseCode: http.StatusBadRequest,
requestBody: string(snappy.Encode(nil, []byte("test"))),
},
{
name: "write error",
isLeader: true,
Expand Down Expand Up @@ -93,6 +100,7 @@ func TestWrite(t *testing.T) {
failedSamplesGauge := &mockMetric{}
sentSamplesGauge := &mockMetric{}
sendBatchHistogram := &mockMetric{}
invalidWriteReqs := &mockMetric{}
mock := &mockInserter{
result: c.inserterResponse,
err: c.inserterErr,
Expand All @@ -104,12 +112,13 @@ func TestWrite(t *testing.T) {
FailedSamples: failedSamplesGauge,
SentSamples: sentSamplesGauge,
SentBatchDuration: sendBatchHistogram,
InvalidWriteReqs: invalidWriteReqs,
WriteThroughput: util.NewThroughputCalc(time.Second),
})

test := GenerateHandleTester(t, handler)
test := GenerateWriteHandleTester(t, handler, c.name == "bad header")

w := test("GET", getReader(c.requestBody))
w := test("POST", getReader(c.requestBody))

if w.Code != c.responseCode {
t.Errorf("Unexpected HTTP status code received: got %d wanted %d", w.Code, c.responseCode)
Expand Down Expand Up @@ -144,16 +153,17 @@ func writeRequestToString(r *prompb.WriteRequest) string {

type HandleTester func(method string, body io.Reader) *httptest.ResponseRecorder

func GenerateHandleTester(t *testing.T, handleFunc http.Handler) HandleTester {
func GenerateWriteHandleTester(t *testing.T, handleFunc http.Handler, badHeaders bool) HandleTester {
return func(method string, body io.Reader) *httptest.ResponseRecorder {
req, err := http.NewRequest(method, "", body)
if err != nil {
t.Errorf("%v", err)
}
req.Header.Set(
"Content-Type",
"application/x-www-form-urlencoded; param=value",
)
if !badHeaders {
req.Header.Add("Content-Encoding", "snappy")
req.Header.Set("Content-Type", "application/x-protobuf")
req.Header.Set("X-Prometheus-Remote-Write-Version", "0.1.0")
}
w := httptest.NewRecorder()
handleFunc.ServeHTTP(w, req)
return w
Expand Down

0 comments on commit 4ad81ac

Please sign in to comment.