diff --git a/engine_test.go b/engine_test.go index e7cc82846..1b0d39cdc 100644 --- a/engine_test.go +++ b/engine_test.go @@ -2459,6 +2459,66 @@ func TestDescribeNoPruneColumns(t *testing.T) { require.Len(p.Schema(), 3) } +var generatorQueries = []struct { + query string + expected []sql.Row +}{ + { + `SELECT a, EXPLODE(b), c FROM t`, + []sql.Row{ + {int64(1), "a", "first"}, + {int64(1), "b", "first"}, + {int64(2), "c", "second"}, + {int64(2), "d", "second"}, + {int64(3), "e", "third"}, + {int64(3), "f", "third"}, + }, + }, + { + `SELECT a, EXPLODE(b) AS x, c FROM t`, + []sql.Row{ + {int64(1), "a", "first"}, + {int64(1), "b", "first"}, + {int64(2), "c", "second"}, + {int64(2), "d", "second"}, + {int64(3), "e", "third"}, + {int64(3), "f", "third"}, + }, + }, + { + `SELECT a, EXPLODE(b) AS x, c FROM t WHERE x = 'e'`, + []sql.Row{ + {int64(3), "e", "third"}, + }, + }, +} + +func TestGenerators(t *testing.T) { + table := mem.NewPartitionedTable("t", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "t"}, + {Name: "c", Type: sql.Text, Source: "t"}, + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(int64(1), []interface{}{"a", "b"}, "first"), + sql.NewRow(int64(2), []interface{}{"c", "d"}, "second"), + sql.NewRow(int64(3), []interface{}{"e", "f"}, "third"), + ) + + db := mem.NewDatabase("db") + db.AddTable("t", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + e := sqle.New(catalog, analyzer.NewDefault(catalog), new(sqle.Config)) + + for _, q := range generatorQueries { + testQuery(t, e, q.query, q.expected) + } +} + func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) { t.Helper() diff --git a/server/handler.go b/server/handler.go index 34e2e70a3..06f048ab5 100644 --- a/server/handler.go +++ b/server/handler.go @@ -125,7 +125,12 @@ func (h *Handler) ComQuery( return err } - r.Rows = append(r.Rows, rowToSQL(schema, row)) + outputRow, err := rowToSQL(schema, row) + if err != nil { + return err + } + + r.Rows = append(r.Rows, outputRow) r.RowsAffected++ } @@ -203,13 +208,17 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) { return true, nil } -func rowToSQL(s sql.Schema, row sql.Row) []sqltypes.Value { +func rowToSQL(s sql.Schema, row sql.Row) ([]sqltypes.Value, error) { o := make([]sqltypes.Value, len(row)) + var err error for i, v := range row { - o[i] = s[i].Type.SQL(v) + o[i], err = s[i].Type.SQL(v) + if err != nil { + return nil, err + } } - return o + return o, nil } func schemaToFields(s sql.Schema) []*query.Field { diff --git a/sql/analyzer/resolve_generators.go b/sql/analyzer/resolve_generators.go new file mode 100644 index 000000000..437cf332a --- /dev/null +++ b/sql/analyzer/resolve_generators.go @@ -0,0 +1,97 @@ +package analyzer + +import ( + "gopkg.in/src-d/go-errors.v1" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" +) + +var ( + errMultipleGenerators = errors.NewKind("there can't be more than 1 instance of EXPLODE in a SELECT") + errExplodeNotArray = errors.NewKind("argument of type %q given to EXPLODE, expecting array") +) + +func resolveGenerators(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + p, ok := n.(*plan.Project) + if !ok { + return n, nil + } + + projection := p.Projections + + g, err := findGenerator(projection) + if err != nil { + return nil, err + } + + // There might be no generator in the project, in that case we don't + // have to do anything. + if g == nil { + return n, nil + } + + projection[g.idx] = g.expr + + var name string + if n, ok := g.expr.(sql.Nameable); ok { + name = n.Name() + } else { + name = g.expr.String() + } + + return plan.NewGenerate( + plan.NewProject(projection, p.Child), + expression.NewGetField(g.idx, g.expr.Type(), name, g.expr.IsNullable()), + ), nil + }) +} + +type generator struct { + idx int + expr sql.Expression +} + +// findGenerator will find in the given projection a generator column. If there +// is no generator, it will return nil. +// If there are is than one generator or the argument to explode is not an +// array it will fail. +// All occurrences of Explode will be replaced with Generate. +func findGenerator(exprs []sql.Expression) (*generator, error) { + var g = &generator{idx: -1} + for i, e := range exprs { + var found bool + switch e := e.(type) { + case *function.Explode: + found = true + g.expr = function.NewGenerate(e.Child) + case *expression.Alias: + if exp, ok := e.Child.(*function.Explode); ok { + found = true + g.expr = expression.NewAlias( + function.NewGenerate(exp.Child), + e.Name(), + ) + } + } + + if found { + if g.idx >= 0 { + return nil, errMultipleGenerators.New() + } + g.idx = i + + if !sql.IsArray(g.expr.Type()) { + return nil, errExplodeNotArray.New(g.expr.Type()) + } + } + } + + if g.expr == nil { + return nil, nil + } + + return g, nil +} diff --git a/sql/analyzer/resolve_generators_test.go b/sql/analyzer/resolve_generators_test.go new file mode 100644 index 000000000..0709e0d81 --- /dev/null +++ b/sql/analyzer/resolve_generators_test.go @@ -0,0 +1,117 @@ +package analyzer + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func TestResolveGenerators(t *testing.T) { + testCases := []struct { + name string + node sql.Node + expected sql.Node + err *errors.Kind + }{ + { + name: "regular explode", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewGenerate(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expression.NewGetField(1, sql.Array(sql.Int64), "EXPLODE(b)", false), + ), + err: nil, + }, + { + name: "explode with alias", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewAlias( + function.NewExplode( + expression.NewGetField(1, sql.Array(sql.Int64), "b", false), + ), + "x", + ), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(1, sql.Array(sql.Int64), "b", false), + ), + "x", + ), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expression.NewGetField(1, sql.Array(sql.Int64), "x", false), + ), + err: nil, + }, + { + name: "non array type on explode", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Int64, "b", false)), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: nil, + err: errExplodeNotArray, + }, + { + name: "more than one generator", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + function.NewExplode(expression.NewGetField(2, sql.Array(sql.Int64), "c", false)), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: nil, + err: errMultipleGenerators, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := resolveGenerators(sql.NewEmptyContext(), nil, tt.node) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + }) + } +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index afc1152bc..c2b8daf00 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -34,6 +34,7 @@ var OnceBeforeDefault = []Rule{ // OnceAfterDefault contains the rules to be applied just once after the // DefaultRules. var OnceAfterDefault = []Rule{ + {"resolve_generators", resolveGenerators}, {"remove_unnecessary_converts", removeUnnecessaryConverts}, {"assign_catalog", assignCatalog}, {"prune_columns", pruneColumns}, diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 3d4dae59f..06d1f13b6 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -19,6 +19,7 @@ const ( validateIndexCreationRule = "validate_index_creation" validateCaseResultTypesRule = "validate_case_result_types" validateIntervalUsageRule = "validate_interval_usage" + validateExplodeUsageRule = "validate_explode_usage" ) var ( @@ -51,6 +52,11 @@ var ( "invalid use of an interval, which can only be used with DATE_ADD, " + "DATE_SUB and +/- operators to subtract from or add to a date", ) + // ErrExplodeInvalidUse is returned when an EXPLODE function is used + // outside a Project node. + ErrExplodeInvalidUse = errors.NewKind( + "using EXPLODE is not supported outside a Project node", + ) ) // DefaultValidationRules to apply while analyzing nodes. @@ -63,6 +69,7 @@ var DefaultValidationRules = []Rule{ {validateIndexCreationRule, validateIndexCreation}, {validateCaseResultTypesRule, validateCaseResultTypes}, {validateIntervalUsageRule, validateIntervalUsage}, + {validateExplodeUsageRule, validateExplodeUsage}, } func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -290,6 +297,31 @@ func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, return n, nil } +func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + var invalid bool + plan.InspectExpressions(n, func(e sql.Expression) bool { + // If it's already invalid just skip everything else. + if invalid { + return false + } + + // All usage of Explode will be incorrect because the ones in projects + // would have already been converted to Generate, so we only have to + // look for those. + if _, ok := e.(*function.Explode); ok { + invalid = true + } + + return true + }) + + if invalid { + return nil, ErrExplodeInvalidUse.New() + } + + return n, nil +} + func stringContains(strs []string, target string) bool { for _, s := range strs { if s == target { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 2d915deff..7bb83d1bb 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -582,6 +582,98 @@ func TestValidateIntervalUsage(t *testing.T) { } } +func TestValidateExplodeUsage(t *testing.T) { + testCases := []struct { + name string + node sql.Node + ok bool + }{ + { + "valid", + plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + true, + }, + { + "where", + plan.NewFilter( + function.NewArrayLength( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + ), + plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + ), + false, + }, + { + "group by", + plan.NewGenerate( + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + []sql.Expression{ + expression.NewAlias( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + _, err := validateExplodeUsage(sql.NewEmptyContext(), nil, tt.node) + if tt.ok { + require.NoError(err) + } else { + require.Error(err) + require.True(ErrExplodeInvalidUse.Is(err)) + } + }) + } +} + type dummyNode struct{ resolved bool } func (n dummyNode) String() string { return "dummynode" } diff --git a/sql/expression/function/explode.go b/sql/expression/function/explode.go new file mode 100644 index 000000000..5c96f509a --- /dev/null +++ b/sql/expression/function/explode.go @@ -0,0 +1,95 @@ +package function + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +// Explode is a function that generates a row for each value of its child. +// It is a placeholder expression node. +type Explode struct { + Child sql.Expression +} + +// NewExplode creates a new Explode function. +func NewExplode(child sql.Expression) sql.Expression { + return &Explode{child} +} + +// Resolved implements the sql.Expression interface. +func (e *Explode) Resolved() bool { return e.Child.Resolved() } + +// Children implements the sql.Expression interface. +func (e *Explode) Children() []sql.Expression { return []sql.Expression{e.Child} } + +// IsNullable implements the sql.Expression interface. +func (e *Explode) IsNullable() bool { return e.Child.IsNullable() } + +// Type implements the sql.Expression interface. +func (e *Explode) Type() sql.Type { + return sql.UnderlyingType(e.Child.Type()) +} + +// Eval implements the sql.Expression interface. +func (e *Explode) Eval(*sql.Context, sql.Row) (interface{}, error) { + panic("eval method of Explode is only a placeholder") +} + +func (e *Explode) String() string { + return fmt.Sprintf("EXPLODE(%s)", e.Child) +} + +// TransformUp implements the sql.Expression interface. +func (e *Explode) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + c, err := f(e.Child) + if err != nil { + return nil, err + } + + return f(NewExplode(c)) +} + +// Generate is a function that generates a row for each value of its child. +// This is the non-placeholder counterpart of Explode. +type Generate struct { + Child sql.Expression +} + +// NewGenerate creates a new Generate function. +func NewGenerate(child sql.Expression) sql.Expression { + return &Generate{child} +} + +// Resolved implements the sql.Expression interface. +func (e *Generate) Resolved() bool { return e.Child.Resolved() } + +// Children implements the sql.Expression interface. +func (e *Generate) Children() []sql.Expression { return []sql.Expression{e.Child} } + +// IsNullable implements the sql.Expression interface. +func (e *Generate) IsNullable() bool { return e.Child.IsNullable() } + +// Type implements the sql.Expression interface. +func (e *Generate) Type() sql.Type { + return e.Child.Type() +} + +// Eval implements the sql.Expression interface. +func (e *Generate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return e.Child.Eval(ctx, row) +} + +func (e *Generate) String() string { + return fmt.Sprintf("EXPLODE(%s)", e.Child) +} + +// TransformUp implements the sql.Expression interface. +func (e *Generate) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + c, err := f(e.Child) + if err != nil { + return nil, err + } + + return f(NewGenerate(c)) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index ffa93a317..c1712ee5b 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -88,4 +88,5 @@ var Defaults = []sql.Function{ sql.Function1{Name: "length", Fn: NewLength}, sql.Function1{Name: "char_length", Fn: NewCharLength}, sql.Function1{Name: "character_length", Fn: NewCharLength}, + sql.Function1{Name: "explode", Fn: NewExplode}, } diff --git a/sql/generator.go b/sql/generator.go new file mode 100644 index 000000000..218b1db66 --- /dev/null +++ b/sql/generator.go @@ -0,0 +1,57 @@ +package sql + +import ( + "io" + + "gopkg.in/src-d/go-errors.v1" +) + +// Generator will generate a set of values for a given row. +type Generator interface { + // Next value in the generator. + Next() (interface{}, error) + // Close the generator and dispose resources. + Close() error +} + +// ErrNotGenerator is returned when the value cannot be converted to a +// generator. +var ErrNotGenerator = errors.NewKind("cannot convert value of type %T to a generator") + +// ToGenerator converts a value to a generator if possible. +func ToGenerator(v interface{}) (Generator, error) { + switch v := v.(type) { + case Generator: + return v, nil + case []interface{}: + return NewArrayGenerator(v), nil + case nil: + return NewArrayGenerator(nil), nil + default: + return nil, ErrNotGenerator.New(v) + } +} + +// NewArrayGenerator creates a generator for a given array. +func NewArrayGenerator(array []interface{}) Generator { + return &arrayGenerator{array, 0} +} + +type arrayGenerator struct { + array []interface{} + pos int +} + +func (g *arrayGenerator) Next() (interface{}, error) { + if g.pos >= len(g.array) { + return nil, io.EOF + } + + g.pos++ + return g.array[g.pos-1], nil +} + +func (g *arrayGenerator) Close() error { + g.pos = len(g.array) + return nil +} diff --git a/sql/generator_test.go b/sql/generator_test.go new file mode 100644 index 000000000..145411333 --- /dev/null +++ b/sql/generator_test.go @@ -0,0 +1,54 @@ +package sql + +import ( + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestArrayGenerator(t *testing.T) { + require := require.New(t) + + expected := []interface{}{"a", "b", "c"} + gen := NewArrayGenerator(expected) + + var values []interface{} + for { + v, err := gen.Next() + if err != nil { + if err == io.EOF { + break + } + require.NoError(err) + } + values = append(values, v) + } + + require.Equal(expected, values) +} + +func TestToGenerator(t *testing.T) { + require := require.New(t) + + gen, err := ToGenerator([]interface{}{1, 2, 3}) + require.NoError(err) + require.Equal(NewArrayGenerator([]interface{}{1, 2, 3}), gen) + + gen, err = ToGenerator(new(fakeGen)) + require.NoError(err) + require.Equal(new(fakeGen), gen) + + gen, err = ToGenerator(nil) + require.NoError(err) + require.Equal(NewArrayGenerator(nil), gen) + + _, err = ToGenerator("foo") + require.Error(err) +} + +type fakeGen struct{} + +func (fakeGen) Next() (interface{}, error) { return nil, fmt.Errorf("not implemented") } +func (fakeGen) Close() error { return nil } diff --git a/sql/plan/generate.go b/sql/plan/generate.go new file mode 100644 index 000000000..841a45890 --- /dev/null +++ b/sql/plan/generate.go @@ -0,0 +1,152 @@ +package plan + +import ( + "fmt" + "io" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Generate will explode rows using a generator. +type Generate struct { + UnaryNode + Column *expression.GetField +} + +// NewGenerate creates a new generate node. +func NewGenerate(child sql.Node, col *expression.GetField) *Generate { + return &Generate{UnaryNode{child}, col} +} + +// Schema implements the sql.Node interface. +func (g *Generate) Schema() sql.Schema { + s := g.Child.Schema() + col := s[g.Column.Index()] + s[g.Column.Index()] = &sql.Column{ + Name: g.Column.Name(), + Type: sql.UnderlyingType(col.Type), + Nullable: col.Nullable, + } + return s +} + +// RowIter implements the sql.Node interface. +func (g *Generate) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.Generate") + + childIter, err := g.Child.RowIter(ctx) + if err != nil { + return nil, err + } + + return sql.NewSpanIter(span, &generateIter{ + child: childIter, + idx: g.Column.Index(), + }), nil +} + +func (g *Generate) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { + col, err := g.Column.TransformUp(f) + if err != nil { + return nil, err + } + + field, ok := col.(*expression.GetField) + if !ok { + return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) + } + + return NewGenerate(g.Child, field), nil +} + +// TransformUp implements the sql.Node interface. +func (g *Generate) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { + child, err := g.Child.TransformUp(f) + if err != nil { + return nil, err + } + + return f(NewGenerate(child, g.Column)) +} + +// TransformExpressionsUp implements the sql.Node interface. +func (g *Generate) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + child, err := g.Child.TransformExpressionsUp(f) + if err != nil { + return nil, err + } + + col, err := g.Column.TransformUp(f) + if err != nil { + return nil, err + } + + field, ok := col.(*expression.GetField) + if !ok { + return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) + } + + return NewGenerate(child, field), nil +} + +func (g *Generate) String() string { + tp := sql.NewTreePrinter() + _ = tp.WriteNode("Generate(%s)", g.Column) + _ = tp.WriteChildren(g.Child.String()) + return tp.String() +} + +type generateIter struct { + child sql.RowIter + idx int + + gen sql.Generator + row sql.Row +} + +func (i *generateIter) Next() (sql.Row, error) { + for { + if i.gen == nil { + var err error + i.row, err = i.child.Next() + if err != nil { + return nil, err + } + + i.gen, err = sql.ToGenerator(i.row[i.idx]) + if err != nil { + return nil, err + } + } + + val, err := i.gen.Next() + if err != nil { + if err == io.EOF { + if err := i.gen.Close(); err != nil { + return nil, err + } + + i.gen = nil + continue + } + return nil, err + } + + var row = make(sql.Row, len(i.row)) + copy(row, i.row) + row[i.idx] = val + return row, nil + } +} + +func (i *generateIter) Close() error { + if i.gen != nil { + if err := i.gen.Close(); err != nil { + _ = i.child.Close() + return err + } + } + + return i.child.Close() +} diff --git a/sql/plan/generate_test.go b/sql/plan/generate_test.go new file mode 100644 index 000000000..ba35f33cf --- /dev/null +++ b/sql/plan/generate_test.go @@ -0,0 +1,88 @@ +package plan + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +func TestGenerateRowIter(t *testing.T) { + require := require.New(t) + + child := newFakeNode( + sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }, + sql.RowsToRowIter( + sql.Row{"first", sql.NewArrayGenerator([]interface{}{"a", "b"}), int64(1)}, + sql.Row{"second", sql.NewArrayGenerator([]interface{}{"c", "d"}), int64(2)}, + ), + ) + + iter, err := NewGenerate( + child, + expression.NewGetFieldWithTable(1, sql.Array(sql.Text), "foo", "b", false), + ).RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"first", "a", int64(1)}, + {"first", "b", int64(1)}, + {"second", "c", int64(2)}, + {"second", "d", int64(2)}, + } + + require.Equal(expected, rows) +} + +func TestGenerateSchema(t *testing.T) { + require := require.New(t) + + schema := NewGenerate( + newFakeNode( + sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }, + nil, + ), + expression.NewGetField(1, sql.Array(sql.Text), "foobar", false), + ).Schema() + + expected := sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "foobar", Type: sql.Text}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + } + + require.Equal(expected, schema) +} + +type fakeNode struct { + schema sql.Schema + iter sql.RowIter +} + +func newFakeNode(s sql.Schema, iter sql.RowIter) *fakeNode { + return &fakeNode{s, iter} +} + +func (n *fakeNode) Children() []sql.Node { return nil } +func (n *fakeNode) Resolved() bool { return true } +func (n *fakeNode) Schema() sql.Schema { return n.schema } +func (n *fakeNode) RowIter(*sql.Context) (sql.RowIter, error) { return n.iter, nil } +func (n *fakeNode) String() string { return "fakeNode" } +func (n *fakeNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { + panic("placeholder") +} +func (n *fakeNode) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { + panic("placeholder") +} diff --git a/sql/type.go b/sql/type.go index 341ab9506..bea1cf710 100644 --- a/sql/type.go +++ b/sql/type.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "math" "reflect" "strconv" @@ -148,7 +149,7 @@ type Type interface { // The result will be 0 if a==b, -1 if a < b, and +1 if a > b. Compare(interface{}, interface{}) (int, error) // SQL returns the sqltypes.Value for the given value. - SQL(interface{}) sqltypes.Value + SQL(interface{}) (sqltypes.Value, error) fmt.Stringer } @@ -266,8 +267,8 @@ func (t nullT) Type() query.Type { } // SQL implements Type interface. -func (t nullT) SQL(interface{}) sqltypes.Value { - return sqltypes.NULL +func (t nullT) SQL(interface{}) (sqltypes.Value, error) { + return sqltypes.NULL, nil } // Convert implements Type interface. @@ -300,26 +301,26 @@ func (t numberT) Type() query.Type { } // SQL implements Type interface. -func (t numberT) SQL(v interface{}) sqltypes.Value { +func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } switch t.t { case sqltypes.Int32: - return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int64: - return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Uint32: - return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint64: - return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)) + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Float32: - return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)) + return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)), nil case sqltypes.Float64: - return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)) + return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)), nil default: - return sqltypes.MakeTrusted(t.t, []byte{}) + return sqltypes.MakeTrusted(t.t, []byte{}), nil } } @@ -426,16 +427,20 @@ var TimestampLayouts = []string{ } // SQL implements Type interface. -func (t timestampT) SQL(v interface{}) sqltypes.Value { +func (t timestampT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err } - time := MustConvert(t, v).(time.Time) return sqltypes.MakeTrusted( sqltypes.Timestamp, - []byte(time.Format(TimestampLayout)), - ) + []byte(v.(time.Time).Format(TimestampLayout)), + ), nil } // Convert implements Type interface. @@ -498,16 +503,20 @@ func (t dateT) Type() query.Type { return sqltypes.Date } -func (t dateT) SQL(v interface{}) sqltypes.Value { +func (t dateT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err } - time := MustConvert(t, v).(time.Time) return sqltypes.MakeTrusted( sqltypes.Timestamp, - []byte(time.Format(DateLayout)), - ) + []byte(v.(time.Time).Format(DateLayout)), + ), nil } func (t dateT) Convert(v interface{}) (interface{}, error) { @@ -551,12 +560,17 @@ func (t textT) Type() query.Type { } // SQL implements Type interface. -func (t textT) SQL(v interface{}) sqltypes.Value { +func (t textT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } - return sqltypes.MakeTrusted(sqltypes.Text, []byte(MustConvert(t, v).(string))) + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.Text, []byte(v.(string))), nil } // Convert implements Type interface. @@ -583,9 +597,9 @@ func (t booleanT) Type() query.Type { } // SQL implements Type interface. -func (t booleanT) SQL(v interface{}) sqltypes.Value { +func (t booleanT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } b := []byte{'0'} @@ -593,7 +607,7 @@ func (t booleanT) SQL(v interface{}) sqltypes.Value { b[0] = '1' } - return sqltypes.MakeTrusted(sqltypes.Bit, b) + return sqltypes.MakeTrusted(sqltypes.Bit, b), nil } // Convert implements Type interface. @@ -655,12 +669,17 @@ func (t blobT) Type() query.Type { } // SQL implements Type interface. -func (t blobT) SQL(v interface{}) sqltypes.Value { +func (t blobT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err } - return sqltypes.MakeTrusted(sqltypes.Blob, MustConvert(t, v).([]byte)) + return sqltypes.MakeTrusted(sqltypes.Blob, v.([]byte)), nil } // Convert implements Type interface. @@ -694,11 +713,17 @@ func (t jsonT) Type() query.Type { } // SQL implements Type interface. -func (t jsonT) SQL(v interface{}) sqltypes.Value { +func (t jsonT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } - return sqltypes.MakeTrusted(sqltypes.TypeJSON, MustConvert(t, v).([]byte)) + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.TypeJSON, v.([]byte)), nil } // Convert implements Type interface. @@ -734,12 +759,12 @@ func (t tupleT) Type() query.Type { return sqltypes.Expression } -func (t tupleT) SQL(v interface{}) sqltypes.Value { +func (t tupleT) SQL(v interface{}) (sqltypes.Value, error) { if _, ok := v.(nullT); ok { - return sqltypes.NULL + return sqltypes.NULL, nil } - panic("unable to convert tuple type to SQL") + return sqltypes.Value{}, fmt.Errorf("unable to convert tuple type to SQL") } func (t tupleT) Convert(v interface{}) (interface{}, error) { @@ -799,24 +824,58 @@ func (t arrayT) Type() query.Type { return sqltypes.TypeJSON } -func (t arrayT) SQL(v interface{}) sqltypes.Value { +func (t arrayT) SQL(v interface{}) (sqltypes.Value, error) { + if _, ok := v.(nullT); ok { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + return JSON.SQL(v) } func (t arrayT) Convert(v interface{}) (interface{}, error) { - if vals, ok := v.([]interface{}); ok { - var result = make([]interface{}, len(vals)) - for i, v := range vals { + switch v := v.(type) { + case []interface{}: + var result = make([]interface{}, len(v)) + for i, v := range v { var err error result[i], err = t.underlying.Convert(v) if err != nil { return nil, err } } - return result, nil + case Generator: + var values []interface{} + for { + val, err := v.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + val, err = t.underlying.Convert(val) + if err != nil { + return nil, err + } + + values = append(values, val) + } + + if err := v.Close(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, ErrNotArray.New(v) } - return nil, ErrNotArray.New(v) } func (t arrayT) Compare(a, b interface{}) (int, error) { @@ -853,16 +912,6 @@ func (t arrayT) Compare(a, b interface{}) (int, error) { return 0, nil } -// MustConvert calls the Convert function from a given Type, it err panics. -func MustConvert(t Type, v interface{}) interface{} { - c, err := t.Convert(v) - if err != nil { - panic(err) - } - - return c -} - // IsNumber checks if t is a number type func IsNumber(t Type) bool { return IsInteger(t) || IsDecimal(t) @@ -961,3 +1010,14 @@ func MySQLTypeName(t Type) string { return "UNKNOWN" } } + +// UnderlyingType returns the underlying type of an array if the type is an +// array, or the type itself in any other case. +func UnderlyingType(t Type) Type { + a, ok := t.(arrayT) + if !ok { + return t + } + + return a.underlying +} diff --git a/sql/type_test.go b/sql/type_test.go index 0b3eaa629..fb8de7f3b 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -13,8 +13,8 @@ func TestIsNull(t *testing.T) { require.True(t, IsNull(nil)) n := numberT{sqltypes.Uint64} - require.Equal(t, sqltypes.NULL, n.SQL(Null)) - require.Equal(t, sqltypes.NewUint64(0), n.SQL(uint64(0))) + require.Equal(t, sqltypes.NULL, mustSQL(n.SQL(Null))) + require.Equal(t, sqltypes.NewUint64(0), mustSQL(n.SQL(uint64(0)))) } func TestText(t *testing.T) { @@ -70,7 +70,8 @@ func TestFloat64(t *testing.T) { var f = numberT{ t: query.Type_FLOAT64, } - val := f.SQL(23.222) + val, err := f.SQL(23.222) + require.NoError(err) require.True(val.IsFloat()) require.Equal(sqltypes.NewFloat64(23.222), val) } @@ -97,7 +98,8 @@ func TestTimestamp(t *testing.T) { v.(time.Time).Format(TimestampLayout), ) - sql := Timestamp.SQL(now) + sql, err := Timestamp.SQL(now) + require.NoError(err) require.Equal([]byte(now.Format(TimestampLayout)), sql.Raw()) after := now.Add(time.Second) @@ -167,7 +169,8 @@ func TestDate(t *testing.T) { v.(time.Time).Format(DateLayout), ) - sql := Date.SQL(now) + sql, err := Date.SQL(now) + require.NoError(err) require.Equal([]byte(now.Format(DateLayout)), sql.Raw()) after := now.Add(time.Second) @@ -186,7 +189,6 @@ func TestBlob(t *testing.T) { convert(t, Blob, "", []byte{}) convert(t, Blob, nil, []byte(nil)) - MustConvert(Blob, nil) _, err := Blob.Convert(1) require.NotNil(err) @@ -221,9 +223,8 @@ func TestTuple(t *testing.T) { convert(t, typ, []interface{}{1, 2, 3}, []interface{}{int32(1), "2", int64(3)}) - require.Panics(func() { - typ.SQL(nil) - }) + _, err = typ.SQL(nil) + require.Error(err) require.Equal(sqltypes.Expression, typ.Type()) @@ -245,6 +246,12 @@ func TestArray(t *testing.T) { require.True(ErrNotArray.Is(err)) convert(t, typ, []interface{}{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}) + convert( + t, + typ, + NewArrayGenerator([]interface{}{1, 2, 3}), + []interface{}{int64(1), int64(2), int64(3)}, + ) require.Equal(sqltypes.TypeJSON, typ.Type()) @@ -257,6 +264,21 @@ func TestArray(t *testing.T) { gt(t, typ, []interface{}{1, 3, 3}, []interface{}{1, 2, 3}) gt(t, typ, []interface{}{1, 2, 4}, []interface{}{1, 2, 3}) gt(t, typ, []interface{}{1, 2, 4}, []interface{}{5, 6}) + + expected := []byte("[1,2,3]") + + v, err := Array(Int64).SQL([]interface{}{1, 2, 3}) + require.NoError(err) + require.Equal(expected, v.Raw()) + + v, err = Array(Int64).SQL(NewArrayGenerator([]interface{}{1, 2, 3})) + require.NoError(err) + require.Equal(expected, v.Raw()) +} + +func TestUnderlyingType(t *testing.T) { + require.Equal(t, Text, UnderlyingType(Array(Text))) + require.Equal(t, Text, UnderlyingType(Text)) } func eq(t *testing.T, typ Type, a, b interface{}) { @@ -292,3 +314,10 @@ func convertErr(t *testing.T, typ Type, val interface{}) { _, err := typ.Convert(val) require.Error(t, err) } + +func mustSQL(v sqltypes.Value, err error) sqltypes.Value { + if err != nil { + panic(err) + } + return v +}