diff --git a/engine_test.go b/engine_test.go index 8e39287de..180720502 100644 --- a/engine_test.go +++ b/engine_test.go @@ -499,6 +499,64 @@ func TestNaturalJoinDisjoint(t *testing.T) { ) } +func TestInnerNestedInNaturalJoins(t *testing.T) { + require := require.New(t) + + table1 := mem.NewTable("table1", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "table1"}, + {Name: "f", Type: sql.Float64, Source: "table1"}, + {Name: "t", Type: sql.Text, Source: "table1"}, + }) + + require.Nil(table1.Insert(sql.NewRow(int32(1), float64(2.1), "table1"))) + require.Nil(table1.Insert(sql.NewRow(int32(1), float64(2.1), "table1"))) + require.Nil(table1.Insert(sql.NewRow(int32(10), float64(2.1), "table1"))) + + table2 := mem.NewTable("table2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "table2"}, + {Name: "f2", Type: sql.Float64, Source: "table2"}, + {Name: "t2", Type: sql.Text, Source: "table2"}, + }) + + require.Nil(table2.Insert(sql.NewRow(int32(1), float64(2.2), "table2"))) + require.Nil(table2.Insert(sql.NewRow(int32(1), float64(2.2), "table2"))) + require.Nil(table2.Insert(sql.NewRow(int32(20), float64(2.2), "table2"))) + + table3 := mem.NewTable("table3", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "table3"}, + {Name: "f2", Type: sql.Float64, Source: "table3"}, + {Name: "t3", Type: sql.Text, Source: "table3"}, + }) + + require.Nil(table3.Insert(sql.NewRow(int32(1), float64(2.3), "table3"))) + require.Nil(table3.Insert(sql.NewRow(int32(2), float64(2.3), "table3"))) + require.Nil(table3.Insert(sql.NewRow(int32(30), float64(2.3), "table3"))) + + db := mem.NewDatabase("mydb") + db.AddTable("table1", table1) + db.AddTable("table2", table2) + db.AddTable("table3", table3) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query(sql.NewEmptyContext(), `SELECT * FROM table1 INNER JOIN table2 ON table1.i = table2.i2 NATURAL JOIN table3`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal( + []sql.Row{ + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + }, + rows, + ) +} + func testQuery(t *testing.T, e *sqle.Engine, q string, r []sql.Row) { t.Run(q, func(t *testing.T) { require := require.New(t) diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 6e246b1be..9adce94e8 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -384,6 +384,10 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.Project: + if !n.Child.Resolved() { + return n, nil + } + expressions, err := expandStars(n.Projections, n.Child.Schema()) if err != nil { return nil, err @@ -391,6 +395,10 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { return plan.NewProject(expressions, n.Child), nil case *plan.GroupBy: + if !n.Child.Resolved() { + return n, nil + } + aggregate, err := expandStars(n.Aggregate, n.Child.Schema()) if err != nil { return nil, err diff --git a/sql/analyzer/rules_test.go b/sql/analyzer/rules_test.go index 9de12077b..d06ec2ad4 100644 --- a/sql/analyzer/rules_test.go +++ b/sql/analyzer/rules_test.go @@ -881,6 +881,115 @@ func TestPushdownProjectionAndFilters(t *testing.T) { require.Equal(expected, result) } +func TestMixInnerAndNaturalJoins(t *testing.T) { + require := require.New(t) + + table := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable"}, + {Name: "f", Type: sql.Float64, Source: "mytable"}, + {Name: "t", Type: sql.Text, Source: "mytable"}, + })} + + table2 := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "mytable2"}, + {Name: "f2", Type: sql.Float64, Source: "mytable2"}, + {Name: "t2", Type: sql.Text, Source: "mytable2"}, + })} + + table3 := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable3", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable3"}, + {Name: "f2", Type: sql.Float64, Source: "mytable3"}, + {Name: "t3", Type: sql.Text, Source: "mytable3"}, + })} + + db := mem.NewDatabase("mydb") + db.AddTable("mytable", table) + db.AddTable("mytable2", table2) + db.AddTable("mytable3", table3) + + catalog := &sql.Catalog{Databases: []sql.Database{db}} + a := NewDefault(catalog) + a.CurrentDatabase = "mydb" + + node := plan.NewProject( + []sql.Expression{ + expression.NewStar(), + }, + plan.NewNaturalJoin( + plan.NewInnerJoin( + plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable2"), + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + expression.NewUnresolvedQualifiedColumn("mytable2", "i2"), + ), + ), + plan.NewUnresolvedTable("mytable3"), + ), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(4, sql.Float64, "mytable2", "f2", false), + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + expression.NewGetFieldWithTable(2, sql.Text, "mytable", "t", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "i2", false), + expression.NewGetFieldWithTable(5, sql.Text, "mytable2", "t2", false), + expression.NewGetFieldWithTable(8, sql.Text, "mytable3", "t3", false), + }, + plan.NewInnerJoin( + plan.NewInnerJoin( + plan.NewPushdownProjectionAndFiltersTable( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + expression.NewGetFieldWithTable(2, sql.Text, "mytable", "t", false), + }, + nil, + table, + ), + plan.NewPushdownProjectionAndFiltersTable( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Float64, "mytable2", "f2", false), + expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), + expression.NewGetFieldWithTable(2, sql.Text, "mytable2", "t2", false), + }, + nil, + table2, + ), + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "i2", false), + ), + ), + plan.NewPushdownProjectionAndFiltersTable( + []sql.Expression{ + expression.NewGetFieldWithTable(2, sql.Text, "mytable3", "t3", false), + expression.NewGetFieldWithTable(0, sql.Int32, "mytable3", "i", false), + expression.NewGetFieldWithTable(1, sql.Float64, "mytable3", "f2", false), + }, + nil, + table3, + ), + expression.NewAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(6, sql.Int32, "mytable3", "i", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(4, sql.Float64, "mytable2", "f2", false), + expression.NewGetFieldWithTable(7, sql.Float64, "mytable3", "f2", false), + ), + ), + ), + ) + + result, err := a.Analyze(sql.NewEmptyContext(), node) + require.NoError(err) + require.Equal(expected, result) +} + func TestPushdownIndexable(t *testing.T) { require := require.New(t) a := NewDefault(sql.NewCatalog())