From 728f747279cee864592de0b9da788ff20324e099 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 22 May 2019 14:48:51 +0200 Subject: [PATCH] sql: implement EXPLODE and generators This pull request implements the EXPLODE expressions and generators, which are used together to provide generation of rows based on a column or expression that can yield multiple values. For this, there have been some quite core changes: - SQL method of sql.Type now returns the sqltypes.Value and an error. This avoids panics in this method. Since generators can yield errors it's better to just return an error and not just kill the server. - Array type can now internally handle either arrays or generators, making it transparent for the user. If it's used in an EXPLODE expression, the generator will be used. Otherwise, it will be converted automatically to an array and used without the user even knowing what happened behind the scenes. That way, we don't need yet another type to represent generators. Aside from that, there are some new additions: - New EXPLODE function, which is just a placeholder. EXPLODE type is the underlying type of its argument for resolution purposes. During analysis Explode nodes will be replaced by Generate functions and a Generate node. - New Generate function, which is the non-placeholder version of EXPLODE and the one that goes into the final execution tree. This one returns the same type as its argument and let's the Generate plan node be the one that returns the underlying type once the values are generated. - Generate node, which wraps a Project in which there is one and only one explode expression. - resolve_generators analysis rule, which will turn projects with an Explode expression into a Generate node with the project as a children, replacing the Explode expressions with Generate expressions. - validate_explode_usage validation rule, which will ensure explode is not used outside a Project node. Signed-off-by: Miguel Molina --- engine_test.go | 60 +++++++++ server/handler.go | 17 ++- sql/analyzer/resolve_generators.go | 97 ++++++++++++++ sql/analyzer/resolve_generators_test.go | 117 +++++++++++++++++ sql/analyzer/rules.go | 1 + sql/analyzer/validation_rules.go | 32 +++++ sql/analyzer/validation_rules_test.go | 92 +++++++++++++ sql/expression/function/explode.go | 95 ++++++++++++++ sql/expression/function/registry.go | 1 + sql/generator.go | 57 ++++++++ sql/generator_test.go | 54 ++++++++ sql/plan/generate.go | 152 ++++++++++++++++++++++ sql/plan/generate_test.go | 88 +++++++++++++ sql/type.go | 166 ++++++++++++++++-------- sql/type_test.go | 47 +++++-- 15 files changed, 1010 insertions(+), 66 deletions(-) create mode 100644 sql/analyzer/resolve_generators.go create mode 100644 sql/analyzer/resolve_generators_test.go create mode 100644 sql/expression/function/explode.go create mode 100644 sql/generator.go create mode 100644 sql/generator_test.go create mode 100644 sql/plan/generate.go create mode 100644 sql/plan/generate_test.go diff --git a/engine_test.go b/engine_test.go index e7cc82846..1b0d39cdc 100644 --- a/engine_test.go +++ b/engine_test.go @@ -2459,6 +2459,66 @@ func TestDescribeNoPruneColumns(t *testing.T) { require.Len(p.Schema(), 3) } +var generatorQueries = []struct { + query string + expected []sql.Row +}{ + { + `SELECT a, EXPLODE(b), c FROM t`, + []sql.Row{ + {int64(1), "a", "first"}, + {int64(1), "b", "first"}, + {int64(2), "c", "second"}, + {int64(2), "d", "second"}, + {int64(3), "e", "third"}, + {int64(3), "f", "third"}, + }, + }, + { + `SELECT a, EXPLODE(b) AS x, c FROM t`, + []sql.Row{ + {int64(1), "a", "first"}, + {int64(1), "b", "first"}, + {int64(2), "c", "second"}, + {int64(2), "d", "second"}, + {int64(3), "e", "third"}, + {int64(3), "f", "third"}, + }, + }, + { + `SELECT a, EXPLODE(b) AS x, c FROM t WHERE x = 'e'`, + []sql.Row{ + {int64(3), "e", "third"}, + }, + }, +} + +func TestGenerators(t *testing.T) { + table := mem.NewPartitionedTable("t", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "t"}, + {Name: "c", Type: sql.Text, Source: "t"}, + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(int64(1), []interface{}{"a", "b"}, "first"), + sql.NewRow(int64(2), []interface{}{"c", "d"}, "second"), + sql.NewRow(int64(3), []interface{}{"e", "f"}, "third"), + ) + + db := mem.NewDatabase("db") + db.AddTable("t", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + e := sqle.New(catalog, analyzer.NewDefault(catalog), new(sqle.Config)) + + for _, q := range generatorQueries { + testQuery(t, e, q.query, q.expected) + } +} + func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) { t.Helper() diff --git a/server/handler.go b/server/handler.go index 34e2e70a3..06f048ab5 100644 --- a/server/handler.go +++ b/server/handler.go @@ -125,7 +125,12 @@ func (h *Handler) ComQuery( return err } - r.Rows = append(r.Rows, rowToSQL(schema, row)) + outputRow, err := rowToSQL(schema, row) + if err != nil { + return err + } + + r.Rows = append(r.Rows, outputRow) r.RowsAffected++ } @@ -203,13 +208,17 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) { return true, nil } -func rowToSQL(s sql.Schema, row sql.Row) []sqltypes.Value { +func rowToSQL(s sql.Schema, row sql.Row) ([]sqltypes.Value, error) { o := make([]sqltypes.Value, len(row)) + var err error for i, v := range row { - o[i] = s[i].Type.SQL(v) + o[i], err = s[i].Type.SQL(v) + if err != nil { + return nil, err + } } - return o + return o, nil } func schemaToFields(s sql.Schema) []*query.Field { diff --git a/sql/analyzer/resolve_generators.go b/sql/analyzer/resolve_generators.go new file mode 100644 index 000000000..437cf332a --- /dev/null +++ b/sql/analyzer/resolve_generators.go @@ -0,0 +1,97 @@ +package analyzer + +import ( + "gopkg.in/src-d/go-errors.v1" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" +) + +var ( + errMultipleGenerators = errors.NewKind("there can't be more than 1 instance of EXPLODE in a SELECT") + errExplodeNotArray = errors.NewKind("argument of type %q given to EXPLODE, expecting array") +) + +func resolveGenerators(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + p, ok := n.(*plan.Project) + if !ok { + return n, nil + } + + projection := p.Projections + + g, err := findGenerator(projection) + if err != nil { + return nil, err + } + + // There might be no generator in the project, in that case we don't + // have to do anything. + if g == nil { + return n, nil + } + + projection[g.idx] = g.expr + + var name string + if n, ok := g.expr.(sql.Nameable); ok { + name = n.Name() + } else { + name = g.expr.String() + } + + return plan.NewGenerate( + plan.NewProject(projection, p.Child), + expression.NewGetField(g.idx, g.expr.Type(), name, g.expr.IsNullable()), + ), nil + }) +} + +type generator struct { + idx int + expr sql.Expression +} + +// findGenerator will find in the given projection a generator column. If there +// is no generator, it will return nil. +// If there are is than one generator or the argument to explode is not an +// array it will fail. +// All occurrences of Explode will be replaced with Generate. +func findGenerator(exprs []sql.Expression) (*generator, error) { + var g = &generator{idx: -1} + for i, e := range exprs { + var found bool + switch e := e.(type) { + case *function.Explode: + found = true + g.expr = function.NewGenerate(e.Child) + case *expression.Alias: + if exp, ok := e.Child.(*function.Explode); ok { + found = true + g.expr = expression.NewAlias( + function.NewGenerate(exp.Child), + e.Name(), + ) + } + } + + if found { + if g.idx >= 0 { + return nil, errMultipleGenerators.New() + } + g.idx = i + + if !sql.IsArray(g.expr.Type()) { + return nil, errExplodeNotArray.New(g.expr.Type()) + } + } + } + + if g.expr == nil { + return nil, nil + } + + return g, nil +} diff --git a/sql/analyzer/resolve_generators_test.go b/sql/analyzer/resolve_generators_test.go new file mode 100644 index 000000000..0709e0d81 --- /dev/null +++ b/sql/analyzer/resolve_generators_test.go @@ -0,0 +1,117 @@ +package analyzer + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func TestResolveGenerators(t *testing.T) { + testCases := []struct { + name string + node sql.Node + expected sql.Node + err *errors.Kind + }{ + { + name: "regular explode", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewGenerate(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expression.NewGetField(1, sql.Array(sql.Int64), "EXPLODE(b)", false), + ), + err: nil, + }, + { + name: "explode with alias", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewAlias( + function.NewExplode( + expression.NewGetField(1, sql.Array(sql.Int64), "b", false), + ), + "x", + ), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(1, sql.Array(sql.Int64), "b", false), + ), + "x", + ), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expression.NewGetField(1, sql.Array(sql.Int64), "x", false), + ), + err: nil, + }, + { + name: "non array type on explode", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Int64, "b", false)), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: nil, + err: errExplodeNotArray, + }, + { + name: "more than one generator", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + function.NewExplode(expression.NewGetField(2, sql.Array(sql.Int64), "c", false)), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: nil, + err: errMultipleGenerators, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := resolveGenerators(sql.NewEmptyContext(), nil, tt.node) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + }) + } +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index afc1152bc..c2b8daf00 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -34,6 +34,7 @@ var OnceBeforeDefault = []Rule{ // OnceAfterDefault contains the rules to be applied just once after the // DefaultRules. var OnceAfterDefault = []Rule{ + {"resolve_generators", resolveGenerators}, {"remove_unnecessary_converts", removeUnnecessaryConverts}, {"assign_catalog", assignCatalog}, {"prune_columns", pruneColumns}, diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 3d4dae59f..06d1f13b6 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -19,6 +19,7 @@ const ( validateIndexCreationRule = "validate_index_creation" validateCaseResultTypesRule = "validate_case_result_types" validateIntervalUsageRule = "validate_interval_usage" + validateExplodeUsageRule = "validate_explode_usage" ) var ( @@ -51,6 +52,11 @@ var ( "invalid use of an interval, which can only be used with DATE_ADD, " + "DATE_SUB and +/- operators to subtract from or add to a date", ) + // ErrExplodeInvalidUse is returned when an EXPLODE function is used + // outside a Project node. + ErrExplodeInvalidUse = errors.NewKind( + "using EXPLODE is not supported outside a Project node", + ) ) // DefaultValidationRules to apply while analyzing nodes. @@ -63,6 +69,7 @@ var DefaultValidationRules = []Rule{ {validateIndexCreationRule, validateIndexCreation}, {validateCaseResultTypesRule, validateCaseResultTypes}, {validateIntervalUsageRule, validateIntervalUsage}, + {validateExplodeUsageRule, validateExplodeUsage}, } func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -290,6 +297,31 @@ func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, return n, nil } +func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + var invalid bool + plan.InspectExpressions(n, func(e sql.Expression) bool { + // If it's already invalid just skip everything else. + if invalid { + return false + } + + // All usage of Explode will be incorrect because the ones in projects + // would have already been converted to Generate, so we only have to + // look for those. + if _, ok := e.(*function.Explode); ok { + invalid = true + } + + return true + }) + + if invalid { + return nil, ErrExplodeInvalidUse.New() + } + + return n, nil +} + func stringContains(strs []string, target string) bool { for _, s := range strs { if s == target { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 2d915deff..7bb83d1bb 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -582,6 +582,98 @@ func TestValidateIntervalUsage(t *testing.T) { } } +func TestValidateExplodeUsage(t *testing.T) { + testCases := []struct { + name string + node sql.Node + ok bool + }{ + { + "valid", + plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + true, + }, + { + "where", + plan.NewFilter( + function.NewArrayLength( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + ), + plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + ), + false, + }, + { + "group by", + plan.NewGenerate( + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + []sql.Expression{ + expression.NewAlias( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + _, err := validateExplodeUsage(sql.NewEmptyContext(), nil, tt.node) + if tt.ok { + require.NoError(err) + } else { + require.Error(err) + require.True(ErrExplodeInvalidUse.Is(err)) + } + }) + } +} + type dummyNode struct{ resolved bool } func (n dummyNode) String() string { return "dummynode" } diff --git a/sql/expression/function/explode.go b/sql/expression/function/explode.go new file mode 100644 index 000000000..5c96f509a --- /dev/null +++ b/sql/expression/function/explode.go @@ -0,0 +1,95 @@ +package function + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +// Explode is a function that generates a row for each value of its child. +// It is a placeholder expression node. +type Explode struct { + Child sql.Expression +} + +// NewExplode creates a new Explode function. +func NewExplode(child sql.Expression) sql.Expression { + return &Explode{child} +} + +// Resolved implements the sql.Expression interface. +func (e *Explode) Resolved() bool { return e.Child.Resolved() } + +// Children implements the sql.Expression interface. +func (e *Explode) Children() []sql.Expression { return []sql.Expression{e.Child} } + +// IsNullable implements the sql.Expression interface. +func (e *Explode) IsNullable() bool { return e.Child.IsNullable() } + +// Type implements the sql.Expression interface. +func (e *Explode) Type() sql.Type { + return sql.UnderlyingType(e.Child.Type()) +} + +// Eval implements the sql.Expression interface. +func (e *Explode) Eval(*sql.Context, sql.Row) (interface{}, error) { + panic("eval method of Explode is only a placeholder") +} + +func (e *Explode) String() string { + return fmt.Sprintf("EXPLODE(%s)", e.Child) +} + +// TransformUp implements the sql.Expression interface. +func (e *Explode) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + c, err := f(e.Child) + if err != nil { + return nil, err + } + + return f(NewExplode(c)) +} + +// Generate is a function that generates a row for each value of its child. +// This is the non-placeholder counterpart of Explode. +type Generate struct { + Child sql.Expression +} + +// NewGenerate creates a new Generate function. +func NewGenerate(child sql.Expression) sql.Expression { + return &Generate{child} +} + +// Resolved implements the sql.Expression interface. +func (e *Generate) Resolved() bool { return e.Child.Resolved() } + +// Children implements the sql.Expression interface. +func (e *Generate) Children() []sql.Expression { return []sql.Expression{e.Child} } + +// IsNullable implements the sql.Expression interface. +func (e *Generate) IsNullable() bool { return e.Child.IsNullable() } + +// Type implements the sql.Expression interface. +func (e *Generate) Type() sql.Type { + return e.Child.Type() +} + +// Eval implements the sql.Expression interface. +func (e *Generate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return e.Child.Eval(ctx, row) +} + +func (e *Generate) String() string { + return fmt.Sprintf("EXPLODE(%s)", e.Child) +} + +// TransformUp implements the sql.Expression interface. +func (e *Generate) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + c, err := f(e.Child) + if err != nil { + return nil, err + } + + return f(NewGenerate(c)) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index ffa93a317..c1712ee5b 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -88,4 +88,5 @@ var Defaults = []sql.Function{ sql.Function1{Name: "length", Fn: NewLength}, sql.Function1{Name: "char_length", Fn: NewCharLength}, sql.Function1{Name: "character_length", Fn: NewCharLength}, + sql.Function1{Name: "explode", Fn: NewExplode}, } diff --git a/sql/generator.go b/sql/generator.go new file mode 100644 index 000000000..218b1db66 --- /dev/null +++ b/sql/generator.go @@ -0,0 +1,57 @@ +package sql + +import ( + "io" + + "gopkg.in/src-d/go-errors.v1" +) + +// Generator will generate a set of values for a given row. +type Generator interface { + // Next value in the generator. + Next() (interface{}, error) + // Close the generator and dispose resources. + Close() error +} + +// ErrNotGenerator is returned when the value cannot be converted to a +// generator. +var ErrNotGenerator = errors.NewKind("cannot convert value of type %T to a generator") + +// ToGenerator converts a value to a generator if possible. +func ToGenerator(v interface{}) (Generator, error) { + switch v := v.(type) { + case Generator: + return v, nil + case []interface{}: + return NewArrayGenerator(v), nil + case nil: + return NewArrayGenerator(nil), nil + default: + return nil, ErrNotGenerator.New(v) + } +} + +// NewArrayGenerator creates a generator for a given array. +func NewArrayGenerator(array []interface{}) Generator { + return &arrayGenerator{array, 0} +} + +type arrayGenerator struct { + array []interface{} + pos int +} + +func (g *arrayGenerator) Next() (interface{}, error) { + if g.pos >= len(g.array) { + return nil, io.EOF + } + + g.pos++ + return g.array[g.pos-1], nil +} + +func (g *arrayGenerator) Close() error { + g.pos = len(g.array) + return nil +} diff --git a/sql/generator_test.go b/sql/generator_test.go new file mode 100644 index 000000000..145411333 --- /dev/null +++ b/sql/generator_test.go @@ -0,0 +1,54 @@ +package sql + +import ( + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestArrayGenerator(t *testing.T) { + require := require.New(t) + + expected := []interface{}{"a", "b", "c"} + gen := NewArrayGenerator(expected) + + var values []interface{} + for { + v, err := gen.Next() + if err != nil { + if err == io.EOF { + break + } + require.NoError(err) + } + values = append(values, v) + } + + require.Equal(expected, values) +} + +func TestToGenerator(t *testing.T) { + require := require.New(t) + + gen, err := ToGenerator([]interface{}{1, 2, 3}) + require.NoError(err) + require.Equal(NewArrayGenerator([]interface{}{1, 2, 3}), gen) + + gen, err = ToGenerator(new(fakeGen)) + require.NoError(err) + require.Equal(new(fakeGen), gen) + + gen, err = ToGenerator(nil) + require.NoError(err) + require.Equal(NewArrayGenerator(nil), gen) + + _, err = ToGenerator("foo") + require.Error(err) +} + +type fakeGen struct{} + +func (fakeGen) Next() (interface{}, error) { return nil, fmt.Errorf("not implemented") } +func (fakeGen) Close() error { return nil } diff --git a/sql/plan/generate.go b/sql/plan/generate.go new file mode 100644 index 000000000..841a45890 --- /dev/null +++ b/sql/plan/generate.go @@ -0,0 +1,152 @@ +package plan + +import ( + "fmt" + "io" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Generate will explode rows using a generator. +type Generate struct { + UnaryNode + Column *expression.GetField +} + +// NewGenerate creates a new generate node. +func NewGenerate(child sql.Node, col *expression.GetField) *Generate { + return &Generate{UnaryNode{child}, col} +} + +// Schema implements the sql.Node interface. +func (g *Generate) Schema() sql.Schema { + s := g.Child.Schema() + col := s[g.Column.Index()] + s[g.Column.Index()] = &sql.Column{ + Name: g.Column.Name(), + Type: sql.UnderlyingType(col.Type), + Nullable: col.Nullable, + } + return s +} + +// RowIter implements the sql.Node interface. +func (g *Generate) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.Generate") + + childIter, err := g.Child.RowIter(ctx) + if err != nil { + return nil, err + } + + return sql.NewSpanIter(span, &generateIter{ + child: childIter, + idx: g.Column.Index(), + }), nil +} + +func (g *Generate) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { + col, err := g.Column.TransformUp(f) + if err != nil { + return nil, err + } + + field, ok := col.(*expression.GetField) + if !ok { + return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) + } + + return NewGenerate(g.Child, field), nil +} + +// TransformUp implements the sql.Node interface. +func (g *Generate) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { + child, err := g.Child.TransformUp(f) + if err != nil { + return nil, err + } + + return f(NewGenerate(child, g.Column)) +} + +// TransformExpressionsUp implements the sql.Node interface. +func (g *Generate) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + child, err := g.Child.TransformExpressionsUp(f) + if err != nil { + return nil, err + } + + col, err := g.Column.TransformUp(f) + if err != nil { + return nil, err + } + + field, ok := col.(*expression.GetField) + if !ok { + return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) + } + + return NewGenerate(child, field), nil +} + +func (g *Generate) String() string { + tp := sql.NewTreePrinter() + _ = tp.WriteNode("Generate(%s)", g.Column) + _ = tp.WriteChildren(g.Child.String()) + return tp.String() +} + +type generateIter struct { + child sql.RowIter + idx int + + gen sql.Generator + row sql.Row +} + +func (i *generateIter) Next() (sql.Row, error) { + for { + if i.gen == nil { + var err error + i.row, err = i.child.Next() + if err != nil { + return nil, err + } + + i.gen, err = sql.ToGenerator(i.row[i.idx]) + if err != nil { + return nil, err + } + } + + val, err := i.gen.Next() + if err != nil { + if err == io.EOF { + if err := i.gen.Close(); err != nil { + return nil, err + } + + i.gen = nil + continue + } + return nil, err + } + + var row = make(sql.Row, len(i.row)) + copy(row, i.row) + row[i.idx] = val + return row, nil + } +} + +func (i *generateIter) Close() error { + if i.gen != nil { + if err := i.gen.Close(); err != nil { + _ = i.child.Close() + return err + } + } + + return i.child.Close() +} diff --git a/sql/plan/generate_test.go b/sql/plan/generate_test.go new file mode 100644 index 000000000..ba35f33cf --- /dev/null +++ b/sql/plan/generate_test.go @@ -0,0 +1,88 @@ +package plan + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +func TestGenerateRowIter(t *testing.T) { + require := require.New(t) + + child := newFakeNode( + sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }, + sql.RowsToRowIter( + sql.Row{"first", sql.NewArrayGenerator([]interface{}{"a", "b"}), int64(1)}, + sql.Row{"second", sql.NewArrayGenerator([]interface{}{"c", "d"}), int64(2)}, + ), + ) + + iter, err := NewGenerate( + child, + expression.NewGetFieldWithTable(1, sql.Array(sql.Text), "foo", "b", false), + ).RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"first", "a", int64(1)}, + {"first", "b", int64(1)}, + {"second", "c", int64(2)}, + {"second", "d", int64(2)}, + } + + require.Equal(expected, rows) +} + +func TestGenerateSchema(t *testing.T) { + require := require.New(t) + + schema := NewGenerate( + newFakeNode( + sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }, + nil, + ), + expression.NewGetField(1, sql.Array(sql.Text), "foobar", false), + ).Schema() + + expected := sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "foobar", Type: sql.Text}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + } + + require.Equal(expected, schema) +} + +type fakeNode struct { + schema sql.Schema + iter sql.RowIter +} + +func newFakeNode(s sql.Schema, iter sql.RowIter) *fakeNode { + return &fakeNode{s, iter} +} + +func (n *fakeNode) Children() []sql.Node { return nil } +func (n *fakeNode) Resolved() bool { return true } +func (n *fakeNode) Schema() sql.Schema { return n.schema } +func (n *fakeNode) RowIter(*sql.Context) (sql.RowIter, error) { return n.iter, nil } +func (n *fakeNode) String() string { return "fakeNode" } +func (n *fakeNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { + panic("placeholder") +} +func (n *fakeNode) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { + panic("placeholder") +} diff --git a/sql/type.go b/sql/type.go index 341ab9506..bea1cf710 100644 --- a/sql/type.go +++ b/sql/type.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "math" "reflect" "strconv" @@ -148,7 +149,7 @@ type Type interface { // The result will be 0 if a==b, -1 if a < b, and +1 if a > b. Compare(interface{}, interface{}) (int, error) // SQL returns the sqltypes.Value for the given value. - SQL(interface{}) sqltypes.Value + SQL(interface{}) (sqltypes.Value, error) fmt.Stringer } @@ -266,8 +267,8 @@ func (t nullT) Type() query.Type { } // SQL implements Type interface. -func (t nullT) SQL(interface{}) sqltypes.Value { - return sqltypes.NULL +func (t nullT) SQL(interface{}) (sqltypes.Value, error) { + return sqltypes.NULL, nil } // Convert implements Type interface. @@ -300,26 +301,26 @@ func (t numberT) Type() query.Type { } // SQL implements Type interface. -func (t numberT) SQL(v interface{}) sqltypes.Value { +func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } switch t.t { case sqltypes.Int32: - return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int64: - return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Uint32: - return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint64: - return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Float32: - return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)) + return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)), nil case sqltypes.Float64: - return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)) + return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)), nil default: - return sqltypes.MakeTrusted(t.t, []byte{}) + return sqltypes.MakeTrusted(t.t, []byte{}), nil } } @@ -426,16 +427,20 @@ var TimestampLayouts = []string{ } // SQL implements Type interface. -func (t timestampT) SQL(v interface{}) sqltypes.Value { +func (t timestampT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err } - time := MustConvert(t, v).(time.Time) return sqltypes.MakeTrusted( sqltypes.Timestamp, - []byte(time.Format(TimestampLayout)), - ) + []byte(v.(time.Time).Format(TimestampLayout)), + ), nil } // Convert implements Type interface. @@ -498,16 +503,20 @@ func (t dateT) Type() query.Type { return sqltypes.Date } -func (t dateT) SQL(v interface{}) sqltypes.Value { +func (t dateT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err } - time := MustConvert(t, v).(time.Time) return sqltypes.MakeTrusted( sqltypes.Timestamp, - []byte(time.Format(DateLayout)), - ) + []byte(v.(time.Time).Format(DateLayout)), + ), nil } func (t dateT) Convert(v interface{}) (interface{}, error) { @@ -551,12 +560,17 @@ func (t textT) Type() query.Type { } // SQL implements Type interface. -func (t textT) SQL(v interface{}) sqltypes.Value { +func (t textT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } - return sqltypes.MakeTrusted(sqltypes.Text, []byte(MustConvert(t, v).(string))) + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.Text, []byte(v.(string))), nil } // Convert implements Type interface. @@ -583,9 +597,9 @@ func (t booleanT) Type() query.Type { } // SQL implements Type interface. -func (t booleanT) SQL(v interface{}) sqltypes.Value { +func (t booleanT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } b := []byte{'0'} @@ -593,7 +607,7 @@ func (t booleanT) SQL(v interface{}) sqltypes.Value { b[0] = '1' } - return sqltypes.MakeTrusted(sqltypes.Bit, b) + return sqltypes.MakeTrusted(sqltypes.Bit, b), nil } // Convert implements Type interface. @@ -655,12 +669,17 @@ func (t blobT) Type() query.Type { } // SQL implements Type interface. -func (t blobT) SQL(v interface{}) sqltypes.Value { +func (t blobT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err } - return sqltypes.MakeTrusted(sqltypes.Blob, MustConvert(t, v).([]byte)) + return sqltypes.MakeTrusted(sqltypes.Blob, v.([]byte)), nil } // Convert implements Type interface. @@ -694,11 +713,17 @@ func (t jsonT) Type() query.Type { } // SQL implements Type interface. -func (t jsonT) SQL(v interface{}) sqltypes.Value { +func (t jsonT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } - return sqltypes.MakeTrusted(sqltypes.TypeJSON, MustConvert(t, v).([]byte)) + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.TypeJSON, v.([]byte)), nil } // Convert implements Type interface. @@ -734,12 +759,12 @@ func (t tupleT) Type() query.Type { return sqltypes.Expression } -func (t tupleT) SQL(v interface{}) sqltypes.Value { +func (t tupleT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } - panic("unable to convert tuple type to SQL") + return sqltypes.Value{}, fmt.Errorf("unable to convert tuple type to SQL") } func (t tupleT) Convert(v interface{}) (interface{}, error) { @@ -799,24 +824,58 @@ func (t arrayT) Type() query.Type { return sqltypes.TypeJSON } -func (t arrayT) SQL(v interface{}) sqltypes.Value { +func (t arrayT) SQL(v interface{}) (sqltypes.Value, error) { + if _, ok := v.(nullT); ok { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + return JSON.SQL(v) } func (t arrayT) Convert(v interface{}) (interface{}, error) { - if vals, ok := v.([]interface{}); ok { - var result = make([]interface{}, len(vals)) - for i, v := range vals { + switch v := v.(type) { + case []interface{}: + var result = make([]interface{}, len(v)) + for i, v := range v { var err error result[i], err = t.underlying.Convert(v) if err != nil { return nil, err } } - return result, nil + case Generator: + var values []interface{} + for { + val, err := v.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + val, err = t.underlying.Convert(val) + if err != nil { + return nil, err + } + + values = append(values, val) + } + + if err := v.Close(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, ErrNotArray.New(v) } - return nil, ErrNotArray.New(v) } func (t arrayT) Compare(a, b interface{}) (int, error) { @@ -853,16 +912,6 @@ func (t arrayT) Compare(a, b interface{}) (int, error) { return 0, nil } -// MustConvert calls the Convert function from a given Type, it err panics. -func MustConvert(t Type, v interface{}) interface{} { - c, err := t.Convert(v) - if err != nil { - panic(err) - } - - return c -} - // IsNumber checks if t is a number type func IsNumber(t Type) bool { return IsInteger(t) || IsDecimal(t) @@ -961,3 +1010,14 @@ func MySQLTypeName(t Type) string { return "UNKNOWN" } } + +// UnderlyingType returns the underlying type of an array if the type is an +// array, or the type itself in any other case. +func UnderlyingType(t Type) Type { + a, ok := t.(arrayT) + if !ok { + return t + } + + return a.underlying +} diff --git a/sql/type_test.go b/sql/type_test.go index 0b3eaa629..fb8de7f3b 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -13,8 +13,8 @@ func TestIsNull(t *testing.T) { require.True(t, IsNull(nil)) n := numberT{sqltypes.Uint64} - require.Equal(t, sqltypes.NULL, n.SQL(Null)) - require.Equal(t, sqltypes.NewUint64(0), n.SQL(uint64(0))) + require.Equal(t, sqltypes.NULL, mustSQL(n.SQL(Null))) + require.Equal(t, sqltypes.NewUint64(0), mustSQL(n.SQL(uint64(0)))) } func TestText(t *testing.T) { @@ -70,7 +70,8 @@ func TestFloat64(t *testing.T) { var f = numberT{ t: query.Type_FLOAT64, } - val := f.SQL(23.222) + val, err := f.SQL(23.222) + require.NoError(err) require.True(val.IsFloat()) require.Equal(sqltypes.NewFloat64(23.222), val) } @@ -97,7 +98,8 @@ func TestTimestamp(t *testing.T) { v.(time.Time).Format(TimestampLayout), ) - sql := Timestamp.SQL(now) + sql, err := Timestamp.SQL(now) + require.NoError(err) require.Equal([]byte(now.Format(TimestampLayout)), sql.Raw()) after := now.Add(time.Second) @@ -167,7 +169,8 @@ func TestDate(t *testing.T) { v.(time.Time).Format(DateLayout), ) - sql := Date.SQL(now) + sql, err := Date.SQL(now) + require.NoError(err) require.Equal([]byte(now.Format(DateLayout)), sql.Raw()) after := now.Add(time.Second) @@ -186,7 +189,6 @@ func TestBlob(t *testing.T) { convert(t, Blob, "", []byte{}) convert(t, Blob, nil, []byte(nil)) - MustConvert(Blob, nil) _, err := Blob.Convert(1) require.NotNil(err) @@ -221,9 +223,8 @@ func TestTuple(t *testing.T) { convert(t, typ, []interface{}{1, 2, 3}, []interface{}{int32(1), "2", int64(3)}) - require.Panics(func() { - typ.SQL(nil) - }) + _, err = typ.SQL(nil) + require.Error(err) require.Equal(sqltypes.Expression, typ.Type()) @@ -245,6 +246,12 @@ func TestArray(t *testing.T) { require.True(ErrNotArray.Is(err)) convert(t, typ, []interface{}{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}) + convert( + t, + typ, + NewArrayGenerator([]interface{}{1, 2, 3}), + []interface{}{int64(1), int64(2), int64(3)}, + ) require.Equal(sqltypes.TypeJSON, typ.Type()) @@ -257,6 +264,21 @@ func TestArray(t *testing.T) { gt(t, typ, []interface{}{1, 3, 3}, []interface{}{1, 2, 3}) gt(t, typ, []interface{}{1, 2, 4}, []interface{}{1, 2, 3}) gt(t, typ, []interface{}{1, 2, 4}, []interface{}{5, 6}) + + expected := []byte("[1,2,3]") + + v, err := Array(Int64).SQL([]interface{}{1, 2, 3}) + require.NoError(err) + require.Equal(expected, v.Raw()) + + v, err = Array(Int64).SQL(NewArrayGenerator([]interface{}{1, 2, 3})) + require.NoError(err) + require.Equal(expected, v.Raw()) +} + +func TestUnderlyingType(t *testing.T) { + require.Equal(t, Text, UnderlyingType(Array(Text))) + require.Equal(t, Text, UnderlyingType(Text)) } func eq(t *testing.T, typ Type, a, b interface{}) { @@ -292,3 +314,10 @@ func convertErr(t *testing.T, typ Type, val interface{}) { _, err := typ.Convert(val) require.Error(t, err) } + +func mustSQL(v sqltypes.Value, err error) sqltypes.Value { + if err != nil { + panic(err) + } + return v +}