diff --git a/.gitignore b/.gitignore index 054270c28..0627a45b2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ bazel-bin bazel-out bazel-testlogs bazel-zoekt + +.DS_STORE diff --git a/internal/tenant/rest_helper.go b/internal/tenant/rest_helper.go new file mode 100644 index 000000000..e5809a8cb --- /dev/null +++ b/internal/tenant/rest_helper.go @@ -0,0 +1,24 @@ +package tenant + +import ( + "context" + "fmt" + "net/http" + + "github.com/sourcegraph/zoekt/internal/tenant/internal/tenanttype" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func InjectTenantFromHeader(ctx context.Context, header http.Header) (context.Context, error) { + tenantID := header.Get("X-TENANT-ID") // TODO: we don't use headerKeyTenantID here so we don't need to change it and potentially break future grpc changes + if tenantID != "" { + tenant, err := tenanttype.Unmarshal(tenantID) + if err != nil { + return ctx, status.New(codes.InvalidArgument, fmt.Errorf("bad tenant value in header: %w", err).Error()).Err() + } + + return tenanttype.WithTenant(ctx, tenant), nil + } + return ctx, nil +} diff --git a/json/json.go b/json/json.go index 80b47348d..c4a0d1cf6 100644 --- a/json/json.go +++ b/json/json.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sourcegraph/zoekt" + "github.com/sourcegraph/zoekt/internal/tenant" "github.com/sourcegraph/zoekt/query" ) @@ -85,6 +86,12 @@ func (s *jsonSearcher) jsonSearch(w http.ResponseWriter, req *http.Request) { defer cancel() } + ctx, err = tenant.InjectTenantFromHeader(ctx, req.Header) + if err != nil { + jsonError(w, http.StatusBadRequest, err.Error()) + return + } + if err := CalculateDefaultSearchLimits(ctx, q, s.Searcher, searchArgs.Opts); err != nil { jsonError(w, http.StatusInternalServerError, err.Error()) return @@ -146,6 +153,7 @@ func CalculateDefaultSearchLimits(ctx context.Context, } func (s *jsonSearcher) jsonList(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() w.Header().Add("Content-Type", "application/json") if req.Method != "POST" { @@ -166,7 +174,13 @@ func (s *jsonSearcher) jsonList(w http.ResponseWriter, req *http.Request) { return } - listResult, err := s.Searcher.List(req.Context(), query, listArgs.Opts) + ctx, err = tenant.InjectTenantFromHeader(ctx, req.Header) + if err != nil { + jsonError(w, http.StatusBadRequest, err.Error()) + return + } + + listResult, err := s.Searcher.List(ctx, query, listArgs.Opts) if err != nil { jsonError(w, http.StatusInternalServerError, err.Error()) return