Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #237 from mcarmonaa/fix/mix-natural-and-inner-joins
Browse files Browse the repository at this point in the history
sql/analyzer: fix star resolution when there are natural joins involved
  • Loading branch information
ajnavarro committed Jun 20, 2018
2 parents acce4ea + f547549 commit 0084abf
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
58 changes: 58 additions & 0 deletions engine_test.go
Expand Up @@ -499,6 +499,64 @@ func TestNaturalJoinDisjoint(t *testing.T) {
)
}

func TestInnerNestedInNaturalJoins(t *testing.T) {
require := require.New(t)

table1 := mem.NewTable("table1", sql.Schema{
{Name: "i", Type: sql.Int32, Source: "table1"},
{Name: "f", Type: sql.Float64, Source: "table1"},
{Name: "t", Type: sql.Text, Source: "table1"},
})

require.Nil(table1.Insert(sql.NewRow(int32(1), float64(2.1), "table1")))
require.Nil(table1.Insert(sql.NewRow(int32(1), float64(2.1), "table1")))
require.Nil(table1.Insert(sql.NewRow(int32(10), float64(2.1), "table1")))

table2 := mem.NewTable("table2", sql.Schema{
{Name: "i2", Type: sql.Int32, Source: "table2"},
{Name: "f2", Type: sql.Float64, Source: "table2"},
{Name: "t2", Type: sql.Text, Source: "table2"},
})

require.Nil(table2.Insert(sql.NewRow(int32(1), float64(2.2), "table2")))
require.Nil(table2.Insert(sql.NewRow(int32(1), float64(2.2), "table2")))
require.Nil(table2.Insert(sql.NewRow(int32(20), float64(2.2), "table2")))

table3 := mem.NewTable("table3", sql.Schema{
{Name: "i", Type: sql.Int32, Source: "table3"},
{Name: "f2", Type: sql.Float64, Source: "table3"},
{Name: "t3", Type: sql.Text, Source: "table3"},
})

require.Nil(table3.Insert(sql.NewRow(int32(1), float64(2.3), "table3")))
require.Nil(table3.Insert(sql.NewRow(int32(2), float64(2.3), "table3")))
require.Nil(table3.Insert(sql.NewRow(int32(30), float64(2.3), "table3")))

db := mem.NewDatabase("mydb")
db.AddTable("table1", table1)
db.AddTable("table2", table2)
db.AddTable("table3", table3)

e := sqle.NewDefault()
e.AddDatabase(db)

_, iter, err := e.Query(sql.NewEmptyContext(), `SELECT * FROM table1 INNER JOIN table2 ON table1.i = table2.i2 NATURAL JOIN table3`)
require.NoError(err)

rows, err := sql.RowIterToRows(iter)
require.NoError(err)

require.Equal(
[]sql.Row{
{int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"},
{int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"},
{int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"},
{int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"},
},
rows,
)
}

func testQuery(t *testing.T, e *sqle.Engine, q string, r []sql.Row) {
t.Run(q, func(t *testing.T) {
require := require.New(t)
Expand Down
8 changes: 8 additions & 0 deletions sql/analyzer/rules.go
Expand Up @@ -384,13 +384,21 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

switch n := n.(type) {
case *plan.Project:
if !n.Child.Resolved() {
return n, nil
}

expressions, err := expandStars(n.Projections, n.Child.Schema())
if err != nil {
return nil, err
}

return plan.NewProject(expressions, n.Child), nil
case *plan.GroupBy:
if !n.Child.Resolved() {
return n, nil
}

aggregate, err := expandStars(n.Aggregate, n.Child.Schema())
if err != nil {
return nil, err
Expand Down
109 changes: 109 additions & 0 deletions sql/analyzer/rules_test.go
Expand Up @@ -881,6 +881,115 @@ func TestPushdownProjectionAndFilters(t *testing.T) {
require.Equal(expected, result)
}

func TestMixInnerAndNaturalJoins(t *testing.T) {
require := require.New(t)

table := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable", sql.Schema{
{Name: "i", Type: sql.Int32, Source: "mytable"},
{Name: "f", Type: sql.Float64, Source: "mytable"},
{Name: "t", Type: sql.Text, Source: "mytable"},
})}

table2 := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable2", sql.Schema{
{Name: "i2", Type: sql.Int32, Source: "mytable2"},
{Name: "f2", Type: sql.Float64, Source: "mytable2"},
{Name: "t2", Type: sql.Text, Source: "mytable2"},
})}

table3 := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable3", sql.Schema{
{Name: "i", Type: sql.Int32, Source: "mytable3"},
{Name: "f2", Type: sql.Float64, Source: "mytable3"},
{Name: "t3", Type: sql.Text, Source: "mytable3"},
})}

db := mem.NewDatabase("mydb")
db.AddTable("mytable", table)
db.AddTable("mytable2", table2)
db.AddTable("mytable3", table3)

catalog := &sql.Catalog{Databases: []sql.Database{db}}
a := NewDefault(catalog)
a.CurrentDatabase = "mydb"

node := plan.NewProject(
[]sql.Expression{
expression.NewStar(),
},
plan.NewNaturalJoin(
plan.NewInnerJoin(
plan.NewUnresolvedTable("mytable"),
plan.NewUnresolvedTable("mytable2"),
expression.NewEquals(
expression.NewUnresolvedQualifiedColumn("mytable", "i"),
expression.NewUnresolvedQualifiedColumn("mytable2", "i2"),
),
),
plan.NewUnresolvedTable("mytable3"),
),
)

expected := plan.NewProject(
[]sql.Expression{
expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false),
expression.NewGetFieldWithTable(4, sql.Float64, "mytable2", "f2", false),
expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false),
expression.NewGetFieldWithTable(2, sql.Text, "mytable", "t", false),
expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "i2", false),
expression.NewGetFieldWithTable(5, sql.Text, "mytable2", "t2", false),
expression.NewGetFieldWithTable(8, sql.Text, "mytable3", "t3", false),
},
plan.NewInnerJoin(
plan.NewInnerJoin(
plan.NewPushdownProjectionAndFiltersTable(
[]sql.Expression{
expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false),
expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false),
expression.NewGetFieldWithTable(2, sql.Text, "mytable", "t", false),
},
nil,
table,
),
plan.NewPushdownProjectionAndFiltersTable(
[]sql.Expression{
expression.NewGetFieldWithTable(1, sql.Float64, "mytable2", "f2", false),
expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false),
expression.NewGetFieldWithTable(2, sql.Text, "mytable2", "t2", false),
},
nil,
table2,
),
expression.NewEquals(
expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false),
expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "i2", false),
),
),
plan.NewPushdownProjectionAndFiltersTable(
[]sql.Expression{
expression.NewGetFieldWithTable(2, sql.Text, "mytable3", "t3", false),
expression.NewGetFieldWithTable(0, sql.Int32, "mytable3", "i", false),
expression.NewGetFieldWithTable(1, sql.Float64, "mytable3", "f2", false),
},
nil,
table3,
),
expression.NewAnd(
expression.NewEquals(
expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false),
expression.NewGetFieldWithTable(6, sql.Int32, "mytable3", "i", false),
),
expression.NewEquals(
expression.NewGetFieldWithTable(4, sql.Float64, "mytable2", "f2", false),
expression.NewGetFieldWithTable(7, sql.Float64, "mytable3", "f2", false),
),
),
),
)

result, err := a.Analyze(sql.NewEmptyContext(), node)
require.NoError(err)
require.Equal(expected, result)
}

func TestPushdownIndexable(t *testing.T) {
require := require.New(t)
a := NewDefault(sql.NewCatalog())
Expand Down

0 comments on commit 0084abf

Please sign in to comment.