Skip to content

Commit

Permalink
feat: read fields support for search (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
adilansari committed Jul 25, 2022
1 parent 72a86de commit 45a7e5d
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 20 deletions.
2 changes: 1 addition & 1 deletion api/proto
Submodule proto updated from 02fad3 to 4669ee
11 changes: 8 additions & 3 deletions api/server/v1/marshaler.go
Expand Up @@ -118,9 +118,14 @@ func (x *SearchRequest) UnmarshalJSON(data []byte) error {
case "sort":
// delaying the sort deserialization
x.Sort = value
case "fields":
// not decoding it here and let it decode during fields parsing
x.Fields = value
case "include_fields":
if err := jsoniter.Unmarshal(value, &x.IncludeFields); err != nil {
return err
}
case "exclude_fields":
if err := jsoniter.Unmarshal(value, &x.ExcludeFields); err != nil {
return err
}
case "page_size":
if err := jsoniter.Unmarshal(value, &x.PageSize); err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions api/server/v1/marshaler_test.go
Expand Up @@ -33,7 +33,7 @@ func TestJSONEncoding(t *testing.T) {
t.Run("unmarshal SearchRequest", func(t *testing.T) {
inputDoc := []byte(`{"q":"my search text","search_fields":["first_name","last_name"],
"filter":{"last_name":"Steve"},"facet":{"facet stat":0},
"sort":[{"salary":"$asc"}],"fields":["employment","history"]}`)
"sort":[{"salary":"$asc"}],"include_fields":["employment","history"]}`)

req := &SearchRequest{}
err := json.Unmarshal(inputDoc, req)
Expand All @@ -43,7 +43,7 @@ func TestJSONEncoding(t *testing.T) {
require.Equal(t, []byte(`{"last_name":"Steve"}`), req.GetFilter())
require.Equal(t, []byte(`{"facet stat":0}`), req.GetFacet())
require.Equal(t, []byte(`[{"salary":"$asc"}]`), req.GetSort())
require.Equal(t, []byte(`["employment","history"]`), req.GetFields())
require.Equal(t, []string{"employment", "history"}, req.GetIncludeFields())
})

t.Run("marshal SearchResponse", func(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions api/server/v1/validator.go
Expand Up @@ -103,6 +103,10 @@ func (x *SearchRequest) Validate() error {
return err
}

if len(x.IncludeFields) > 0 && len(x.ExcludeFields) > 0 {
return Errorf(Code_INVALID_ARGUMENT, "Cannot use both `include_fields` and `exclude_fields` together")
}

if err := isValidPaginationParam("page", int(x.Page)); err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions query/read/fields.go
Expand Up @@ -52,7 +52,7 @@ func BuildFields(reqFields jsoniter.RawMessage) (*FieldFactory, error) {
return err
}

factory.addField(&SimpleField{
factory.AddField(&SimpleField{
Name: string(key),
Incl: include,
})
Expand All @@ -64,7 +64,7 @@ func BuildFields(reqFields jsoniter.RawMessage) (*FieldFactory, error) {
return err
}

factory.addField(&SimpleField{
factory.AddField(&SimpleField{
Name: string(key),
Incl: include == 1,
})
Expand All @@ -74,7 +74,7 @@ func BuildFields(reqFields jsoniter.RawMessage) (*FieldFactory, error) {
if err != nil {
return err
}
factory.addField(NewExprField(string(key), expr))
factory.AddField(NewExprField(string(key), expr))
default:
return api.Errorf(api.Code_INVALID_ARGUMENT, "only boolean/integer is supported as value")
}
Expand All @@ -93,7 +93,7 @@ type FieldFactory struct {
FetchedValues map[string]*JSONObject
}

func (factory *FieldFactory) addField(f Field) {
func (factory *FieldFactory) AddField(f Field) {
if !f.Include() {
factory.Exclude[f.Alias()] = f
return
Expand Down
21 changes: 14 additions & 7 deletions query/search/search.go
Expand Up @@ -16,18 +16,20 @@ package search

import (
"github.com/tigrisdata/tigris/query/filter"
"github.com/tigrisdata/tigris/query/read"
)

const (
all = "*"
)

type Query struct {
Q string
Fields []string
Facets Facets
PageSize int
WrappedF *filter.WrappedFilter
Q string
SearchFields []string
Facets Facets
PageSize int
WrappedF *filter.WrappedFilter
ReadFields *read.FieldFactory
}

func (q *Query) ToSearchFacetSize() int {
Expand Down Expand Up @@ -63,7 +65,7 @@ func (q *Query) ToSearchFacets() string {

func (q *Query) ToSearchFields() string {
var fields string
for i, f := range q.Fields {
for i, f := range q.SearchFields {
if i != 0 {
fields += ","
}
Expand Down Expand Up @@ -104,7 +106,12 @@ func (b *Builder) Facets(facets Facets) *Builder {
}

func (b *Builder) SearchFields(f []string) *Builder {
b.query.Fields = f
b.query.SearchFields = f
return b
}

func (b *Builder) ReadFields(f *read.FieldFactory) *Builder {
b.query.ReadFields = f
return b
}

Expand Down
44 changes: 43 additions & 1 deletion server/services/v1/query_runner.go
Expand Up @@ -536,6 +536,11 @@ func (runner *SearchQueryRunner) Run(ctx context.Context, tx transaction.Tx, ten
return nil, ctx, err
}

fieldSelection, err := runner.getFieldSelection(collection.GetFields())
if err != nil {
return nil, ctx, err
}

pageSize := int(runner.req.PageSize)
if pageSize == 0 {
pageSize = defaultPerPage
Expand All @@ -548,6 +553,7 @@ func (runner *SearchQueryRunner) Run(ctx context.Context, tx transaction.Tx, ten
Facets(facets).
PageSize(pageSize).
Filter(wrappedF).
ReadFields(fieldSelection).
Build()

var rowReader *SearchRowReader
Expand Down Expand Up @@ -665,7 +671,6 @@ func (runner *SearchQueryRunner) getFacetFields(collFields []*schema.Field) (qse
}
found = true
break

}
if !found {
return qsearch.Facets{}, api.Errorf(api.Code_INVALID_ARGUMENT, "`%s` is not a schema field", ff.Name)
Expand All @@ -675,6 +680,43 @@ func (runner *SearchQueryRunner) getFacetFields(collFields []*schema.Field) (qse
return facets, nil
}

func (runner *SearchQueryRunner) getFieldSelection(collFields []*schema.Field) (*read.FieldFactory, error) {
var selectionFields []string

// Only one of include/exclude. Honor inclusion over exclusion
if len(runner.req.IncludeFields) > 0 {
selectionFields = runner.req.IncludeFields
} else if len(runner.req.ExcludeFields) > 0 {
selectionFields = runner.req.ExcludeFields
} else {
return nil, nil
}

factory := &read.FieldFactory{
Include: map[string]read.Field{},
Exclude: map[string]read.Field{},
}

for _, sf := range selectionFields {
found := false
for _, cf := range collFields {
if sf == cf.FieldName {
found = true
}
}
if !found {
return nil, api.Errorf(api.Code_INVALID_ARGUMENT, "`%s` is not a schema field", sf)
}

factory.AddField(&read.SimpleField{
Name: sf,
Incl: len(runner.req.IncludeFields) > 0,
})
}

return factory, nil
}

type CollectionQueryRunner struct {
*BaseQueryRunner

Expand Down
97 changes: 97 additions & 0 deletions server/services/v1/query_runner_test.go
@@ -0,0 +1,97 @@
package v1

import (
"testing"

"github.com/stretchr/testify/assert"
api "github.com/tigrisdata/tigris/api/server/v1"
"github.com/tigrisdata/tigris/schema"
)

func TestSearchQueryRunner_getFieldSelection(t *testing.T) {
t.Run("only include fields are provided", func(t *testing.T) {
collFields := []*schema.Field{
{FieldName: "field_1"},
{FieldName: "field_2"},
}

runner := &SearchQueryRunner{
req: &api.SearchRequest{
IncludeFields: []string{"field_1", "field_2"},
},
}

factory, err := runner.getFieldSelection(collFields)

assert.Nil(t, err)
assert.NotNil(t, factory)
assert.Empty(t, factory.Exclude)
assert.Len(t, factory.Include, 2)
assert.Contains(t, factory.Include, "field_1")
assert.Contains(t, factory.Include, "field_2")
})

t.Run("only exclude fields are provided", func(t *testing.T) {
collFields := []*schema.Field{
{FieldName: "field_1"},
{FieldName: "field_2"},
}

runner := &SearchQueryRunner{
req: &api.SearchRequest{
ExcludeFields: []string{"field_1", "field_2"},
},
}

factory, err := runner.getFieldSelection(collFields)

assert.Nil(t, err)
assert.NotNil(t, factory)
assert.Empty(t, factory.Include)
assert.Len(t, factory.Exclude, 2)
assert.Contains(t, factory.Exclude, "field_1")
assert.Contains(t, factory.Exclude, "field_2")
})

t.Run("no fields to include or exclude", func(t *testing.T) {
collFields := []*schema.Field{
{FieldName: "field_1"},
{FieldName: "field_2"},
}
runner := &SearchQueryRunner{req: &api.SearchRequest{}}

factory, err := runner.getFieldSelection(collFields)

assert.Nil(t, err)
assert.Nil(t, factory)
})

t.Run("no schema fields are defined", func(t *testing.T) {
var collFields []*schema.Field
runner := &SearchQueryRunner{req: &api.SearchRequest{}}

factory, err := runner.getFieldSelection(collFields)

assert.Nil(t, err)
assert.Nil(t, factory)
})

t.Run("selection fields are not in schema", func(t *testing.T) {
collFields := []*schema.Field{
{FieldName: "field_1"},
{FieldName: "field_2"},
}

runner := &SearchQueryRunner{
req: &api.SearchRequest{
ExcludeFields: []string{"field_2", "field_3"},
},
}

factory, err := runner.getFieldSelection(collFields)

assert.Nil(t, factory)
assert.NotNil(t, err)
assert.ErrorContains(t, err, "`field_3` is not a schema field")
})
}
22 changes: 20 additions & 2 deletions server/services/v1/search_reader.go
Expand Up @@ -20,9 +20,11 @@ import (
jsoniter "github.com/json-iterator/go"
api "github.com/tigrisdata/tigris/api/server/v1"
"github.com/tigrisdata/tigris/query/filter"
"github.com/tigrisdata/tigris/query/read"
qsearch "github.com/tigrisdata/tigris/query/search"
"github.com/tigrisdata/tigris/schema"
"github.com/tigrisdata/tigris/store/search"
ulog "github.com/tigrisdata/tigris/util/log"
tsApi "github.com/typesense/typesense-go/typesense/api"
)

Expand All @@ -38,6 +40,7 @@ type page struct {
hits *HitsResponse
wrappedF *filter.WrappedFilter
collection *schema.DefaultCollection
readFields *read.FieldFactory
}

func newPage(collection *schema.DefaultCollection, query *qsearch.Query) *page {
Expand All @@ -47,6 +50,7 @@ func newPage(collection *schema.DefaultCollection, query *qsearch.Query) *page {
cap: query.PageSize,
wrappedF: query.WrappedF,
collection: collection,
readFields: query.ReadFields,
}
}

Expand Down Expand Up @@ -89,11 +93,25 @@ func (p *page) readRow(row *Row) bool {
continue
}

// set the raw data now after marshaling it
if row.Data.RawData, p.err = jsoniter.Marshal(doc); p.err != nil {
var rawData []byte

// marshal the doc as bytes
rawData, p.err = jsoniter.Marshal(doc)
if p.err != nil {
return false
}

// apply field selection
if p.readFields != nil {
newValue, err := p.readFields.Apply(rawData)
if ulog.E(err) {
return false
}
row.Data.RawData = newValue
} else {
row.Data.RawData = rawData
}

return true
}

Expand Down

0 comments on commit 45a7e5d

Please sign in to comment.