diff --git a/internal/api/api.go b/internal/api/api.go index 00b3fafed9..c8009b08b5 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -10,13 +10,16 @@ import ( "io/ioutil" "net/http" "os" + "regexp" "strings" + "github.com/Masterminds/semver" "github.com/hashicorp/go-multierror" "github.com/jig/teereadcloser" "github.com/kballard/go-shellquote" "github.com/mattn/go-isatty" "github.com/pkg/errors" + "github.com/sourcegraph/codeintelutils" ) // Client instances provide methods to create API requests. @@ -53,7 +56,8 @@ type Request interface { // client is the internal concrete type implementing Client. type client struct { - opts ClientOpts + opts ClientOpts + supportsGzip *bool } // request is the internal concrete type implementing Request. @@ -113,6 +117,39 @@ func (c *client) NewRequest(query string, vars map[string]interface{}) Request { } func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) { + if c.supportsGzip == nil { + // set to false, unless we have a new enough version + supportsGzip := false + c.supportsGzip = &supportsGzip + + version, err := c.getSourcegraphVersion(ctx) + + // ignore errors; we only care if the version is sufficently new + if err == nil { + supportsGzip, err = sourcegraphVersionCheck(version, ">= 3.21.0", "2020-10-12") + if err == nil { + c.supportsGzip = &supportsGzip + } + } + } + + if *c.supportsGzip && body != nil { + body = codeintelutils.Gzip(body) + } + + req, err := c.createHTTPRequest(ctx, method, p, body) + if err != nil { + return nil, err + } + + if *c.supportsGzip { + req.Header.Set("Content-Encoding", "gzip") + } + + return req, nil +} + +func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(c.opts.Endpoint, "/")+"/"+p, body) if err != nil { return nil, err @@ -126,6 +163,7 @@ func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.R for k, v := range c.opts.AdditionalHeaders { req.Header.Set(k, v) } + return req, nil } @@ -265,3 +303,80 @@ func (r *request) curlCmd() (string, error) { s += fmt.Sprintf(" %s", shellquote.Join(r.client.opts.Endpoint+"/.api/graphql")) return s, nil } + +const sourcegraphVersionQuery = `query SourcegraphVersion { + site { + productVersion + } + } + ` + +func (c *client) getSourcegraphVersion(ctx context.Context) (string, error) { + var sourcegraphVersion struct { + Data struct { + Site struct { + ProductVersion string + } + } + } + + // Create the JSON object. + reqBody, err := json.Marshal(map[string]interface{}{ + "query": sourcegraphVersionQuery, + }) + if err != nil { + return "", err + } + + // Create the HTTP request. + req, err := c.createHTTPRequest(ctx, "POST", ".api/graphql", bytes.NewBuffer(reqBody)) + if err != nil { + return "", err + } + + // Perform the request. + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("checking sourcegraph backend version; got status code %d", resp.StatusCode) + } + + err = json.Unmarshal(respBytes, &sourcegraphVersion) + if err != nil { + return "", err + } + + return sourcegraphVersion.Data.Site.ProductVersion, err +} + +func sourcegraphVersionCheck(version, constraint, minDate string) (bool, error) { + if version == "dev" || version == "0.0.0+dev" { + return true, nil + } + + buildDate := regexp.MustCompile(`^\d+_(\d{4}-\d{2}-\d{2})_[a-z0-9]{7}$`) + matches := buildDate.FindStringSubmatch(version) + if len(matches) > 1 { + return matches[1] >= minDate, nil + } + + c, err := semver.NewConstraint(constraint) + if err != nil { + return false, nil + } + + v, err := semver.NewVersion(version) + if err != nil { + return false, err + } + return c.Check(v), nil +}