Skip to content

Commit

Permalink
Fix a few issues found by fuzzing (FerretDB#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekSi authored and w84thesun committed Mar 11, 2022
1 parent 5083f63 commit 549f737
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
17 changes: 14 additions & 3 deletions internal/handlers/jsonb1/where.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ func fieldExpr(field string, expr *types.Document, p *pg.Placeholder) (sql strin
}
sql += "NOT("

argSql, arg, err = fieldExpr(field, value.(*types.Document), p)
var exprValue *types.Document
if exprValue, err = common.AssertType[*types.Document](value); err != nil {
err = lazyerrors.Errorf("fieldExpr: %w", err)
return
}
argSql, arg, err = fieldExpr(field, exprValue, p)
if err != nil {
err = lazyerrors.Errorf("fieldExpr: %w", err)
return
Expand All @@ -150,11 +155,17 @@ func fieldExpr(field string, expr *types.Document, p *pg.Placeholder) (sql strin
case "$in":
// {field: {$in: [value1, value2, ...]}}
sql += "_jsonb->" + p.Next() + " IN"
argSql, arg, err = common.InArray(value.(*types.Array), p, scalar)
var arr *types.Array
if arr, err = common.AssertType[*types.Array](value); err == nil {
argSql, arg, err = common.InArray(arr, p, scalar)
}
case "$nin":
// {field: {$nin: [value1, value2, ...]}}
sql += "_jsonb->" + p.Next() + " NOT IN"
argSql, arg, err = common.InArray(value.(*types.Array), p, scalar)
var arr *types.Array
if arr, err = common.AssertType[*types.Array](value); err == nil {
argSql, arg, err = common.InArray(arr, p, scalar)
}
case "$eq":
// {field: {$eq: value}}
// TODO special handling for regex
Expand Down
17 changes: 14 additions & 3 deletions internal/handlers/sql/where.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ func fieldExpr(field string, expr *types.Document, p *pg.Placeholder) (sql strin
}
sql += "NOT("

argSql, arg, err = fieldExpr(field, value.(*types.Document), p)
var exprValue *types.Document
if exprValue, err = common.AssertType[*types.Document](value); err != nil {
err = lazyerrors.Errorf("fieldExpr: %w", err)
return
}
argSql, arg, err = fieldExpr(field, exprValue, p)
if err != nil {
err = lazyerrors.Errorf("fieldExpr: %w", err)
return
Expand All @@ -97,11 +102,17 @@ func fieldExpr(field string, expr *types.Document, p *pg.Placeholder) (sql strin
case "$in":
// {field: {$in: [value1, value2, ...]}}
sql += " IN"
argSql, arg, err = common.InArray(value.(*types.Array), p, scalar)
var arr *types.Array
if arr, err = common.AssertType[*types.Array](value); err == nil {
argSql, arg, err = common.InArray(arr, p, scalar)
}
case "$nin":
// {field: {$nin: [value1, value2, ...]}}
sql += " NOT IN"
argSql, arg, err = common.InArray(value.(*types.Array), p, scalar)
var arr *types.Array
if arr, err = common.AssertType[*types.Array](value); err == nil {
argSql, arg, err = common.InArray(arr, p, scalar)
}
case "$eq":
// {field: {$eq: value}}
// TODO special handling for regex
Expand Down
9 changes: 7 additions & 2 deletions internal/types/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,14 @@ func (d *Document) Keys() []string {
return d.keys
}

// Command returns the first document's key, this is often used as a command name.
// Command returns the first document's key lowercased. This is often used as a command name.
// It returns an empty string if document is nil or empty.
func (d *Document) Command() string {
return strings.ToLower(d.keys[0])
keys := d.Keys()
if len(keys) == 0 {
return ""
}
return strings.ToLower(keys[0])
}

func (d *Document) add(key string, value any) error {
Expand Down
8 changes: 6 additions & 2 deletions internal/types/document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestDocument(t *testing.T) {
assert.Zero(t, doc.Len())
assert.Nil(t, doc.Map())
assert.Nil(t, doc.Keys())
assert.Equal(t, "", doc.Command())
})

t.Run("ZeroValues", func(t *testing.T) {
Expand All @@ -46,10 +47,11 @@ func TestDocument(t *testing.T) {
assert.Equal(t, 0, doc.Len())
assert.Nil(t, doc.m)
assert.Nil(t, doc.keys)
assert.Equal(t, "", doc.Command())

err := doc.Set("foo", Null)
err := doc.Set("Foo", Null)
assert.NoError(t, err)
value, err := doc.Get("foo")
value, err := doc.Get("Foo")
assert.NoError(t, err)
assert.Equal(t, Null, value)

Expand All @@ -58,6 +60,8 @@ func TestDocument(t *testing.T) {

err = doc.Set("bar", nil)
assert.EqualError(t, err, `types.Document.validate: types.validateValue: unsupported type: <nil> (<nil>)`)

assert.Equal(t, "foo", doc.Command())
})

t.Run("NewDocument", func(t *testing.T) {
Expand Down

0 comments on commit 549f737

Please sign in to comment.