Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
sql/analyzer: back-propagate expression names after adding convert
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Jun 3, 2019
1 parent 572d64b commit b99763e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
23 changes: 23 additions & 0 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
if err != nil {
return nil, err
}

aggregate[i] = agg

if agg.String() != a.String() {
nodeReplacements[tableCol{"", a.String()}] = agg.String()
}
}

var grouping = make([]sql.Expression, len(exp.Grouping))
Expand All @@ -69,6 +74,24 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
}

result = plan.NewGroupBy(aggregate, grouping, exp.Child)
case *plan.Project:
var projections = make([]sql.Expression, len(exp.Projections))
for i, e := range exp.Projections {
expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
return nil, err
}

projections[i] = expr

if expr.String() != e.String() {
nodeReplacements[tableCol{"", e.String()}] = expr.String()
}
}

result = plan.NewProject(projections, exp.Child)
default:
result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, n, replacements, nodeReplacements, expressions, true)
Expand Down
35 changes: 34 additions & 1 deletion sql/analyzer/convert_dates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package analyzer
import (
"testing"

"github.com/stretchr/testify/require"
"github.com/src-d/go-mysql-server/mem"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/expression/function"
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
"github.com/src-d/go-mysql-server/sql/plan"
"github.com/stretchr/testify/require"
)

func TestConvertDates(t *testing.T) {
Expand Down Expand Up @@ -249,6 +249,39 @@ func TestConvertDatesGroupBy(t *testing.T) {
require.Equal(t, expected, result)
}

func TestConvertDatesFieldReference(t *testing.T) {
table := plan.NewResolvedTable(mem.NewTable("t", nil))
input := plan.NewFilter(
expression.NewEquals(
expression.NewGetField(0, sql.Int64, "DAYOFWEEK(foo)", false),
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
),
plan.NewProject([]sql.Expression{
function.NewDayOfWeek(
expression.NewGetField(0, sql.Timestamp, "foo", false),
),
}, table),
)
expected := plan.NewFilter(
expression.NewEquals(
expression.NewGetField(0, sql.Int64, "DAYOFWEEK(convert(foo, datetime))", false),
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
),
plan.NewProject([]sql.Expression{
function.NewDayOfWeek(
expression.NewConvert(
expression.NewGetField(0, sql.Timestamp, "foo", false),
expression.ConvertToDatetime,
),
),
}, table),
)

result, err := convertDates(sql.NewEmptyContext(), nil, input)
require.NoError(t, err)
require.Equal(t, expected, result)
}

func newDateAdd(l, r sql.Expression) sql.Expression {
e, _ := function.NewDateAdd(l, r)
return e
Expand Down

0 comments on commit b99763e

Please sign in to comment.