Skip to content
4 changes: 4 additions & 0 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ func (d *indexData) simplify(in query.Q) query.Q {
return d.simplifyMultiRepo(q, func(repo *Repository) bool {
return r.Set[repo.Name]
})
case *query.RepoIDs:
return d.simplifyMultiRepo(q, func(repo *Repository) bool {
return r.Repos.Contains(repo.ID)
})
case *query.Language:
_, has := d.metaData.LanguageMap[r.Language]
if !has && d.metaData.IndexFeatureVersion < 12 {
Expand Down
24 changes: 24 additions & 0 deletions eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,30 @@ func TestSimplifyRepoSet(t *testing.T) {
}
}

func TestSimplifyRepoIDs(t *testing.T) {
d := compoundReposShard(t, "foo", "bar")
all := &query.RepoIDs{Repos: roaring.BitmapOf(hash("foo"), hash("bar"))}
some := &query.RepoIDs{Repos: roaring.BitmapOf(hash("foo"), hash("banana"))}
none := &query.RepoIDs{Repos: roaring.BitmapOf(hash("banana"))}

tr := cmp.Transformer("", func(b *roaring.Bitmap) []uint32 { return b.ToArray() })

got := d.simplify(all)
if d := cmp.Diff(&query.Const{Value: true}, got, tr); d != "" {
t.Fatalf("-want, +got:\n%s", d)
}

got = d.simplify(some)
if d := cmp.Diff(some, got, tr); d != "" {
t.Fatalf("-want, +got:\n%s", d)
}

got = d.simplify(none)
if d := cmp.Diff(&query.Const{Value: false}, got); d != "" {
t.Fatalf("-want, +got:\n%s", d)
}
}

func TestSimplifyRepo(t *testing.T) {
re := func(pat string) *query.Repo {
t.Helper()
Expand Down
15 changes: 10 additions & 5 deletions json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ type jsonSearcher struct {
}

type jsonSearchArgs struct {
Q string
Opts *zoekt.SearchOptions
Q string
RepoIDs *[]uint32
Opts *zoekt.SearchOptions
}

type jsonSearchReply struct {
Expand Down Expand Up @@ -67,25 +68,29 @@ func (s *jsonSearcher) jsonSearch(w http.ResponseWriter, req *http.Request) {
searchArgs.Opts = &zoekt.SearchOptions{}
}

query, err := query.Parse(searchArgs.Q)
q, err := query.Parse(searchArgs.Q)
if err != nil {
jsonError(w, http.StatusBadRequest, err.Error())
return
}

if searchArgs.RepoIDs != nil {
q = query.NewAnd(q, query.NewRepoIDs(*searchArgs.RepoIDs...))
}

// Set a timeout if the user hasn't specified one.
if searchArgs.Opts.MaxWallTime == 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
defer cancel()
}

if err := CalculateDefaultSearchLimits(ctx, query, s.Searcher, searchArgs.Opts); err != nil {
if err := CalculateDefaultSearchLimits(ctx, q, s.Searcher, searchArgs.Opts); err != nil {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}

searchResult, err := s.Searcher.Search(ctx, query, searchArgs.Opts)
searchResult, err := s.Searcher.Search(ctx, q, searchArgs.Opts)
if err != nil {
jsonError(w, http.StatusInternalServerError, err.Error())
return
Expand Down
74 changes: 74 additions & 0 deletions json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,80 @@ func TestClientServer(t *testing.T) {
}
}

func TestClientServerWithRepoIDsProvided(t *testing.T) {
searchQuery := "hello"
expectedSearch := mustParse(searchQuery)
expectedSearch = query.NewAnd(expectedSearch, query.NewRepoIDs(1, 3, 5, 7))
mock := &mockSearcher.MockSearcher{
WantSearch: expectedSearch,
SearchResult: &zoekt.SearchResult{
Files: []zoekt.FileMatch{
{FileName: "bin.go"},
},
},
}

ts := httptest.NewServer(zjson.JSONServer(mock))
defer ts.Close()

searchBody := "{\"Q\":\"hello\",\"RepoIDs\":[1,3,5,7]}"

r, err := http.Post(ts.URL+"/search", "application/json", bytes.NewBufferString(searchBody))
if err != nil {
t.Fatal(err)
}
if r.StatusCode != 200 {
body, _ := io.ReadAll(r.Body)
t.Fatalf("Got status code %d, err %s", r.StatusCode, string(body))
}

var searchResult struct{ Result *zoekt.SearchResult }
err = json.NewDecoder(r.Body).Decode(&searchResult)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(searchResult.Result, mock.SearchResult) {
t.Fatalf("\na %+v\nb %+v", searchResult.Result, mock.SearchResult)
}
}

func TestClientServerWithEmptyRepoIDsProvided(t *testing.T) {
searchQuery := "hello"
expectedSearch := mustParse(searchQuery)
expectedSearch = query.NewAnd(expectedSearch, query.NewRepoIDs())
mock := &mockSearcher.MockSearcher{
WantSearch: expectedSearch,
SearchResult: &zoekt.SearchResult{
Files: []zoekt.FileMatch{
{FileName: "bin.go"},
},
},
}

ts := httptest.NewServer(zjson.JSONServer(mock))
defer ts.Close()

searchBody := "{\"Q\":\"hello\",\"RepoIDs\":[]}"

r, err := http.Post(ts.URL+"/search", "application/json", bytes.NewBufferString(searchBody))
if err != nil {
t.Fatal(err)
}
if r.StatusCode != 200 {
body, _ := io.ReadAll(r.Body)
t.Fatalf("Got status code %d, err %s", r.StatusCode, string(body))
}

var searchResult struct{ Result *zoekt.SearchResult }
err = json.NewDecoder(r.Body).Decode(&searchResult)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(searchResult.Result, mock.SearchResult) {
t.Fatalf("\na %+v\nb %+v", searchResult.Result, mock.SearchResult)
}
}

func TestProgressNotEncodedInSearch(t *testing.T) {
searchQuery := "hello"
mock := &mockSearcher.MockSearcher{
Expand Down
15 changes: 15 additions & 0 deletions matchtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,21 @@ func (d *indexData) newMatchTree(q query.Q) (matchTree, error) {
},
}, nil

case *query.RepoIDs:
reposWant := make([]bool, len(d.repoMetaData))
for repoIdx, r := range d.repoMetaData {
if s.Repos.Contains(r.ID) {
reposWant[repoIdx] = true
}
}
return &docMatchTree{
reason: "RepoIDs",
numDocs: d.numDocs(),
predicate: func(docID uint32) bool {
return reposWant[d.repos[docID]]
},
}, nil

case *query.Repo:
reposWant := make([]bool, len(d.repoMetaData))
for repoIdx, r := range d.repoMetaData {
Expand Down
23 changes: 23 additions & 0 deletions matchtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,26 @@ func TestBranchesRepos(t *testing.T) {
t.Fatalf("expect %d documents, but got at least 1 more", len(want))
}
}

func TestRepoIDs(t *testing.T) {
d := &indexData{
repoMetaData: []Repository{{Name: "r0", ID: 0}, {Name: "r1", ID: 1}, {Name: "r2", ID: 2}, {Name: "r3", ID: 3}},
fileBranchMasks: []uint64{1, 1, 1, 1, 1, 1},
repos: []uint16{0, 0, 1, 2, 3, 3},
}
mt, err := d.newMatchTree(&query.RepoIDs{Repos: roaring.BitmapOf(1, 3, 99)})
if err != nil {
t.Fatal(err)
}
want := []uint32{2, 4, 5}
for i := 0; i < len(want); i++ {
nextDoc := mt.nextDoc()
if nextDoc != want[i] {
t.Fatalf("want %d, got %d", want[i], nextDoc)
}
mt.prepare(nextDoc)
}
if mt.nextDoc() != maxUInt32 {
t.Fatalf("expected %d document, but got at least 1 more", len(want))
}
}
27 changes: 27 additions & 0 deletions query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,27 @@ func (q *BranchesRepos) String() string {
return sb.String()
}

// NewRepoIDs is a helper for creating a RepoIDs which
// searches only the matched repos.
func NewRepoIDs(ids ...uint32) *RepoIDs {
return &RepoIDs{Repos: roaring.BitmapOf(ids...)}
}

func (q *RepoIDs) String() string {
var sb strings.Builder

sb.WriteString("(repoids ")

if size := q.Repos.GetCardinality(); size > 1 {
sb.WriteString("count:" + strconv.FormatUint(size, 10))
} else {
sb.WriteString("repoid=" + q.Repos.String())
}

sb.WriteString(")")
return sb.String()
}

// MarshalBinary implements a specialized encoder for BranchesRepos.
func (q BranchesRepos) MarshalBinary() ([]byte, error) {
return branchesReposEncode(q.List)
Expand All @@ -249,6 +270,12 @@ type BranchRepos struct {
Repos *roaring.Bitmap
}

// Similar to BranchRepos but will be used to match only by repoid and
// therefore matches all branches
type RepoIDs struct {
Repos *roaring.Bitmap
}

// RepoSet is a list of repos to match. It is a Sourcegraph addition and only
// used in the RPC interface for efficient checking of large repo lists.
type RepoSet struct {
Expand Down
1 change: 1 addition & 0 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func RegisterGob() {
gobRegister(&query.Regexp{})
gobRegister(&query.RepoRegexp{})
gobRegister(&query.RepoSet{})
gobRegister(&query.RepoIDs{})
gobRegister(&query.Repo{})
gobRegister(&query.Substring{})
gobRegister(&query.Symbol{})
Expand Down
9 changes: 9 additions & 0 deletions shards/shards.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ func selectRepoSet(shards []*rankedShard, q query.Q) ([]*rankedShard, query.Q) {
hasRepos = hasReposForPredicate(func(repo *zoekt.Repository) bool {
return setQuery.Set[repo.Name]
})
case *query.RepoIDs:
setSize = int(setQuery.Repos.GetCardinality())
hasRepos = hasReposForPredicate(func(repo *zoekt.Repository) bool {
return setQuery.Repos.Contains(repo.ID)
})
case *query.BranchesRepos:
for _, br := range setQuery.List {
setSize += int(br.Repos.GetCardinality())
Expand Down Expand Up @@ -445,6 +450,10 @@ func selectRepoSet(shards []*rankedShard, q query.Q) ([]*rankedShard, query.Q) {
and.Children[i] = &query.Const{Value: true}
return filtered, query.Simplify(and)

case *query.RepoIDs:
and.Children[i] = &query.Const{Value: true}
return filtered, query.Simplify(and)

case *query.BranchesRepos:
// We can only replace if all the repos want the same branches. We
// simplify and just check that we are requesting 1 branch. The common
Expand Down
13 changes: 11 additions & 2 deletions shards/shards_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,22 +226,25 @@ func TestShardedSearcher_Ranking(t *testing.T) {
}
}

func TestFilteringShardsByRepoSet(t *testing.T) {
func TestFilteringShardsByRepoSetOrBranchesReposOrRepoIDs(t *testing.T) {
ss := newShardedSearcher(1)

repoSetNames := []string{}
repoIDs := []uint32{}
n := 10 * runtime.GOMAXPROCS(0)
for i := 0; i < n; i++ {
shardName := fmt.Sprintf("shard%d", i)
repoName := fmt.Sprintf("repository%.3d", i)
repoID := hash(repoName)

if i%3 == 0 {
repoSetNames = append(repoSetNames, repoName)
repoIDs = append(repoIDs, repoID)
}

ss.replace(map[string]zoekt.Searcher{
shardName: &rankSearcher{
repo: &zoekt.Repository{ID: hash(repoName), Name: repoName},
repo: &zoekt.Repository{ID: repoID, Name: repoName},
rank: uint16(n - i),
},
})
Expand All @@ -266,6 +269,8 @@ func TestFilteringShardsByRepoSet(t *testing.T) {
set := query.NewRepoSet(repoSetNames...)
sub := &query.Substring{Pattern: "bla"}

repoIDsQuery := query.NewRepoIDs(repoIDs...)

queries := []query.Q{
query.NewAnd(set, sub),
// Test with the same reposet again
Expand All @@ -274,6 +279,10 @@ func TestFilteringShardsByRepoSet(t *testing.T) {
query.NewAnd(branchesRepos, sub),
// Test with the same repoBranches with IDs again
query.NewAnd(branchesRepos, sub),

query.NewAnd(repoIDsQuery, sub),
// Test with the same repoIDs again
query.NewAnd(repoIDsQuery, sub),
}

for _, q := range queries {
Expand Down