Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Add proper CORS handling by the HTTP API
Browse files Browse the repository at this point in the history
A flag can be set by the user to specify the allowed
origins to access the HTTP API. A wrapper handler
sets the proper CORS headers if the flag is enabled.
  • Loading branch information
Blagoj Atanasovski authored and cevian committed Aug 7, 2020
1 parent 6693256 commit 98b9f0d
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 68 deletions.
52 changes: 36 additions & 16 deletions cmd/timescale-prometheus/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,23 @@ import (
"net/http"
pprof "net/http/pprof"
"os"
"regexp"
"sync/atomic"
"time"

"github.com/prometheus/common/route"

"github.com/jackc/pgx/v4/pgxpool"
_ "github.com/jackc/pgx/v4/stdlib"

"github.com/jamiealquiza/envy"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/timescale/timescale-prometheus/pkg/api"
"github.com/timescale/timescale-prometheus/pkg/log"
"github.com/timescale/timescale-prometheus/pkg/pgclient"
"github.com/timescale/timescale-prometheus/pkg/pgmodel"
"github.com/timescale/timescale-prometheus/pkg/util"

"github.com/jamiealquiza/envy"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/timescale/timescale-prometheus/pkg/query"
"github.com/timescale/timescale-prometheus/pkg/util"
)

type config struct {
Expand All @@ -44,6 +42,7 @@ type config struct {
prometheusTimeout time.Duration
electionInterval time.Duration
migrate bool
corsOrigin *regexp.Regexp
}

const (
Expand Down Expand Up @@ -148,8 +147,12 @@ func init() {
}

func main() {
cfg := parseFlags()
err := log.Init(cfg.logLevel)
cfg, err := parseFlags()
if err != nil {
fmt.Println("Version: ", Version, "Commit Hash: ", CommitHash)
fmt.Println("Fatal error: cannot parse flags ", err)
}
err = log.Init(cfg.logLevel)
if err != nil {
fmt.Println("Version: ", Version, "Commit Hash: ", CommitHash)
fmt.Println("Fatal error: cannot start logger", err)
Expand Down Expand Up @@ -231,7 +234,6 @@ func main() {
prometheus.MustRegister(labelsCacheCap)

router := route.New()

promMetrics := api.Metrics{
LeaderGauge: leaderGauge,
ReceivedSamples: receivedSamples,
Expand All @@ -253,21 +255,22 @@ func main() {
router.Get("/read", readHandler)
router.Post("/read", readHandler)

apiConf := &api.Config{AllowedOrigin: cfg.corsOrigin}
queryable := client.GetQueryable()
queryEngine := query.NewEngine(log.GetLogger(), time.Minute)
queryHandler := timeHandler(httpRequestDuration, "query", api.Query(queryEngine, queryable))
queryHandler := timeHandler(httpRequestDuration, "query", api.Query(apiConf, queryEngine, queryable))
router.Get("/api/v1/query", queryHandler)
router.Post("/api/v1/query", queryHandler)

queryRangeHandler := timeHandler(httpRequestDuration, "query_range", api.QueryRange(queryEngine, queryable))
queryRangeHandler := timeHandler(httpRequestDuration, "query_range", api.QueryRange(apiConf, queryEngine, queryable))
router.Get("/api/v1/query_range", queryRangeHandler)
router.Post("/api/v1/query_range", queryRangeHandler)

labelsHandler := timeHandler(httpRequestDuration, "labels", api.Labels(queryable))
labelsHandler := timeHandler(httpRequestDuration, "labels", api.Labels(apiConf, queryable))
router.Get("/api/v1/labels", labelsHandler)
router.Post("/api/v1/labels", labelsHandler)

labelValuesHandler := timeHandler(httpRequestDuration, "label/:name/values", api.LabelValues(queryable))
labelValuesHandler := timeHandler(httpRequestDuration, "label/:name/values", api.LabelValues(apiConf, queryable))
router.Get("/api/v1/label/:name/values", labelValuesHandler)

router.Get("/healthz", api.Health(client))
Expand All @@ -291,14 +294,23 @@ func main() {
}
}

func parseFlags() *config {
func parseFlags() (*config, error) {

cfg := &config{}

pgclient.ParseFlags(&cfg.pgmodelCfg)

flag.StringVar(&cfg.listenAddr, "web-listen-address", ":9201", "Address to listen on for web endpoints.")
flag.StringVar(&cfg.telemetryPath, "web-telemetry-path", "/metrics", "Address to listen on for web endpoints.")

var corsOriginFlag string
flag.StringVar(&corsOriginFlag, "web-cors-origin", ".*", `Regex for CORS origin. It is fully anchored. Example: 'https?://(domain1|domain2)\.com'`)
corsOriginRegex, err := compileAnchoredRegexString(corsOriginFlag)
if err != nil {
err = fmt.Errorf("could not compile CORS regex string %v: %w", corsOriginFlag, err)
return nil, err
}
cfg.corsOrigin = corsOriginRegex
flag.StringVar(&cfg.logLevel, "log-level", "debug", "The log level to use [ \"error\", \"warn\", \"info\", \"debug\" ].")
flag.IntVar(&cfg.haGroupLockID, "leader-election-pg-advisory-lock-id", 0, "Unique advisory lock id per adapter high-availability group. Set it if you want to use leader election implementation based on PostgreSQL advisory lock.")
flag.DurationVar(&cfg.prometheusTimeout, "leader-election-pg-advisory-lock-prometheus-timeout", -1, "Adapter will resign if there are no requests from Prometheus within a given timeout (0 means no timeout). "+
Expand All @@ -309,7 +321,7 @@ func parseFlags() *config {
envy.Parse("TS_PROM")
flag.Parse()

return cfg
return cfg, nil
}

func initElector(cfg *config) (*util.Elector, error) {
Expand Down Expand Up @@ -388,3 +400,11 @@ func timeHandler(histogramVec prometheus.ObserverVec, path string, handler http.
histogramVec.WithLabelValues(path).Observe(float64(elapsedMs))
}
}

func compileAnchoredRegexString(s string) (*regexp.Regexp, error) {
r, err := regexp.Compile("^(?:" + s + ")$")
if err != nil {
return nil, err
}
return r, nil
}
21 changes: 18 additions & 3 deletions pkg/api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import (
"io"
"math"
"net/http"
"regexp"
"strconv"
"time"

"github.com/pkg/errors"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/promql/parser"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/util/httputil"
"github.com/timescale/timescale-prometheus/pkg/promql"
)

Expand All @@ -24,8 +26,21 @@ var (
maxTimeFormatted = maxTime.Format(time.RFC3339Nano)
)

func setHeaders(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
w.Header().Set("Access-Control-Allow-Origin", "*")
type Config struct {
AllowedOrigin *regexp.Regexp
}

func corsWrapper(conf *Config, f http.HandlerFunc) http.HandlerFunc {
if conf.AllowedOrigin == nil {
return f
}
return func(w http.ResponseWriter, r *http.Request) {
httputil.SetCORS(w, conf.AllowedOrigin, r)
f(w, r)
}
}

func setResponseHeaders(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
w.Header().Set("Content-Type", "application/json")
if warnings != nil && len(warnings) > 0 {
w.Header().Set("Cache-Control", "no-store")
Expand All @@ -38,7 +53,7 @@ func setHeaders(w http.ResponseWriter, res *promql.Result, warnings storage.Warn
}

func respondQuery(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
setHeaders(w, res, warnings)
setResponseHeaders(w, res, warnings)
switch resVal := res.Value.(type) {
case promql.Vector:
warnings := make([]string, 0, len(res.Warnings))
Expand Down
108 changes: 108 additions & 0 deletions pkg/api/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package api

import (
"context"
"github.com/timescale/timescale-prometheus/pkg/log"
"net/http"
"net/http/httptest"
"reflect"
"regexp"
"testing"
)

func TestCORSWrapper(t *testing.T) {
_ = log.Init("debug")
acceptSpecific, _ := regexp.Compile("^(?:" + "http://some-site.com" + ")$")
acceptAny, _ := regexp.Compile("^(?:" + ".*" + ")$")

testCases := []struct {
name string
requestOrigin string
acceptedOrigin *regexp.Regexp
expectHeaders map[string][]string
}{
{
name: "No origin",
requestOrigin: "",
acceptedOrigin: acceptSpecific,
expectHeaders: map[string][]string{},
}, {
name: "Origin doesn't match accepted",
requestOrigin: "http://some-unknown-site.com",
acceptedOrigin: acceptSpecific,
expectHeaders: map[string][]string{
"Access-Control-Allow-Headers": {"Accept, Authorization, Content-Type, Origin"},
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
"Access-Control-Expose-Headers": {"Date"},
"Vary": {"Origin"},
},
},
{
name: "Origin matches accepted",
requestOrigin: "http://some-site.com",
acceptedOrigin: acceptSpecific,
expectHeaders: map[string][]string{
"Access-Control-Allow-Headers": {"Accept, Authorization, Content-Type, Origin"},
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
"Access-Control-Expose-Headers": {"Date"},
"Access-Control-Allow-Origin": {"http://some-site.com"},
"Vary": {"Origin"},
},
}, {
name: "Wildcard allowed origin",
requestOrigin: "http://any-site.com",
acceptedOrigin: acceptAny,
expectHeaders: map[string][]string{
"Access-Control-Allow-Headers": {"Accept, Authorization, Content-Type, Origin"},
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
"Access-Control-Expose-Headers": {"Date"},
"Access-Control-Allow-Origin": {"*"},
"Vary": {"Origin"},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := &Config{}
if tc.acceptedOrigin != nil {
conf.AllowedOrigin = tc.acceptedOrigin
} else {
tc.acceptedOrigin = &regexp.Regexp{}
}
internalHandlerCalled := false
handler := corsWrapper(conf, func(http.ResponseWriter, *http.Request) {
internalHandlerCalled = true
})
w := doCORSWrapperRequest(t, handler, "http://localhost/", tc.requestOrigin)
if !internalHandlerCalled {
t.Fatalf("internal handler not called by CORS wrapper")
return
}
returnedHeaders := w.Header()
if len(returnedHeaders) != len(tc.expectHeaders) {
t.Fatalf("expected %d headers, got %d", len(tc.expectHeaders), len(returnedHeaders))
return
}
for hName, hValues := range tc.expectHeaders {
returnedValues := returnedHeaders[hName]
if !reflect.DeepEqual(hValues, returnedValues) {
t.Errorf("expected header %s with value %v; got %v", hName, hValues, returnedValues)
}
}
})

}

}

func doCORSWrapperRequest(t *testing.T, queryHandler http.Handler, url, origin string) *httptest.ResponseRecorder {
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
if err != nil {
t.Errorf("%v", err)
}

req.Header.Set("Origin", origin)
w := httptest.NewRecorder()
queryHandler.ServeHTTP(w, req)
return w
}
20 changes: 12 additions & 8 deletions pkg/api/label_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@ package api
import (
"context"
"fmt"
"math"
"net/http"

"github.com/NYTimes/gziphandler"
"github.com/prometheus/common/model"
"github.com/prometheus/common/route"
"github.com/timescale/timescale-prometheus/pkg/promql"
"github.com/timescale/timescale-prometheus/pkg/query"
"math"
"net/http"
)

func LabelValues(queriable *query.Queryable) http.Handler {
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func LabelValues(conf *Config, queryable *query.Queryable) http.Handler {
hf := corsWrapper(conf, labelValues(queryable))
return gziphandler.GzipHandler(hf)
}

func labelValues(queryable *query.Queryable) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := route.Param(ctx, "name")
if !model.LabelNameRE.MatchString(name) {
respondError(w, http.StatusBadRequest, fmt.Errorf("invalid label name: %s", name), "bad_data")
return
}
querier, err := queriable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
querier, err := queryable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
if err != nil {
respondError(w, http.StatusInternalServerError, err, "internal")
return
Expand All @@ -35,7 +41,5 @@ func LabelValues(queriable *query.Queryable) http.Handler {
respondLabels(w, &promql.Result{
Value: values,
}, warnings)
})

return gziphandler.GzipHandler(hf)
}
}
24 changes: 14 additions & 10 deletions pkg/api/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package api
import (
"context"
"encoding/json"
"math"
"net/http"
"strings"

"github.com/NYTimes/gziphandler"
"github.com/prometheus/prometheus/promql/parser"
"github.com/prometheus/prometheus/storage"
"github.com/timescale/timescale-prometheus/pkg/promql"
"github.com/timescale/timescale-prometheus/pkg/query"
"math"
"net/http"
"strings"
)

type labelsValue []string
Expand All @@ -23,9 +24,14 @@ func (l labelsValue) String() string {
return strings.Join(l, "\n")
}

func Labels(queriable *query.Queryable) http.Handler {
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
querier, err := queriable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
func Labels(conf *Config, queryable *query.Queryable) http.Handler {
hf := corsWrapper(conf, labelsHandler(queryable))
return gziphandler.GzipHandler(hf)
}

func labelsHandler(queryable *query.Queryable) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
querier, err := queryable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
if err != nil {
respondError(w, http.StatusInternalServerError, err, "internal")
return
Expand All @@ -39,13 +45,11 @@ func Labels(queriable *query.Queryable) http.Handler {
respondLabels(w, &promql.Result{
Value: names,
}, warnings)
})

return gziphandler.GzipHandler(hf)
}
}

func respondLabels(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
setHeaders(w, res, warnings)
setResponseHeaders(w, res, warnings)
resp := &response{
Status: "success",
Data: res.Value,
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/labels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestLabels(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := Labels(query.NewQueryable(tc.querier))
handler := labelsHandler(query.NewQueryable(tc.querier))
w := doLabels(t, handler)

if w.Code != tc.expectCode {
Expand Down

0 comments on commit 98b9f0d

Please sign in to comment.