diff --git a/README.md b/README.md index 388d3cc56..2429f9fa7 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ We support and actively test against certain third-party clients to ensure compa |`NOW()`|Returns the current timestamp.| |`NULLIF(expr1, expr2)`|Returns NULL if expr1 = expr2 is true, otherwise returns expr1.| |`POW(X, Y)`|Returns the value of X raised to the power of Y.| +|`REGEXP_MATCHES(text, pattern, [flags])`|Returns an array with the matches of the pattern in the given text. Flags can be given to control certain behaviours of the regular expression. Currently, only the `i` flag is supported, to make the comparison case insensitive.| |`REPEAT(str, count)`|Returns a string consisting of the string str repeated count times.| |`REPLACE(str,from_str,to_str)`|Returns the string str with all occurrences of the string from_str replaced by the string to_str.| |`REVERSE(str)`|Returns the string str with the order of the characters reversed.| diff --git a/engine_test.go b/engine_test.go index 41568444c..b3381d03a 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1454,6 +1454,34 @@ var queries = []struct { ORDER BY table_type, table_schema, table_name`, []sql.Row{{"mydb", "mytable", "TABLE"}}, }, + { + `SELECT REGEXP_MATCHES("bopbeepbop", "bop")`, + []sql.Row{{[]interface{}{"bop", "bop"}}}, + }, + { + `SELECT EXPLODE(REGEXP_MATCHES("bopbeepbop", "bop"))`, + []sql.Row{{"bop"}, {"bop"}}, + }, + { + `SELECT EXPLODE(REGEXP_MATCHES("helloworld", "bop"))`, + []sql.Row{}, + }, + { + `SELECT EXPLODE(REGEXP_MATCHES("", ""))`, + []sql.Row{{""}}, + }, + { + `SELECT REGEXP_MATCHES(NULL, "")`, + []sql.Row{{nil}}, + }, + { + `SELECT REGEXP_MATCHES("", NULL)`, + []sql.Row{{nil}}, + }, + { + `SELECT REGEXP_MATCHES("", "", NULL)`, + []sql.Row{{nil}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/expression/function/regexp_matches.go b/sql/expression/function/regexp_matches.go new file mode 100644 index 000000000..417e91f5e --- /dev/null +++ b/sql/expression/function/regexp_matches.go @@ -0,0 +1,204 @@ +package function + +import ( + "fmt" + "regexp" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" +) + +// RegexpMatches returns the matches of a regular expression. +type RegexpMatches struct { + Text sql.Expression + Pattern sql.Expression + Flags sql.Expression + + cacheable bool + re *regexp.Regexp +} + +// NewRegexpMatches creates a new RegexpMatches expression. +func NewRegexpMatches(args ...sql.Expression) (sql.Expression, error) { + var r RegexpMatches + switch len(args) { + case 3: + r.Flags = args[2] + fallthrough + case 2: + r.Text = args[0] + r.Pattern = args[1] + default: + return nil, sql.ErrInvalidArgumentNumber.New("regexp_matches", "2 or 3", len(args)) + } + + if canBeCached(r.Pattern) && (r.Flags == nil || canBeCached(r.Flags)) { + r.cacheable = true + } + + return &r, nil +} + +// Type implements the sql.Expression interface. +func (r *RegexpMatches) Type() sql.Type { return sql.Array(sql.Text) } + +// IsNullable implements the sql.Expression interface. +func (r *RegexpMatches) IsNullable() bool { return true } + +// Children implements the sql.Expression interface. +func (r *RegexpMatches) Children() []sql.Expression { + var result = []sql.Expression{r.Text, r.Pattern} + if r.Flags != nil { + result = append(result, r.Flags) + } + return result +} + +// Resolved implements the sql.Expression interface. +func (r *RegexpMatches) Resolved() bool { + return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved()) +} + +// WithChildren implements the sql.Expression interface. +func (r *RegexpMatches) WithChildren(children ...sql.Expression) (sql.Expression, error) { + required := 2 + if r.Flags != nil { + required = 3 + } + + if len(children) != required { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required) + } + + return NewRegexpMatches(children...) +} + +func (r *RegexpMatches) String() string { + var args []string + for _, e := range r.Children() { + args = append(args, e.String()) + } + return fmt.Sprintf("regexp_matches(%s)", strings.Join(args, ", ")) +} + +// Eval implements the sql.Expression interface. +func (r *RegexpMatches) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + span, ctx := ctx.Span("function.RegexpMatches") + defer span.Finish() + + var re *regexp.Regexp + var err error + if r.cacheable { + if r.re == nil { + r.re, err = r.compileRegex(ctx, nil) + if err != nil { + return nil, err + } + + if r.re == nil { + return nil, nil + } + } + re = r.re + } else { + re, err = r.compileRegex(ctx, row) + if err != nil { + return nil, err + } + + if re == nil { + return nil, nil + } + } + + text, err := r.Text.Eval(ctx, row) + if err != nil { + return nil, err + } + + if text == nil { + return nil, nil + } + + text, err = sql.Text.Convert(text) + if err != nil { + return nil, err + } + + matches := re.FindAllStringSubmatch(text.(string), -1) + if len(matches) == 0 { + return nil, nil + } + + var result []interface{} + for _, m := range matches { + for _, sm := range m { + result = append(result, sm) + } + } + + return result, nil +} + +func (r *RegexpMatches) compileRegex(ctx *sql.Context, row sql.Row) (*regexp.Regexp, error) { + pattern, err := r.Pattern.Eval(ctx, row) + if err != nil { + return nil, err + } + + if pattern == nil { + return nil, nil + } + + pattern, err = sql.Text.Convert(pattern) + if err != nil { + return nil, err + } + + var flags string + if r.Flags != nil { + f, err := r.Flags.Eval(ctx, row) + if err != nil { + return nil, err + } + + if f == nil { + return nil, nil + } + + f, err = sql.Text.Convert(f) + if err != nil { + return nil, err + } + + flags = f.(string) + for _, f := range flags { + if !validRegexpFlags[f] { + return nil, errInvalidRegexpFlag.New(f) + } + } + + flags = fmt.Sprintf("(?%s)", flags) + } + + return regexp.Compile(flags + pattern.(string)) +} + +var errInvalidRegexpFlag = errors.NewKind("invalid regexp flag: %v") + +var validRegexpFlags = map[rune]bool{ + 'i': true, +} + +func canBeCached(e sql.Expression) bool { + var hasCols bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(*expression.GetField); ok { + hasCols = true + } + return true + }) + return !hasCols +} diff --git a/sql/expression/function/regexp_matches_test.go b/sql/expression/function/regexp_matches_test.go new file mode 100644 index 000000000..4a7fc35c5 --- /dev/null +++ b/sql/expression/function/regexp_matches_test.go @@ -0,0 +1,146 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" + + errors "gopkg.in/src-d/go-errors.v1" +) + +func TestRegexpMatches(t *testing.T) { + testCases := []struct { + pattern interface{} + text interface{} + flags interface{} + expected interface{} + err *errors.Kind + }{ + { + `^foobar(.*)bye$`, + "foobarhellobye", + "", + []interface{}{"foobarhellobye", "hello"}, + nil, + }, + { + "bop", + "bopbeepbop", + "", + []interface{}{"bop", "bop"}, + nil, + }, + { + "bop", + "bopbeepBop", + "i", + []interface{}{"bop", "Bop"}, + nil, + }, + { + "bop", + "helloworld", + "", + nil, + nil, + }, + { + "foo", + "", + "", + nil, + nil, + }, + { + "", + "", + "", + []interface{}{""}, + nil, + }, + { + "bop", + nil, + "", + nil, + nil, + }, + { + "bop", + "beep", + nil, + nil, + nil, + }, + { + nil, + "bop", + "", + nil, + nil, + }, + { + "bop", + "bopbeepBop", + "ix", + nil, + errInvalidRegexpFlag, + }, + } + + t.Run("cacheable", func(t *testing.T) { + for _, tt := range testCases { + var flags sql.Expression + if tt.flags != "" { + flags = expression.NewLiteral(tt.flags, sql.Text) + } + f, err := NewRegexpMatches( + expression.NewLiteral(tt.text, sql.Text), + expression.NewLiteral(tt.pattern, sql.Text), + flags, + ) + require.NoError(t, err) + + t.Run(f.String(), func(t *testing.T) { + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), nil) + if tt.err == nil { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.Error(err) + require.True(tt.err.Is(err)) + } + }) + } + }) + + t.Run("not cacheable", func(t *testing.T) { + for _, tt := range testCases { + var flags sql.Expression + if tt.flags != "" { + flags = expression.NewGetField(2, sql.Text, "x", false) + } + f, err := NewRegexpMatches( + expression.NewGetField(0, sql.Text, "x", false), + expression.NewGetField(1, sql.Text, "x", false), + flags, + ) + require.NoError(t, err) + + t.Run(f.String(), func(t *testing.T) { + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.Row{tt.text, tt.pattern, tt.flags}) + if tt.err == nil { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.Error(err) + require.True(tt.err.Is(err)) + } + }) + } + }) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 2a7b07411..7d2586bba 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -98,4 +98,5 @@ var Defaults = []sql.Function{ sql.Function1{Name: "char_length", Fn: NewCharLength}, sql.Function1{Name: "character_length", Fn: NewCharLength}, sql.Function1{Name: "explode", Fn: NewExplode}, + sql.FunctionN{Name: "regexp_matches", Fn: NewRegexpMatches}, }