Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

function: implement regexp_matches #794

Merged
merged 1 commit into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ We support and actively test against certain third-party clients to ensure compa
|`NOW()`|Returns the current timestamp.|
|`NULLIF(expr1, expr2)`|Returns NULL if expr1 = expr2 is true, otherwise returns expr1.|
|`POW(X, Y)`|Returns the value of X raised to the power of Y.|
|`REGEXP_MATCHES(text, pattern, [flags])`|Returns an array with the matches of the pattern in the given text. Flags can be given to control certain behaviours of the regular expression. Currently, only the `i` flag is supported, to make the comparison case insensitive.|
|`REPEAT(str, count)`|Returns a string consisting of the string str repeated count times.|
|`REPLACE(str,from_str,to_str)`|Returns the string str with all occurrences of the string from_str replaced by the string to_str.|
|`REVERSE(str)`|Returns the string str with the order of the characters reversed.|
Expand Down
28 changes: 28 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,34 @@ var queries = []struct {
ORDER BY table_type, table_schema, table_name`,
[]sql.Row{{"mydb", "mytable", "TABLE"}},
},
{
`SELECT REGEXP_MATCHES("bopbeepbop", "bop")`,
[]sql.Row{{[]interface{}{"bop", "bop"}}},
},
{
`SELECT EXPLODE(REGEXP_MATCHES("bopbeepbop", "bop"))`,
[]sql.Row{{"bop"}, {"bop"}},
},
{
`SELECT EXPLODE(REGEXP_MATCHES("helloworld", "bop"))`,
[]sql.Row{},
},
erizocosmico marked this conversation as resolved.
Show resolved Hide resolved
{
`SELECT EXPLODE(REGEXP_MATCHES("", ""))`,
[]sql.Row{{""}},
},
{
`SELECT REGEXP_MATCHES(NULL, "")`,
[]sql.Row{{nil}},
},
{
`SELECT REGEXP_MATCHES("", NULL)`,
[]sql.Row{{nil}},
},
{
`SELECT REGEXP_MATCHES("", "", NULL)`,
[]sql.Row{{nil}},
},
}

func TestQueries(t *testing.T) {
Expand Down
204 changes: 204 additions & 0 deletions sql/expression/function/regexp_matches.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package function

import (
"fmt"
"regexp"
"strings"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
errors "gopkg.in/src-d/go-errors.v1"
)

// RegexpMatches returns the matches of a regular expression.
type RegexpMatches struct {
Text sql.Expression
Pattern sql.Expression
Flags sql.Expression

cacheable bool
re *regexp.Regexp
}

// NewRegexpMatches creates a new RegexpMatches expression.
func NewRegexpMatches(args ...sql.Expression) (sql.Expression, error) {
var r RegexpMatches
switch len(args) {
case 3:
r.Flags = args[2]
fallthrough
case 2:
r.Text = args[0]
r.Pattern = args[1]
default:
return nil, sql.ErrInvalidArgumentNumber.New("regexp_matches", "2 or 3", len(args))
}

if canBeCached(r.Pattern) && (r.Flags == nil || canBeCached(r.Flags)) {
r.cacheable = true
}

return &r, nil
}

// Type implements the sql.Expression interface.
func (r *RegexpMatches) Type() sql.Type { return sql.Array(sql.Text) }

// IsNullable implements the sql.Expression interface.
func (r *RegexpMatches) IsNullable() bool { return true }

// Children implements the sql.Expression interface.
func (r *RegexpMatches) Children() []sql.Expression {
var result = []sql.Expression{r.Text, r.Pattern}
if r.Flags != nil {
result = append(result, r.Flags)
}
return result
}

// Resolved implements the sql.Expression interface.
func (r *RegexpMatches) Resolved() bool {
return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved())
}

// WithChildren implements the sql.Expression interface.
func (r *RegexpMatches) WithChildren(children ...sql.Expression) (sql.Expression, error) {
required := 2
if r.Flags != nil {
required = 3
}

if len(children) != required {
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required)
}

return NewRegexpMatches(children...)
}

func (r *RegexpMatches) String() string {
var args []string
for _, e := range r.Children() {
args = append(args, e.String())
}
return fmt.Sprintf("regexp_matches(%s)", strings.Join(args, ", "))
}

// Eval implements the sql.Expression interface.
func (r *RegexpMatches) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
span, ctx := ctx.Span("function.RegexpMatches")
defer span.Finish()

var re *regexp.Regexp
var err error
if r.cacheable {
if r.re == nil {
r.re, err = r.compileRegex(ctx, nil)
if err != nil {
return nil, err
}

if r.re == nil {
return nil, nil
}
}
re = r.re
} else {
re, err = r.compileRegex(ctx, row)
if err != nil {
return nil, err
}

if re == nil {
return nil, nil
}
}

text, err := r.Text.Eval(ctx, row)
if err != nil {
return nil, err
}

if text == nil {
return nil, nil
}

text, err = sql.Text.Convert(text)
if err != nil {
return nil, err
}

matches := re.FindAllStringSubmatch(text.(string), -1)
if len(matches) == 0 {
return nil, nil
}

var result []interface{}
for _, m := range matches {
for _, sm := range m {
result = append(result, sm)
}
}

return result, nil
}

func (r *RegexpMatches) compileRegex(ctx *sql.Context, row sql.Row) (*regexp.Regexp, error) {
pattern, err := r.Pattern.Eval(ctx, row)
if err != nil {
return nil, err
}

if pattern == nil {
return nil, nil
}

pattern, err = sql.Text.Convert(pattern)
if err != nil {
return nil, err
}

var flags string
if r.Flags != nil {
f, err := r.Flags.Eval(ctx, row)
if err != nil {
return nil, err
}

if f == nil {
return nil, nil
}

f, err = sql.Text.Convert(f)
if err != nil {
return nil, err
}

flags = f.(string)
for _, f := range flags {
if !validRegexpFlags[f] {
return nil, errInvalidRegexpFlag.New(f)
}
}

flags = fmt.Sprintf("(?%s)", flags)
}

return regexp.Compile(flags + pattern.(string))
}

var errInvalidRegexpFlag = errors.NewKind("invalid regexp flag: %v")

var validRegexpFlags = map[rune]bool{
'i': true,
}

func canBeCached(e sql.Expression) bool {
var hasCols bool
expression.Inspect(e, func(e sql.Expression) bool {
if _, ok := e.(*expression.GetField); ok {
hasCols = true
}
return true
})
return !hasCols
}
146 changes: 146 additions & 0 deletions sql/expression/function/regexp_matches_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package function

import (
"testing"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/stretchr/testify/require"

errors "gopkg.in/src-d/go-errors.v1"
)

func TestRegexpMatches(t *testing.T) {
testCases := []struct {
pattern interface{}
text interface{}
flags interface{}
expected interface{}
err *errors.Kind
}{
{
`^foobar(.*)bye$`,
"foobarhellobye",
"",
[]interface{}{"foobarhellobye", "hello"},
nil,
},
{
"bop",
"bopbeepbop",
"",
[]interface{}{"bop", "bop"},
nil,
},
{
"bop",
"bopbeepBop",
"i",
[]interface{}{"bop", "Bop"},
nil,
},
{
"bop",
"helloworld",
"",
nil,
nil,
},
{
"foo",
"",
"",
nil,
nil,
},
{
"",
"",
"",
[]interface{}{""},
nil,
},
{
"bop",
nil,
"",
nil,
nil,
},
{
"bop",
"beep",
nil,
nil,
nil,
},
{
nil,
"bop",
"",
nil,
nil,
},
{
"bop",
"bopbeepBop",
"ix",
nil,
errInvalidRegexpFlag,
},
}
erizocosmico marked this conversation as resolved.
Show resolved Hide resolved

t.Run("cacheable", func(t *testing.T) {
for _, tt := range testCases {
var flags sql.Expression
if tt.flags != "" {
flags = expression.NewLiteral(tt.flags, sql.Text)
}
f, err := NewRegexpMatches(
expression.NewLiteral(tt.text, sql.Text),
expression.NewLiteral(tt.pattern, sql.Text),
flags,
)
require.NoError(t, err)

t.Run(f.String(), func(t *testing.T) {
require := require.New(t)
result, err := f.Eval(sql.NewEmptyContext(), nil)
if tt.err == nil {
require.NoError(err)
require.Equal(tt.expected, result)
} else {
require.Error(err)
require.True(tt.err.Is(err))
}
})
}
})

t.Run("not cacheable", func(t *testing.T) {
for _, tt := range testCases {
var flags sql.Expression
if tt.flags != "" {
flags = expression.NewGetField(2, sql.Text, "x", false)
}
f, err := NewRegexpMatches(
expression.NewGetField(0, sql.Text, "x", false),
expression.NewGetField(1, sql.Text, "x", false),
flags,
)
require.NoError(t, err)

t.Run(f.String(), func(t *testing.T) {
require := require.New(t)
result, err := f.Eval(sql.NewEmptyContext(), sql.Row{tt.text, tt.pattern, tt.flags})
if tt.err == nil {
require.NoError(err)
require.Equal(tt.expected, result)
} else {
require.Error(err)
require.True(tt.err.Is(err))
}
})
}
})
}
Loading