Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
function: implement regexp_matches
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Aug 5, 2019
1 parent 550cc54 commit 18e2012
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ We support and actively test against certain third-party clients to ensure compa
|`NOW()`|Returns the current timestamp.|
|`NULLIF(expr1, expr2)`|Returns NULL if expr1 = expr2 is true, otherwise returns expr1.|
|`POW(X, Y)`|Returns the value of X raised to the power of Y.|
|`REGEXP_MATCHES(text, pattern, [flags])`|Returns an array with the matches of the pattern in the given text. Flags can be given to control certain behaviours of the regular expression. Currently, only the `i` flag is supported, to make the comparison case insensitive.|
|`REPEAT(str, count)`|Returns a string consisting of the string str repeated count times.|
|`REPLACE(str,from_str,to_str)`|Returns the string str with all occurrences of the string from_str replaced by the string to_str.|
|`REVERSE(str)`|Returns the string str with the order of the characters reversed.|
Expand Down
12 changes: 12 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,18 @@ var queries = []struct {
ORDER BY table_type, table_schema, table_name`,
[]sql.Row{{"mydb", "mytable", "TABLE"}},
},
{
`SELECT REGEXP_MATCHES("bopbeepbop", "bop")`,
[]sql.Row{{[]interface{}{"bop", "bop"}}},
},
{
`SELECT EXPLODE(REGEXP_MATCHES("bopbeepbop", "bop"))`,
[]sql.Row{{"bop"}, {"bop"}},
},
{
`SELECT EXPLODE(REGEXP_MATCHES("helloworld", "bop"))`,
[]sql.Row{},
},
}

func TestQueries(t *testing.T) {
Expand Down
184 changes: 184 additions & 0 deletions sql/expression/function/regexp_matches.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package function

import (
"fmt"
"regexp"
"strings"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
errors "gopkg.in/src-d/go-errors.v1"
)

// RegexpMatches returns the matches of a regular expression.
type RegexpMatches struct {
Text sql.Expression
Pattern sql.Expression
Flags sql.Expression

cacheable bool
re *regexp.Regexp
}

// NewRegexpMatches creates a new RegexpMatches expression.
func NewRegexpMatches(args ...sql.Expression) (sql.Expression, error) {
var r RegexpMatches
switch len(args) {
case 3:
r.Flags = args[2]
fallthrough
case 2:
r.Text = args[0]
r.Pattern = args[1]
default:
return nil, sql.ErrInvalidArgumentNumber.New("regexp_matches", "2 or 3", len(args))
}

if canBeCached(r.Pattern) && (r.Flags == nil || canBeCached(r.Flags)) {
r.cacheable = true
}

return &r, nil
}

// Type implements the sql.Expression interface.
func (r *RegexpMatches) Type() sql.Type { return sql.Array(sql.Text) }

// IsNullable implements the sql.Expression interface.
func (r *RegexpMatches) IsNullable() bool { return true }

// Children implements the sql.Expression interface.
func (r *RegexpMatches) Children() []sql.Expression {
var result = []sql.Expression{r.Text, r.Pattern}
if r.Flags != nil {
result = append(result, r.Flags)
}
return result
}

// Resolved implements the sql.Expression interface.
func (r *RegexpMatches) Resolved() bool {
return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved())
}

// WithChildren implements the sql.Expression interface.
func (r *RegexpMatches) WithChildren(children ...sql.Expression) (sql.Expression, error) {
required := 2
if r.Flags != nil {
required = 3
}

if len(children) != required {
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required)
}

return NewRegexpMatches(children...)
}

func (r *RegexpMatches) String() string {
var args []string
for _, e := range r.Children() {
args = append(args, e.String())
}
return fmt.Sprintf("regexp_matches(%s)", strings.Join(args, ", "))
}

// Eval implements the sql.Expression interface.
func (r *RegexpMatches) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
span, ctx := ctx.Span("function.RegexpMatches")
defer span.Finish()

var re *regexp.Regexp
var err error
if r.cacheable {
if r.re == nil {
r.re, err = r.compileRegex(ctx, nil)
if err != nil {
return nil, err
}
}
re = r.re
} else {
re, err = r.compileRegex(ctx, row)
if err != nil {
return nil, err
}
}

text, err := r.Text.Eval(ctx, row)
if err != nil {
return nil, err
}

text, err = sql.Text.Convert(text)
if err != nil {
return nil, err
}

matches := re.FindAllStringSubmatch(text.(string), -1)
if len(matches) == 0 {
return nil, nil
}

var result []interface{}
for _, m := range matches {
for _, sm := range m {
result = append(result, sm)
}
}

return result, nil
}

func (r *RegexpMatches) compileRegex(ctx *sql.Context, row sql.Row) (*regexp.Regexp, error) {
pattern, err := r.Pattern.Eval(ctx, row)
if err != nil {
return nil, err
}

pattern, err = sql.Text.Convert(pattern)
if err != nil {
return nil, err
}

var flags string
if r.Flags != nil {
f, err := r.Flags.Eval(ctx, row)
if err != nil {
return nil, err
}

f, err = sql.Text.Convert(f)
if err != nil {
return nil, err
}

flags = f.(string)
for _, f := range flags {
if !validRegexpFlags[f] {
return nil, errInvalidRegexpFlag.New(f)
}
}

flags = fmt.Sprintf("(?%s)", flags)
}

return regexp.Compile(flags + pattern.(string))
}

var errInvalidRegexpFlag = errors.NewKind("invalid regexp flag: %v")

var validRegexpFlags = map[rune]bool{
'i': true,
}

func canBeCached(e sql.Expression) bool {
var hasCols bool
expression.Inspect(e, func(e sql.Expression) bool {
if _, ok := e.(*expression.GetField); ok {
hasCols = true
}
return true
})
return !hasCols
}
111 changes: 111 additions & 0 deletions sql/expression/function/regexp_matches_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package function

import (
"testing"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/stretchr/testify/require"

errors "gopkg.in/src-d/go-errors.v1"
)

func TestRegexpMatches(t *testing.T) {
testCases := []struct {
pattern string
text string
flags string
expected interface{}
err *errors.Kind
}{
{
`^foobar(.*)bye$`,
"foobarhellobye",
"",
[]interface{}{"foobarhellobye", "hello"},
nil,
},
{
"bop",
"bopbeepbop",
"",
[]interface{}{"bop", "bop"},
nil,
},
{
"bop",
"bopbeepBop",
"i",
[]interface{}{"bop", "Bop"},
nil,
},
{
"bop",
"helloworld",
"",
nil,
nil,
},
{
"bop",
"bopbeepBop",
"ix",
nil,
errInvalidRegexpFlag,
},
}

t.Run("cacheable", func(t *testing.T) {
for _, tt := range testCases {
var flags sql.Expression
if tt.flags != "" {
flags = expression.NewLiteral(tt.flags, sql.Text)
}
f, err := NewRegexpMatches(
expression.NewLiteral(tt.text, sql.Text),
expression.NewLiteral(tt.pattern, sql.Text),
flags,
)
require.NoError(t, err)

t.Run(f.String(), func(t *testing.T) {
require := require.New(t)
result, err := f.Eval(sql.NewEmptyContext(), nil)
if tt.err == nil {
require.NoError(err)
require.Equal(tt.expected, result)
} else {
require.Error(err)
require.True(tt.err.Is(err))
}
})
}
})

t.Run("not cacheable", func(t *testing.T) {
for _, tt := range testCases {
var flags sql.Expression
if tt.flags != "" {
flags = expression.NewGetField(2, sql.Text, "x", false)
}
f, err := NewRegexpMatches(
expression.NewGetField(0, sql.Text, "x", false),
expression.NewGetField(1, sql.Text, "x", false),
flags,
)
require.NoError(t, err)

t.Run(f.String(), func(t *testing.T) {
require := require.New(t)
result, err := f.Eval(sql.NewEmptyContext(), sql.Row{tt.text, tt.pattern, tt.flags})
if tt.err == nil {
require.NoError(err)
require.Equal(tt.expected, result)
} else {
require.Error(err)
require.True(tt.err.Is(err))
}
})
}
})
}
1 change: 1 addition & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ var Defaults = []sql.Function{
sql.Function1{Name: "char_length", Fn: NewCharLength},
sql.Function1{Name: "character_length", Fn: NewCharLength},
sql.Function1{Name: "explode", Fn: NewExplode},
sql.FunctionN{Name: "regexp_matches", Fn: NewRegexpMatches},
}

0 comments on commit 18e2012

Please sign in to comment.