From 71819b4df39ef86073770f0e3ab5e758222728f6 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Mon, 3 Jun 2019 15:29:09 +0200 Subject: [PATCH] sql/analyzer: back-propagate expression names after adding convert Signed-off-by: Miguel Molina --- sql/analyzer/convert_dates.go | 25 ++++++++++++++++++++- sql/analyzer/convert_dates_test.go | 35 +++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go index 5307463f0..1397d867a 100644 --- a/sql/analyzer/convert_dates.go +++ b/sql/analyzer/convert_dates.go @@ -54,7 +54,12 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { if err != nil { return nil, err } + aggregate[i] = agg + + if _, ok := agg.(*expression.Alias); !ok && agg.String() != a.String() { + nodeReplacements[tableCol{"", a.String()}] = agg.String() + } } var grouping = make([]sql.Expression, len(exp.Grouping)) @@ -69,9 +74,27 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { } result = plan.NewGroupBy(aggregate, grouping, exp.Child) + case *plan.Project: + var projections = make([]sql.Expression, len(exp.Projections)) + for i, e := range exp.Projections { + expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) + }) + if err != nil { + return nil, err + } + + projections[i] = expr + + if _, ok := expr.(*expression.Alias); !ok && expr.String() != e.String() { + nodeReplacements[tableCol{"", e.String()}] = expr.String() + } + } + + result = plan.NewProject(projections, exp.Child) default: result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { - return addDateConvert(e, n, replacements, nodeReplacements, expressions, true) + return addDateConvert(e, n, replacements, nodeReplacements, expressions, false) }) } diff --git a/sql/analyzer/convert_dates_test.go b/sql/analyzer/convert_dates_test.go index 82cac245b..934ec5db3 100644 --- a/sql/analyzer/convert_dates_test.go +++ b/sql/analyzer/convert_dates_test.go @@ -3,13 +3,13 @@ package analyzer import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/expression/function" "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" ) func TestConvertDates(t *testing.T) { @@ -249,6 +249,39 @@ func TestConvertDatesGroupBy(t *testing.T) { require.Equal(t, expected, result) } +func TestConvertDatesFieldReference(t *testing.T) { + table := plan.NewResolvedTable(mem.NewTable("t", nil)) + input := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "DAYOFWEEK(foo)", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + function.NewDayOfWeek( + expression.NewGetField(0, sql.Timestamp, "foo", false), + ), + }, table), + ) + expected := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "DAYOFWEEK(convert(foo, datetime))", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + function.NewDayOfWeek( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + ), + }, table), + ) + + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) +} + func newDateAdd(l, r sql.Expression) sql.Expression { e, _ := function.NewDateAdd(l, r) return e