diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java new file mode 100644 index 000000000000..5e89895eeddb --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java @@ -0,0 +1,807 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.QueryAssertions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.CTE_MATERIALIZATION_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.PARTITIONING_PROVIDER_CATALOG; +import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_ENABLED; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; +import static io.airlift.tpch.TpchTable.PART; +import static io.airlift.tpch.TpchTable.PART_SUPPLIER; +import static io.airlift.tpch.TpchTable.REGION; +import static io.airlift.tpch.TpchTable.SUPPLIER; + +@Test(singleThreaded = true) +public class TestCteExecution + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return HiveQueryRunner.createQueryRunner( + ImmutableList.of(ORDERS, CUSTOMER, LINE_ITEM, PART_SUPPLIER, NATION, REGION, PART, SUPPLIER), + ImmutableMap.of( + "query.partitioning-provider-catalog", "hive"), + "sql-standard", + ImmutableMap.of("hive.pushdown-filter-enabled", "true", + "hive.enable-parquet-dereference-pushdown", "true"), + Optional.empty()); + } + + @Test + public void testSimplePersistentCte() + { + QueryRunner queryRunner = getQueryRunner(); + compareResults(queryRunner.execute(getMaterializedSession(), + "WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 "), + queryRunner.execute(getSession(), + "WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 ")); + } + + @Test + public void testPersistentCteWithTimeStampWithTimeZoneType() + { + String testQuery = "WITH cte AS (" + + " SELECT ts FROM (VALUES " + + " (CAST('2023-01-01 00:00:00.000 UTC' AS TIMESTAMP WITH TIME ZONE)), " + + " (CAST('2023-06-01 12:00:00.000 UTC' AS TIMESTAMP WITH TIME ZONE)), " + + " (CAST('2023-12-31 23:59:59.999 UTC' AS TIMESTAMP WITH TIME ZONE))" + + " ) AS t(ts)" + + ")" + + "SELECT ts FROM cte"; + QueryRunner queryRunner = getQueryRunner(); + compareResults(queryRunner.execute(getMaterializedSession(), + testQuery), + queryRunner.execute(getSession(), + testQuery)); + } + + @Test + public void testPersistentCteWithStructTypes() + { + String testQuery = "WITH temp AS (" + + " SELECT * FROM (VALUES " + + " (CAST(ROW('example_status', 100) AS ROW(status VARCHAR, amount INTEGER)), 1)," + + " (CAST(ROW('another_status', 200) AS ROW(status VARCHAR, amount INTEGER)), 2)" + + " ) AS t (order_details, orderkey)" + + ") SELECT * FROM temp"; + QueryRunner queryRunner = getQueryRunner(); + compareResults(queryRunner.execute(getMaterializedSession(), + testQuery), + queryRunner.execute(getSession(), + testQuery)); + } + + @Test + public void testCteWithZeroLengthVarchar() + { + String testQuery = "WITH temp AS (" + + " SELECT * FROM (VALUES " + + " (CAST('' AS VARCHAR(0)), 9)" + + " ) AS t (text_column, number_column)" + + ") SELECT * FROM temp"; + QueryRunner queryRunner = getQueryRunner(); + compareResults(queryRunner.execute(getMaterializedSession(), + testQuery), + queryRunner.execute(getSession(), + testQuery)); + } + + @Test + public void testDependentPersistentCtes() + { + String testQuery = "WITH cte1 AS (SELECT orderkey FROM ORDERS WHERE orderkey < 100), " + + " cte2 AS (SELECT * FROM cte1 WHERE orderkey > 50) " + + "SELECT * FROM cte2"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testMultipleIndependentPersistentCtes() + { + String testQuery = "WITH cte1 AS (SELECT orderkey FROM ORDERS WHERE orderkey < 100), " + + " cte2 AS (SELECT custkey FROM CUSTOMER WHERE custkey < 50) " + + "SELECT * FROM cte1, cte2 WHERE cte1.orderkey = cte2.custkey"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testNestedPersistentCtes() + { + String testQuery = "WITH cte1 AS (" + + " SELECT orderkey FROM ORDERS WHERE orderkey IN " + + " (WITH cte2 AS (SELECT orderkey FROM ORDERS WHERE orderkey < 100) " + + " SELECT orderkey FROM cte2 WHERE orderkey > 50)" + + ") SELECT * FROM cte1"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testRefinedCtesOutsideScope() + { + String testQuery = "WITH cte1 AS ( WITH cte2 as (SELECT orderkey FROM ORDERS WHERE orderkey < 100)" + + "SELECT * FROM cte2), " + + " cte2 AS (SELECT * FROM customer WHERE custkey < 50) " + + "SELECT * FROM cte2 JOIN cte1 ON true"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testComplexRefinedCtesOutsideScope() + { + String testQuery = "WITH " + + "cte1 AS ( " + + " SELECT orderkey, totalprice FROM ORDERS WHERE orderkey < 100 " + + "), " + + "cte2 AS ( " + + " WITH cte3 AS ( WITH cte4 AS (SELECT orderkey, totalprice FROM cte1 WHERE totalprice > 1000) SELECT * FROM cte4) " + + " SELECT cte3.orderkey FROM cte3 " + + "), " + + "cte3 AS ( " + + " SELECT * FROM customer WHERE custkey < 50 " + + ") " + + "SELECT cte3.*, cte2.orderkey FROM cte3 JOIN cte2 ON cte3.custkey = cte2.orderkey"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testChainedPersistentCtes() + { + String testQuery = "WITH cte1 AS (SELECT orderkey FROM ORDERS WHERE orderkey < 100), " + + " cte2 AS (SELECT orderkey FROM cte1 WHERE orderkey > 50), " + + " cte3 AS (SELECT orderkey FROM cte2 WHERE orderkey < 75) " + + "SELECT * FROM cte3"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testSimplePersistentCteWithJoinInCteDef() + { + String testQuery = "WITH temp as " + + "(SELECT * FROM ORDERS o1 " + + "JOIN ORDERS o2 ON o1.orderkey = o2.orderkey) " + + "SELECT * FROM temp t1 "; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testSimplePersistentCteMultipleUses() + { + String testQuery = " WITH temp as" + + " (SELECT * FROM ORDERS) " + + "SELECT * FROM temp t1 JOIN temp t2 on " + + "t1.orderkey = t2.orderkey WHERE t1.orderkey < 10"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testPersistentCteMultipleColumns() + { + String testQuery = " WITH temp as (SELECT * FROM ORDERS) " + + "SELECT * FROM temp t1"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testJoinAndAggregationWithPersistentCtes() + { + String testQuery = "WITH cte1 AS (" + + " SELECT orderkey, COUNT(*) as item_count FROM lineitem" + + " GROUP BY orderkey)," + + " cte2 AS (" + + " SELECT c.custkey, c.name FROM CUSTOMER c" + + " WHERE c.mktsegment = 'BUILDING')" + + " SELECT * FROM cte1" + + " JOIN cte2 ON cte1.orderkey = cte2.custkey"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testNestedPersistentCtes2() + { + String testQuery = "WITH cte1 AS (" + + " WITH cte2 AS (" + + " SELECT nationkey FROM NATION" + + " WHERE regionkey = 1)" + + " SELECT * FROM cte2" + + " WHERE nationkey < 5)" + + "SELECT * FROM cte1"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testPersistentCteWithUnion() + { + String testQuery = "WITH cte AS (" + + " SELECT orderkey FROM ORDERS WHERE orderkey < 100" + + " UNION" + + " SELECT orderkey FROM ORDERS WHERE orderkey > 500)" + + "SELECT * FROM cte"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testPersistentCteWithSelfJoin() + { + String testQuery = "WITH cte AS (" + + " SELECT * FROM ORDERS)" + + "SELECT * FROM cte c1" + + " JOIN cte c2 ON c1.orderkey = c2.orderkey WHERE c1.orderkey < 100"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testPersistentCteWithWindowFunction() + { + String testQuery = "WITH cte AS (" + + " SELECT *, ROW_NUMBER() OVER(PARTITION BY orderstatus ORDER BY orderkey) as row" + + " FROM ORDERS)" + + "SELECT * FROM cte WHERE row <= 5"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testPersistentCteWithMultipleDependentSubCtes() + { + String testQuery = "WITH cte1 AS (" + + " SELECT * FROM ORDERS)," + + " cte2 AS (SELECT * FROM cte1 WHERE orderkey < 100)," + + " cte3 AS (SELECT * FROM cte1 WHERE orderkey >= 100)" + + "SELECT * FROM cte2 UNION ALL SELECT * FROM cte3"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testTopCustomersByOrderValue() + { + String testQuery = "WITH cte AS (" + + " SELECT c.custkey, c.name, SUM(o.totalprice) as total_spent " + + " FROM CUSTOMER c JOIN ORDERS o ON c.custkey = o.custkey " + + " GROUP BY c.custkey, c.name)" + + "SELECT * FROM cte " + + "ORDER BY total_spent DESC " + + "LIMIT 5"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery), true); + } + + @Test + public void testSupplierDataAnalysis() + { + String testQuery = "WITH cte AS (" + + " SELECT s.suppkey, s.name, n.name as nation, r.name as region, ROUND(SUM(ps.supplycost), 8) as total_supply_cost " + + " FROM partsupp ps JOIN SUPPLIER s ON ps.suppkey = s.suppkey " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey " + + " GROUP BY s.suppkey, s.name, n.name, r.name) " + + "SELECT * FROM cte " + + "WHERE total_supply_cost > 1000"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testCustomerOrderPatternAnalysis() + { + String testQuery = "WITH cte AS (" + + " SELECT c.name as customer_name, r.name as region_name, EXTRACT(year FROM o.orderdate) as order_year, COUNT(*) as order_count " + + " FROM CUSTOMER c JOIN ORDERS o ON c.custkey = o.custkey " + + " JOIN NATION n ON c.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey " + + " GROUP BY c.name, r.name, EXTRACT(year FROM o.orderdate)) " + + "SELECT * FROM cte " + + "ORDER BY customer_name, order_year"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testLowStockAnalysis() + { + String testQuery = "WITH cte AS (" + + " SELECT p.partkey, p.name, p.type, SUM(ps.availqty) as total_qty " + + " FROM PART p JOIN partsupp ps ON p.partkey = ps.partkey " + + " GROUP BY p.partkey, p.name, p.type) " + + "SELECT * FROM cte " + + "WHERE total_qty < 100"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testComplexChainOfDependentAndNestedPersistentCtes() + { + String testQuery = "WITH " + + " cte1 AS (" + + " SELECT * FROM ORDERS WHERE orderkey < 1000" + + " )," + + " cte2 AS (" + + " SELECT * FROM cte1 WHERE custkey < 500" + + " )," + + " cte3 AS (" + + " SELECT cte2.*, cte1.totalprice AS cte1_totalprice " + + " FROM cte2 " + + " JOIN cte1 ON cte2.orderkey = cte1.orderkey " + + " WHERE cte1.totalprice < 150000" + + " )," + + " cte4 AS (" + + " SELECT * FROM cte3 WHERE orderstatus = 'O'" + + " )," + + " cte5 AS (" + + " SELECT orderkey FROM cte4 WHERE cte1_totalprice < 100000" + + " )," + + " cte6 AS (" + + " SELECT * FROM cte5, LATERAL (" + + " SELECT * FROM cte2 WHERE cte2.orderkey = cte5.orderkey" + + " ) x" + + " )" + + "SELECT * FROM cte6"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testComplexQuery1() + { + String testQuery = "WITH customer_nation AS (" + + " SELECT c.custkey, c.name, n.name AS nation_name, r.name AS region_name " + + " FROM CUSTOMER c " + + " JOIN NATION n ON c.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " customer_orders AS (" + + " SELECT co.custkey, co.name, co.nation_name, co.region_name, o.orderkey, o.orderdate " + + " FROM customer_nation co " + + " JOIN ORDERS o ON co.custkey = o.custkey), " + + "order_lineitems AS (" + + " SELECT co.*, l.partkey, l.quantity, l.extendedprice " + + " FROM customer_orders co " + + " JOIN lineitem l ON co.orderkey = l.orderkey), " + + " customer_part_analysis AS (" + + " SELECT ol.*, p.name AS part_name, p.type AS part_type " + + " FROM order_lineitems ol " + + " JOIN PART p ON ol.partkey = p.partkey) " + + "SELECT * FROM customer_part_analysis " + + "WHERE region_name = 'AMERICA' " + + "ORDER BY nation_name, custkey, orderdate"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testComplexQuery2() + { + String testQuery = "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + + @Test + public void testSimplePersistentCteForCtasQueries() + { + QueryRunner queryRunner = getQueryRunner(); + + // Create tables with Ctas + queryRunner.execute(getMaterializedSession(), + "CREATE TABLE persistent_table as (WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 )"); + queryRunner.execute(getSession(), + "CREATE TABLE non_persistent_table as (WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1) "); + + // Compare contents with a select + compareResults(queryRunner.execute(getSession(), + "SELECT * FROM persistent_table"), + queryRunner.execute(getSession(), + "SELECT * FROM non_persistent_table")); + + // drop tables + queryRunner.execute(getSession(), + "DROP TABLE persistent_table"); + queryRunner.execute(getSession(), + "DROP TABLE non_persistent_table"); + } + + @Test + public void testComplexPersistentCteForCtasQueries() + { + QueryRunner queryRunner = getQueryRunner(); + // Create tables with Ctas + queryRunner.execute(getMaterializedSession(), + "CREATE TABLE persistent_table as ( " + + "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name)"); + queryRunner.execute(getSession(), + "CREATE TABLE non_persistent_table as ( " + + "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name)"); + + // Compare contents with a select + compareResults(queryRunner.execute(getSession(), + "SELECT * FROM persistent_table"), + queryRunner.execute(getSession(), + "SELECT * FROM non_persistent_table")); + + // drop tables + queryRunner.execute(getSession(), + "DROP TABLE persistent_table"); + queryRunner.execute(getSession(), + "DROP TABLE non_persistent_table"); + } + + @Test + public void testSimplePersistentCteForInsertQueries() + { + QueryRunner queryRunner = getQueryRunner(); + + // Create tables without data + queryRunner.execute(getSession(), + "CREATE TABLE persistent_table (orderkey BIGINT)"); + queryRunner.execute(getSession(), + "CREATE TABLE non_persistent_table (orderkey BIGINT)"); + + // Insert data into tables using CTEs + queryRunner.execute(getMaterializedSession(), + "INSERT INTO persistent_table " + + "WITH temp AS (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp"); + queryRunner.execute(getSession(), + "INSERT INTO non_persistent_table " + + "WITH temp AS (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp"); + + // Compare contents with a select + compareResults(queryRunner.execute(getSession(), + "SELECT * FROM persistent_table"), + queryRunner.execute(getSession(), + "SELECT * FROM non_persistent_table")); + + // drop tables + queryRunner.execute(getSession(), + "DROP TABLE persistent_table"); + queryRunner.execute(getSession(), + "DROP TABLE non_persistent_table"); + } + + @Test + public void testComplexPersistentCteForInsertQueries() + { + QueryRunner queryRunner = getQueryRunner(); + // Create tables without data + // Create tables + String createTableBase = " (suppkey BIGINT, supplier_name VARCHAR, nation_name VARCHAR, region_name VARCHAR, " + + "partkey BIGINT, availqty BIGINT, supplycost DOUBLE, " + + "part_name VARCHAR, part_type VARCHAR, part_size BIGINT, " + + "nation_comment VARCHAR, region_comment VARCHAR)"; + + queryRunner.execute(getSession(), + "CREATE TABLE persistent_table" + createTableBase); + + queryRunner.execute(getSession(), + "CREATE TABLE non_persistent_table" + createTableBase); + + queryRunner.execute(getMaterializedSession(), + "INSERT INTO persistent_table " + + "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name"); + queryRunner.execute(getSession(), + "INSERT INTO non_persistent_table " + + "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name"); + + // Compare contents with a select + compareResults(queryRunner.execute(getSession(), + "SELECT * FROM persistent_table"), + queryRunner.execute(getSession(), + "SELECT * FROM non_persistent_table")); + + // drop tables + queryRunner.execute(getSession(), + "DROP TABLE persistent_table"); + queryRunner.execute(getSession(), + "DROP TABLE non_persistent_table"); + } + + @Test + public void testSimplePersistentCteForViewQueries() + { + QueryRunner queryRunner = getQueryRunner(); + + // Create views + queryRunner.execute(getMaterializedSession(), + "CREATE VIEW persistent_view AS WITH temp AS (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp"); + queryRunner.execute(getSession(), + "CREATE VIEW non_persistent_view AS WITH temp AS (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp"); + // Compare contents of views with a select + compareResults(queryRunner.execute(getMaterializedSession(), "SELECT * FROM persistent_view"), + queryRunner.execute(getSession(), "SELECT * FROM non_persistent_view")); + + // Drop views + queryRunner.execute(getSession(), "DROP VIEW persistent_view"); + queryRunner.execute(getSession(), "DROP VIEW non_persistent_view"); + } + + @Test + public void testComplexPersistentCteForViewQueries() + { + QueryRunner queryRunner = getQueryRunner(); + // Create Views + queryRunner.execute(getMaterializedSession(), + "CREATE View persistent_view as " + + "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name"); + queryRunner.execute(getSession(), + "CREATE View non_persistent_view as " + + "WITH supplier_region AS (" + + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + + " FROM SUPPLIER s " + + " JOIN NATION n ON s.nationkey = n.nationkey " + + " JOIN REGION r ON n.regionkey = r.regionkey), " + + " supplier_parts AS (" + + " SELECT sr.*, ps.partkey, ps.availqty, ps.supplycost " + + " FROM supplier_region sr " + + " JOIN partsupp ps ON sr.suppkey = ps.suppkey), " + + "parts_info AS (" + + " SELECT sp.*, p.name AS part_name, p.type AS part_type, p.size AS part_size " + + " FROM supplier_parts sp " + + " JOIN PART p ON sp.partkey = p.partkey), " + + " full_supplier_part_info AS (" + + " SELECT pi.*, n.comment AS nation_comment, r.comment AS region_comment " + + " FROM parts_info pi " + + " JOIN NATION n ON pi.nation_name = n.name " + + " JOIN REGION r ON pi.region_name = r.name) " + + "SELECT * FROM full_supplier_part_info " + + "WHERE part_type LIKE '%BRASS' " + + "ORDER BY region_name, supplier_name"); + + // Compare contents with a select + compareResults(queryRunner.execute(getMaterializedSession(), + "SELECT * FROM persistent_view"), + queryRunner.execute(getSession(), + "SELECT * FROM non_persistent_view")); + + // drop views + queryRunner.execute(getSession(), + "DROP View persistent_view"); + queryRunner.execute(getSession(), + "DROP View non_persistent_view"); + } + + private void compareResults(MaterializedResult actual, MaterializedResult expected) + { + compareResults(actual, expected, false); + } + + private void compareResults(MaterializedResult actual, MaterializedResult expected, boolean checkOrdering) + { + // Verify result count + assertEquals(actual.getRowCount(), + expected.getRowCount(), String.format("Expected %d rows got %d rows", expected.getRowCount(), actual.getRowCount())); + if (checkOrdering) { + assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows(), "Correctness check failed! Rows are not equal"); + return; + } + QueryAssertions.assertEqualsIgnoreOrder(actual, expected, "Correctness check failed! Rows are not equal"); + } + + @Override + protected Session getSession() + { + return Session.builder(super.getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, "true") + .setSystemProperty(CTE_MATERIALIZATION_STRATEGY, "NONE") + .build(); + } + protected Session getMaterializedSession() + { + return Session.builder(super.getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, "true") + .setSystemProperty(CTE_MATERIALIZATION_STRATEGY, "ALL") + .setSystemProperty(PARTITIONING_PROVIDER_CATALOG, "hive") + .build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 7f139042c475..3829808eaf0c 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -29,6 +29,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; @@ -186,6 +187,7 @@ public final class SystemSessionProperties public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task"; public static final String MAX_TASKS_PER_STAGE = "max_tasks_per_stage"; public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled"; + public static final String CTE_MATERIALIZATION_STRATEGY = "cte_materialization_strategy"; public static final String DEFAULT_JOIN_SELECTIVITY_COEFFICIENT = "default_join_selectivity_coefficient"; public static final String PUSH_LIMIT_THROUGH_OUTER_JOIN = "push_limit_through_outer_join"; public static final String OPTIMIZE_CONSTANT_GROUPING_KEYS = "optimize_constant_grouping_keys"; @@ -1038,6 +1040,18 @@ public SystemSessionProperties( "use a default filter factor for unknown filters in a filter node", featuresConfig.isDefaultFilterFactorEnabled(), false), + new PropertyMetadata<>( + CTE_MATERIALIZATION_STRATEGY, + format("The strategy to materialize common table expressions. Options are %s", + Stream.of(CteMaterializationStrategy.values()) + .map(CteMaterializationStrategy::name) + .collect(joining(","))), + VARCHAR, + CteMaterializationStrategy.class, + featuresConfig.getCteMaterializationStrategy(), + false, + value -> CteMaterializationStrategy.valueOf(((String) value).toUpperCase()), + CteMaterializationStrategy::name), new PropertyMetadata<>( DEFAULT_JOIN_SELECTIVITY_COEFFICIENT, "use a default join selectivity coefficient factor when column statistics are not available in a join node", @@ -2310,6 +2324,11 @@ public static DataSize getFilterAndProjectMinOutputPageSize(Session session) return session.getSystemProperty(FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE, DataSize.class); } + public static CteMaterializationStrategy getCteMaterializationStrategy(Session session) + { + return session.getSystemProperty(CTE_MATERIALIZATION_STRATEGY, CteMaterializationStrategy.class); + } + public static int getFilterAndProjectMinOutputPageRowCount(Session session) { return session.getSystemProperty(FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_ROW_COUNT, Integer.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java new file mode 100644 index 000000000000..fd4e793d6a7b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java @@ -0,0 +1,387 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql; + +import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.NewTableLayout; +import com.facebook.presto.metadata.PartitioningMetadata; +import com.facebook.presto.metadata.TableLayout; +import com.facebook.presto.metadata.TableLayoutResult; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorNewTableLayout; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.TableMetadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.statistics.TableStatisticsMetadata; +import com.facebook.presto.sql.planner.BasePlanFragmenter; +import com.facebook.presto.sql.planner.Partitioning; +import com.facebook.presto.sql.planner.PartitioningHandle; +import com.facebook.presto.sql.planner.PartitioningScheme; +import com.facebook.presto.sql.planner.StatisticsAggregationPlanner; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.StatisticAggregations; +import com.facebook.presto.sql.planner.plan.TableFinishNode; +import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; +import com.facebook.presto.sql.planner.plan.TableWriterNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.facebook.presto.SystemSessionProperties.getTaskPartitionedWriterCount; +import static com.facebook.presto.SystemSessionProperties.isTableWriterMergeOperatorEnabled; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.ensureSourceOrderingGatheringExchange; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.concat; +import static java.lang.String.format; +import static java.util.function.Function.identity; + +// Planner Util for creating temporary tables +public class TemporaryTableUtil +{ + private TemporaryTableUtil() + { + } + + public static TableScanNode createTemporaryTableScan( + Metadata metadata, + Session session, + PlanNodeIdAllocator idAllocator, + Optional sourceLocation, + TableHandle tableHandle, + List outputVariables, + Map variableToColumnMap, + PartitioningMetadata expectedPartitioningMetadata) + { + Map columnHandles = metadata.getColumnHandles(session, tableHandle); + Map outputColumns = outputVariables.stream() + .collect(toImmutableMap(identity(), variableToColumnMap::get)); + Set outputColumnHandles = outputColumns.values().stream() + .map(ColumnMetadata::getName) + .map(columnHandles::get) + .collect(toImmutableSet()); + + TableLayoutResult selectedLayout = metadata.getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.of(outputColumnHandles)); + verify(selectedLayout.getUnenforcedConstraint().equals(TupleDomain.all()), "temporary table layout shouldn't enforce any constraints"); + verify(!selectedLayout.getLayout().getColumns().isPresent(), "temporary table layout must provide all the columns"); + TableLayout.TablePartitioning expectedPartitioning = new TableLayout.TablePartitioning( + expectedPartitioningMetadata.getPartitioningHandle(), + expectedPartitioningMetadata.getPartitionColumns().stream() + .map(columnHandles::get) + .collect(toImmutableList())); + verify(selectedLayout.getLayout().getTablePartitioning().equals(Optional.of(expectedPartitioning)), "invalid temporary table partitioning"); + + Map assignments = outputVariables.stream() + .collect(toImmutableMap(identity(), variable -> columnHandles.get(outputColumns.get(variable).getName()))); + + return new TableScanNode( + sourceLocation, + idAllocator.getNextId(), + selectedLayout.getLayout().getNewTableHandle(), + outputVariables, + assignments, + TupleDomain.all(), + TupleDomain.all()); + } + + public static Map assignTemporaryTableColumnNames(Collection outputVariables, + Collection constantPartitioningVariables) + { + ImmutableMap.Builder result = ImmutableMap.builder(); + int column = 0; + for (VariableReferenceExpression outputVariable : concat(outputVariables, constantPartitioningVariables)) { + String columnName = format("_c%d_%s", column, outputVariable.getName()); + result.put(outputVariable, new ColumnMetadata(columnName, outputVariable.getType())); + column++; + } + return result.build(); + } + + public static BasePlanFragmenter.PartitioningVariableAssignments assignPartitioningVariables(VariableAllocator variableAllocator, + Partitioning partitioning) + { + ImmutableList.Builder variables = ImmutableList.builder(); + ImmutableMap.Builder constants = ImmutableMap.builder(); + for (RowExpression argument : partitioning.getArguments()) { + checkArgument(argument instanceof ConstantExpression || argument instanceof VariableReferenceExpression, + format("Expect argument to be ConstantExpression or VariableReferenceExpression, got %s (%s)", argument.getClass(), argument)); + VariableReferenceExpression variable; + if (argument instanceof ConstantExpression) { + variable = variableAllocator.newVariable(argument.getSourceLocation(), "constant_partition", argument.getType()); + constants.put(variable, argument); + } + else { + variable = (VariableReferenceExpression) argument; + } + variables.add(variable); + } + return new BasePlanFragmenter.PartitioningVariableAssignments(variables.build(), constants.build()); + } + + public static TableFinishNode createTemporaryTableWriteWithoutExchanges( + Metadata metadata, + Session session, + PlanNodeIdAllocator idAllocator, + VariableAllocator variableAllocator, + PlanNode source, + TableHandle tableHandle, + List outputs, + Map variableToColumnMap, + PartitioningMetadata partitioningMetadata, + VariableReferenceExpression outputVar) + { + SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable(); + TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName); + List outputColumnNames = outputs.stream() + .map(variableToColumnMap::get) + .map(ColumnMetadata::getName) + .collect(toImmutableList()); + Set outputNotNullColumnVariables = outputs.stream() + .filter(variable -> variableToColumnMap.get(variable) != null && !(variableToColumnMap.get(variable).isNullable())) + .collect(Collectors.toSet()); + PartitioningHandle partitioningHandle = partitioningMetadata.getPartitioningHandle(); + List partitionColumns = partitioningMetadata.getPartitionColumns(); + Map columnNameToVariable = variableToColumnMap.entrySet().stream() + .collect(toImmutableMap(entry -> entry.getValue().getName(), Map.Entry::getKey)); + List partitioningVariables = partitionColumns.stream() + .map(columnNameToVariable::get) + .collect(toImmutableList()); + PartitioningScheme partitioningScheme = new PartitioningScheme( + Partitioning.create(partitioningHandle, partitioningVariables), + outputs, + Optional.empty(), + false, + Optional.empty()); + return new TableFinishNode( + source.getSourceLocation(), + idAllocator.getNextId(), + new TableWriterNode( + source.getSourceLocation(), + idAllocator.getNextId(), + source, + Optional.of(insertReference), + variableAllocator.newVariable("rows", BIGINT), + variableAllocator.newVariable("fragments", VARBINARY), + variableAllocator.newVariable("commitcontext", VARBINARY), + outputs, + outputColumnNames, + outputNotNullColumnVariables, + Optional.of(partitioningScheme), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.of(insertReference), + outputVar, + Optional.empty(), + Optional.empty()); + } + + public static TableFinishNode createTemporaryTableWriteWithExchanges( + Metadata metadata, + Session session, + PlanNodeIdAllocator idAllocator, + VariableAllocator variableAllocator, + StatisticsAggregationPlanner statisticsAggregationPlanner, + Optional sourceLocation, + TableHandle tableHandle, + Map variableToColumnMap, + List outputs, + List> inputs, + List sources, + Map constantExpressions, + PartitioningMetadata partitioningMetadata) + { + if (!constantExpressions.isEmpty()) { + List constantVariables = ImmutableList.copyOf(constantExpressions.keySet()); + outputs = ImmutableList.builder() + .addAll(outputs) + .addAll(constantVariables) + .build(); + inputs = inputs.stream() + .map(input -> ImmutableList.builder() + .addAll(input) + .addAll(constantVariables) + .build()) + .collect(toImmutableList()); + + // update sources + sources = sources.stream() + .map(source -> { + Assignments.Builder assignments = Assignments.builder(); + source.getOutputVariables().forEach(variable -> assignments.put(variable, new VariableReferenceExpression(variable.getSourceLocation(), variable.getName(), variable.getType()))); + constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable))); + return new ProjectNode(source.getSourceLocation(), idAllocator.getNextId(), source, assignments.build(), ProjectNode.Locality.LOCAL); + }) + .collect(toImmutableList()); + } + + NewTableLayout insertLayout = metadata.getInsertLayout(session, tableHandle) + // TODO: support insert into non partitioned table + .orElseThrow(() -> new IllegalArgumentException("insertLayout for the temporary table must be present")); + + PartitioningHandle partitioningHandle = partitioningMetadata.getPartitioningHandle(); + List partitionColumns = partitioningMetadata.getPartitionColumns(); + ConnectorNewTableLayout expectedNewTableLayout = new ConnectorNewTableLayout(partitioningHandle.getConnectorHandle(), partitionColumns); + verify(insertLayout.getLayout().equals(expectedNewTableLayout), "unexpected new table layout"); + + Map columnNameToVariable = variableToColumnMap.entrySet().stream() + .collect(toImmutableMap(entry -> entry.getValue().getName(), Map.Entry::getKey)); + List partitioningVariables = partitionColumns.stream() + .map(columnNameToVariable::get) + .collect(toImmutableList()); + + List outputColumnNames = outputs.stream() + .map(variableToColumnMap::get) + .map(ColumnMetadata::getName) + .collect(toImmutableList()); + Set outputNotNullColumnVariables = outputs.stream() + .filter(variable -> variableToColumnMap.get(variable) != null && !(variableToColumnMap.get(variable).isNullable())) + .collect(Collectors.toSet()); + + SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable(); + TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName); + + PartitioningScheme partitioningScheme = new PartitioningScheme( + Partitioning.create(partitioningHandle, partitioningVariables), + outputs, + Optional.empty(), + false, + Optional.empty()); + + ExchangeNode writerRemoteSource = new ExchangeNode( + sourceLocation, + idAllocator.getNextId(), + REPARTITION, + REMOTE_STREAMING, + partitioningScheme, + sources, + inputs, + false, + Optional.empty()); + + ExchangeNode writerSource; + if (getTaskPartitionedWriterCount(session) == 1) { + writerSource = gatheringExchange( + idAllocator.getNextId(), + LOCAL, + writerRemoteSource); + } + else { + writerSource = partitionedExchange( + idAllocator.getNextId(), + LOCAL, + writerRemoteSource, + partitioningScheme); + } + + String catalogName = tableHandle.getConnectorId().getCatalogName(); + TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); + TableStatisticsMetadata statisticsMetadata = metadata.getStatisticsCollectionMetadataForWrite(session, catalogName, tableMetadata.getMetadata()); + StatisticsAggregationPlanner.TableStatisticAggregation statisticsResult = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnNameToVariable); + StatisticAggregations.Parts aggregations = statisticsResult.getAggregations().splitIntoPartialAndFinal(variableAllocator, metadata.getFunctionAndTypeManager()); + PlanNode tableWriterMerge; + + // Disabled by default. Enable when the column statistics are essential for future runtime adaptive plan optimizations + boolean enableStatsCollectionForTemporaryTable = SystemSessionProperties.isEnableStatsCollectionForTemporaryTable(session); + + if (isTableWriterMergeOperatorEnabled(session)) { + StatisticAggregations.Parts localAggregations = aggregations.getPartialAggregation().splitIntoPartialAndIntermediate(variableAllocator, metadata.getFunctionAndTypeManager()); + tableWriterMerge = new TableWriterMergeNode( + sourceLocation, + idAllocator.getNextId(), + gatheringExchange( + idAllocator.getNextId(), + LOCAL, + new TableWriterNode( + sourceLocation, + idAllocator.getNextId(), + writerSource, + Optional.of(insertReference), + variableAllocator.newVariable("partialrows", BIGINT), + variableAllocator.newVariable("partialfragments", VARBINARY), + variableAllocator.newVariable("partialtablecommitcontext", VARBINARY), + outputs, + outputColumnNames, + outputNotNullColumnVariables, + Optional.of(partitioningScheme), + Optional.empty(), + enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty(), + Optional.empty())), + variableAllocator.newVariable("intermediaterows", BIGINT), + variableAllocator.newVariable("intermediatefragments", VARBINARY), + variableAllocator.newVariable("intermediatetablecommitcontext", VARBINARY), + enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getIntermediateAggregation()) : Optional.empty()); + } + else { + tableWriterMerge = new TableWriterNode( + sourceLocation, + idAllocator.getNextId(), + writerSource, + Optional.of(insertReference), + variableAllocator.newVariable("partialrows", BIGINT), + variableAllocator.newVariable("partialfragments", VARBINARY), + variableAllocator.newVariable("partialtablecommitcontext", VARBINARY), + outputs, + outputColumnNames, + outputNotNullColumnVariables, + Optional.of(partitioningScheme), + Optional.empty(), + enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty(), + Optional.empty()); + } + + return new TableFinishNode( + sourceLocation, + idAllocator.getNextId(), + ensureSourceOrderingGatheringExchange( + idAllocator.getNextId(), + REMOTE_STREAMING, + tableWriterMerge), + Optional.of(insertReference), + variableAllocator.newVariable("rows", BIGINT), + enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getFinalAggregation()) : Optional.empty(), + enableStatsCollectionForTemporaryTable ? Optional.of(statisticsResult.getDescriptor()) : Optional.empty()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/CTEInformationCollector.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/CTEInformationCollector.java index b2f0aaa9008b..3f35872eb618 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/CTEInformationCollector.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/CTEInformationCollector.java @@ -23,9 +23,9 @@ public class CTEInformationCollector { private final HashMap cteInformationMap = new HashMap<>(); - public void addCTEReference(String cteName, boolean isView) + public void addCTEReference(String cteName, boolean isView, boolean isMaterialized) { - cteInformationMap.putIfAbsent(cteName, new CTEInformation(cteName, 0, isView)); + cteInformationMap.putIfAbsent(cteName, new CTEInformation(cteName, 0, isView, isMaterialized)); cteInformationMap.get(cteName).incrementReferences(); } @@ -33,4 +33,9 @@ public List getCTEInformationList() { return ImmutableList.copyOf(cteInformationMap.values()); } + + public HashMap getCteInformationMap() + { + return cteInformationMap; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 292994227531..e9f979c0f3e4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -90,6 +90,8 @@ public class FeaturesConfig private DataSize maxRevocableMemoryPerTask = new DataSize(500, MEGABYTE); private JoinReorderingStrategy joinReorderingStrategy = JoinReorderingStrategy.AUTOMATIC; private PartialMergePushdownStrategy partialMergePushdownStrategy = PartialMergePushdownStrategy.NONE; + + private CteMaterializationStrategy cteMaterializationStrategy = CteMaterializationStrategy.NONE; private int maxReorderedJoins = 9; private boolean useHistoryBasedPlanStatistics; private boolean trackHistoryBasedPlanStatistics; @@ -348,6 +350,12 @@ public boolean isAdoptingMergedPreference() } } + public enum CteMaterializationStrategy + { + ALL, // Materialize all CTES + NONE // Materialize no ctes + } + public enum TaskSpillingStrategy { ORDER_BY_CREATE_TIME, // When spilling is triggered, revoke tasks in order of oldest to newest @@ -565,6 +573,19 @@ public FeaturesConfig setLegacyMapSubscript(boolean value) return this; } + public CteMaterializationStrategy getCteMaterializationStrategy() + { + return cteMaterializationStrategy; + } + + @Config("cte-materialization-strategy") + @ConfigDescription("Set strategy used to determine whether to materialize CTEs (ALL, NONE)") + public FeaturesConfig setCteMaterializationStrategy(CteMaterializationStrategy cteMaterializationStrategy) + { + this.cteMaterializationStrategy = cteMaterializationStrategy; + return this; + } + public boolean isLegacyMapSubscript() { return legacyMapSubscript; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index b75f324839ff..f27459dd3828 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -15,39 +15,26 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; -import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.NewTableLayout; import com.facebook.presto.metadata.PartitioningMetadata; import com.facebook.presto.metadata.TableLayout; -import com.facebook.presto.metadata.TableLayoutResult; import com.facebook.presto.operator.StageExecutionDescriptor; -import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.ConnectorNewTableLayout; -import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.SchemaTableName; -import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; -import com.facebook.presto.spi.TableMetadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; -import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.ValuesNode; -import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.spi.statistics.TableStatisticsMetadata; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -55,10 +42,8 @@ import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.planner.plan.StatisticAggregations; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; -import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.base.Preconditions; @@ -67,20 +52,19 @@ import com.google.common.collect.ImmutableSet; import java.util.ArrayList; -import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; -import static com.facebook.presto.SystemSessionProperties.getTaskPartitionedWriterCount; import static com.facebook.presto.SystemSessionProperties.isForceSingleNodeOutput; -import static com.facebook.presto.SystemSessionProperties.isTableWriterMergeOperatorEnabled; -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.sql.TemporaryTableUtil.assignPartitioningVariables; +import static com.facebook.presto.sql.TemporaryTableUtil.assignTemporaryTableColumnNames; +import static com.facebook.presto.sql.TemporaryTableUtil.createTemporaryTableScan; +import static com.facebook.presto.sql.TemporaryTableUtil.createTemporaryTableWriteWithExchanges; import static com.facebook.presto.sql.planner.BasePlanFragmenter.FragmentProperties; import static com.facebook.presto.sql.planner.PlanFragmenterUtils.isCoordinatorOnlyDistribution; import static com.facebook.presto.sql.planner.SchedulingOrderVisitor.scheduleOrder; @@ -89,24 +73,16 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.isCompatibleSystemPartitioning; import static com.facebook.presto.sql.planner.VariablesExtractor.extractOutputVariables; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_MATERIALIZED; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.ensureSourceOrderingGatheringExchange; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange; -import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.concat; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.function.Function.identity; /** * Main rewriter that creates plan fragments @@ -125,6 +101,8 @@ public abstract class BasePlanFragmenter private final Set outputTableWriterNodeIds; private final StatisticsAggregationPlanner statisticsAggregationPlanner; + private Map cteNameToTableScanMap = new HashMap<>(); + public BasePlanFragmenter( Session session, Metadata metadata, @@ -223,6 +201,31 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext context) + { + // Since this is topologically sorted by the LogicalCtePlanner, need to make sure that execution order follows + // Can be optimized further to avoid non dependents from getting blocked + int cteProducerCount = node.getCteProducers().size(); + checkArgument(cteProducerCount >= 1, "Sequence Node has 0 CTE producers"); + PlanNode source = node.getCteProducers().get(cteProducerCount - 1); + FragmentProperties childProperties = new FragmentProperties(new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + source.getOutputVariables())); + SubPlan lastSubPlan = buildSubPlan(source, childProperties, context); + + for (int sourceIndex = cteProducerCount - 2; sourceIndex >= 0; sourceIndex--) { + source = node.getCteProducers().get(sourceIndex); + childProperties = new FragmentProperties(new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + source.getOutputVariables())); + childProperties.addChildren(ImmutableList.of(lastSubPlan)); + lastSubPlan = buildSubPlan(source, childProperties, context); + } + context.get().addChildren(ImmutableList.of(lastSubPlan)); + return node.getPrimarySource().accept(this, context); + } + @Override public PlanNode visitMetadataDelete(MetadataDeleteNode node, RewriteContext context) { @@ -325,7 +328,7 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite "The catalog must support providing a custom partitioning and storing temporary tables.")); Partitioning partitioning = partitioningScheme.getPartitioning(); - PartitioningVariableAssignments partitioningVariableAssignments = assignPartitioningVariables(partitioning); + PartitioningVariableAssignments partitioningVariableAssignments = assignPartitioningVariables(variableAllocator, partitioning); Map variableToColumnMap = assignTemporaryTableColumnNames(exchange.getOutputVariables(), partitioningVariableAssignments.getConstants().keySet()); List partitioningVariables = partitioningVariableAssignments.getVariables(); List partitionColumns = partitioningVariables.stream() @@ -353,6 +356,9 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite } TableScanNode scan = createTemporaryTableScan( + metadata, + session, + idAllocator, exchange.getSourceLocation(), temporaryTableHandle, exchange.getOutputVariables(), @@ -362,7 +368,12 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite checkArgument( !exchange.getPartitioningScheme().isReplicateNullsAndAny(), "materialized remote exchange is not supported when replicateNullsAndAny is needed"); - TableFinishNode write = createTemporaryTableWrite( + TableFinishNode write = createTemporaryTableWriteWithExchanges( + metadata, + session, + idAllocator, + variableAllocator, + statisticsAggregationPlanner, scan.getSourceLocation(), temporaryTableHandle, variableToColumnMap, @@ -383,240 +394,6 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite return visitTableScan(scan, context); } - private PartitioningVariableAssignments assignPartitioningVariables(Partitioning partitioning) - { - ImmutableList.Builder variables = ImmutableList.builder(); - ImmutableMap.Builder constants = ImmutableMap.builder(); - for (RowExpression argument : partitioning.getArguments()) { - checkArgument(argument instanceof ConstantExpression || argument instanceof VariableReferenceExpression, format("Expect argument to be ConstantExpression or VariableReferenceExpression, get %s (%s)", argument.getClass(), argument)); - VariableReferenceExpression variable; - if (argument instanceof ConstantExpression) { - variable = variableAllocator.newVariable(argument.getSourceLocation(), "constant_partition", argument.getType()); - constants.put(variable, argument); - } - else { - variable = (VariableReferenceExpression) argument; - } - variables.add(variable); - } - return new PartitioningVariableAssignments(variables.build(), constants.build()); - } - - private Map assignTemporaryTableColumnNames(Collection outputVariables, Collection constantPartitioningVariables) - { - ImmutableMap.Builder result = ImmutableMap.builder(); - int column = 0; - for (VariableReferenceExpression outputVariable : concat(outputVariables, constantPartitioningVariables)) { - String columnName = format("_c%d_%s", column, outputVariable.getName()); - result.put(outputVariable, new ColumnMetadata(columnName, outputVariable.getType())); - column++; - } - return result.build(); - } - - private TableScanNode createTemporaryTableScan( - Optional sourceLocation, - TableHandle tableHandle, - List outputVariables, - Map variableToColumnMap, - PartitioningMetadata expectedPartitioningMetadata) - { - Map columnHandles = metadata.getColumnHandles(session, tableHandle); - Map outputColumns = outputVariables.stream() - .collect(toImmutableMap(identity(), variableToColumnMap::get)); - Set outputColumnHandles = outputColumns.values().stream() - .map(ColumnMetadata::getName) - .map(columnHandles::get) - .collect(toImmutableSet()); - - TableLayoutResult selectedLayout = metadata.getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.of(outputColumnHandles)); - verify(selectedLayout.getUnenforcedConstraint().equals(TupleDomain.all()), "temporary table layout shouldn't enforce any constraints"); - verify(!selectedLayout.getLayout().getColumns().isPresent(), "temporary table layout must provide all the columns"); - TableLayout.TablePartitioning expectedPartitioning = new TableLayout.TablePartitioning( - expectedPartitioningMetadata.getPartitioningHandle(), - expectedPartitioningMetadata.getPartitionColumns().stream() - .map(columnHandles::get) - .collect(toImmutableList())); - verify(selectedLayout.getLayout().getTablePartitioning().equals(Optional.of(expectedPartitioning)), "invalid temporary table partitioning"); - - Map assignments = outputVariables.stream() - .collect(toImmutableMap(identity(), variable -> columnHandles.get(outputColumns.get(variable).getName()))); - - return new TableScanNode( - sourceLocation, - idAllocator.getNextId(), - selectedLayout.getLayout().getNewTableHandle(), - outputVariables, - assignments, - TupleDomain.all(), - TupleDomain.all()); - } - - private TableFinishNode createTemporaryTableWrite( - Optional sourceLocation, TableHandle tableHandle, - Map variableToColumnMap, - List outputs, - List> inputs, - List sources, - Map constantExpressions, - PartitioningMetadata partitioningMetadata) - { - if (!constantExpressions.isEmpty()) { - List constantVariables = ImmutableList.copyOf(constantExpressions.keySet()); - - // update outputs - outputs = ImmutableList.builder() - .addAll(outputs) - .addAll(constantVariables) - .build(); - - // update inputs - inputs = inputs.stream() - .map(input -> ImmutableList.builder() - .addAll(input) - .addAll(constantVariables) - .build()) - .collect(toImmutableList()); - - // update sources - sources = sources.stream() - .map(source -> { - Assignments.Builder assignments = Assignments.builder(); - source.getOutputVariables().forEach(variable -> assignments.put(variable, new VariableReferenceExpression(variable.getSourceLocation(), variable.getName(), variable.getType()))); - constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable))); - return new ProjectNode(source.getSourceLocation(), idAllocator.getNextId(), source, assignments.build(), ProjectNode.Locality.LOCAL); - }) - .collect(toImmutableList()); - } - - NewTableLayout insertLayout = metadata.getInsertLayout(session, tableHandle) - // TODO: support insert into non partitioned table - .orElseThrow(() -> new IllegalArgumentException("insertLayout for the temporary table must be present")); - - PartitioningHandle partitioningHandle = partitioningMetadata.getPartitioningHandle(); - List partitionColumns = partitioningMetadata.getPartitionColumns(); - ConnectorNewTableLayout expectedNewTableLayout = new ConnectorNewTableLayout(partitioningHandle.getConnectorHandle(), partitionColumns); - verify(insertLayout.getLayout().equals(expectedNewTableLayout), "unexpected new table layout"); - - Map columnNameToVariable = variableToColumnMap.entrySet().stream() - .collect(toImmutableMap(entry -> entry.getValue().getName(), Map.Entry::getKey)); - List partitioningVariables = partitionColumns.stream() - .map(columnNameToVariable::get) - .collect(toImmutableList()); - - List outputColumnNames = outputs.stream() - .map(variableToColumnMap::get) - .map(ColumnMetadata::getName) - .collect(toImmutableList()); - Set outputNotNullColumnVariables = outputs.stream() - .filter(variable -> variableToColumnMap.get(variable) != null && !(variableToColumnMap.get(variable).isNullable())) - .collect(Collectors.toSet()); - - SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable(); - TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName); - - PartitioningScheme partitioningScheme = new PartitioningScheme( - Partitioning.create(partitioningHandle, partitioningVariables), - outputs, - Optional.empty(), - false, - Optional.empty()); - - ExchangeNode writerRemoteSource = new ExchangeNode( - sourceLocation, - idAllocator.getNextId(), - REPARTITION, - REMOTE_STREAMING, - partitioningScheme, - sources, - inputs, - false, - Optional.empty()); - - ExchangeNode writerSource; - if (getTaskPartitionedWriterCount(session) == 1) { - writerSource = gatheringExchange( - idAllocator.getNextId(), - LOCAL, - writerRemoteSource); - } - else { - writerSource = partitionedExchange( - idAllocator.getNextId(), - LOCAL, - writerRemoteSource, - partitioningScheme); - } - - String catalogName = tableHandle.getConnectorId().getCatalogName(); - TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); - TableStatisticsMetadata statisticsMetadata = metadata.getStatisticsCollectionMetadataForWrite(session, catalogName, tableMetadata.getMetadata()); - StatisticsAggregationPlanner.TableStatisticAggregation statisticsResult = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnNameToVariable); - StatisticAggregations.Parts aggregations = statisticsResult.getAggregations().splitIntoPartialAndFinal(variableAllocator, metadata.getFunctionAndTypeManager()); - PlanNode tableWriterMerge; - - // Disabled by default. Enable when the column statistics are essential for future runtime adaptive plan optimizations - boolean enableStatsCollectionForTemporaryTable = SystemSessionProperties.isEnableStatsCollectionForTemporaryTable(session); - - if (isTableWriterMergeOperatorEnabled(session)) { - StatisticAggregations.Parts localAggregations = aggregations.getPartialAggregation().splitIntoPartialAndIntermediate(variableAllocator, metadata.getFunctionAndTypeManager()); - tableWriterMerge = new TableWriterMergeNode( - sourceLocation, - idAllocator.getNextId(), - gatheringExchange( - idAllocator.getNextId(), - LOCAL, - new TableWriterNode( - sourceLocation, - idAllocator.getNextId(), - writerSource, - Optional.of(insertReference), - variableAllocator.newVariable("partialrows", BIGINT), - variableAllocator.newVariable("partialfragments", VARBINARY), - variableAllocator.newVariable("partialtablecommitcontext", VARBINARY), - outputs, - outputColumnNames, - outputNotNullColumnVariables, - Optional.of(partitioningScheme), - Optional.empty(), - enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty(), - Optional.empty())), - variableAllocator.newVariable("intermediaterows", BIGINT), - variableAllocator.newVariable("intermediatefragments", VARBINARY), - variableAllocator.newVariable("intermediatetablecommitcontext", VARBINARY), - enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getIntermediateAggregation()) : Optional.empty()); - } - else { - tableWriterMerge = new TableWriterNode( - sourceLocation, - idAllocator.getNextId(), - writerSource, - Optional.of(insertReference), - variableAllocator.newVariable("partialrows", BIGINT), - variableAllocator.newVariable("partialfragments", VARBINARY), - variableAllocator.newVariable("partialtablecommitcontext", VARBINARY), - outputs, - outputColumnNames, - outputNotNullColumnVariables, - Optional.of(partitioningScheme), - Optional.empty(), - enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty(), - Optional.empty()); - } - - return new TableFinishNode( - sourceLocation, - idAllocator.getNextId(), - ensureSourceOrderingGatheringExchange( - idAllocator.getNextId(), - REMOTE_STREAMING, - tableWriterMerge), - Optional.of(insertReference), - variableAllocator.newVariable("rows", BIGINT), - enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getFinalAggregation()) : Optional.empty(), - enableStatsCollectionForTemporaryTable ? Optional.of(statisticsResult.getDescriptor()) : Optional.empty()); - } - private SubPlan buildSubPlan(PlanNode node, FragmentProperties properties, RewriteContext context) { PlanFragmentId planFragmentId = nextFragmentId(); @@ -756,12 +533,12 @@ public Set getPartitionedSources() } } - private static class PartitioningVariableAssignments + public static class PartitioningVariableAssignments { private final List variables; private final Map constants; - private PartitioningVariableAssignments(List variables, Map constants) + public PartitioningVariableAssignments(List variables, Map constants) { this.variables = ImmutableList.copyOf(requireNonNull(variables, "variables is null")); this.constants = ImmutableMap.copyOf(requireNonNull(constants, "constants is null")); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java index 6a6981110112..979c056ece97 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.google.common.base.VerifyException; @@ -61,6 +62,14 @@ public Void visitTableWriter(TableWriterNode node, Void context) return null; } + public Void visitSequence(SequenceNode node, Void context) + { + // Left children of sequence are ignored since they don't output anything + node.getPrimarySource().accept(this, context); + + return null; + } + @Override public Void visitPlan(PlanNode node, Void context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 57caee0d9b24..ac61a9c46824 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -144,12 +144,14 @@ import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.optimizations.KeyBasedSampler; import com.facebook.presto.sql.planner.optimizations.LimitPushDown; +import com.facebook.presto.sql.planner.optimizations.LogicalCteOptimizer; import com.facebook.presto.sql.planner.optimizations.MergeJoinForSortedInputOptimizer; import com.facebook.presto.sql.planner.optimizations.MergePartialAggregationsWithFilter; import com.facebook.presto.sql.planner.optimizations.MetadataDeleteOptimizer; import com.facebook.presto.sql.planner.optimizations.MetadataQueryOptimizer; import com.facebook.presto.sql.planner.optimizations.OptimizeMixedDistinctAggregations; import com.facebook.presto.sql.planner.optimizations.PayloadJoinOptimizer; +import com.facebook.presto.sql.planner.optimizations.PhysicalCteOptimizer; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.optimizations.PredicatePushDown; import com.facebook.presto.sql.planner.optimizations.PrefilterForLimitingAggregation; @@ -275,6 +277,8 @@ public PlanOptimizers( new PruneLimitColumns(), new PruneTableScanColumns()); + builder.add(new LogicalCteOptimizer(metadata)); + IterativeOptimizer inlineProjections = new IterativeOptimizer( metadata, ruleStats, @@ -760,6 +764,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges + builder.add(new PhysicalCteOptimizer(metadata)); // Must run before AddExchanges builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser, partitioningProviderManager))); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index dbf53c24f900..986f84c5e517 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -28,6 +28,7 @@ import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteReferenceNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; @@ -106,6 +107,7 @@ import java.util.Set; import java.util.stream.IntStream; +import static com.facebook.presto.SystemSessionProperties.getCteMaterializationStrategy; import static com.facebook.presto.SystemSessionProperties.getQueryAnalyzerTimeout; import static com.facebook.presto.common.type.TypeUtils.isEnumType; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; @@ -116,9 +118,11 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.ALL; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; +import static com.facebook.presto.sql.planner.optimizations.CteUtils.isCteMaterializable; import static com.facebook.presto.sql.tree.Join.Type.INNER; import static com.facebook.presto.sql.tree.Join.Type.LEFT; import static com.facebook.presto.sql.tree.Join.Type.RIGHT; @@ -180,8 +184,18 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) if (namedQuery.isFromView()) { cteName = createQualifiedObjectName(session, node, node.getName()).toString(); } - session.getCteInformationCollector().addCTEReference(cteName, namedQuery.isFromView()); + context.getNestedCteStack().push(cteName, namedQuery.getQuery()); RelationPlan subPlan = process(namedQuery.getQuery(), context); + context.getNestedCteStack().pop(namedQuery.getQuery()); + boolean shouldBeMaterialized = getCteMaterializationStrategy(session).equals(ALL) && isCteMaterializable(subPlan.getRoot().getOutputVariables()); + session.getCteInformationCollector().addCTEReference(cteName, namedQuery.isFromView(), shouldBeMaterialized); + if (shouldBeMaterialized) { + subPlan = new RelationPlan( + new CteReferenceNode(getSourceLocation(node.getLocation()), + idAllocator.getNextId(), subPlan.getRoot(), context.getNestedCteStack().getRawPath(cteName)), + subPlan.getScope(), + subPlan.getFieldMappings()); + } // Add implicit coercions if view query produces types that don't match the declared output types // of the view (e.g., if the underlying tables referenced by the view changed) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java index d6853d343dc8..cb1f538e6c4b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java @@ -16,6 +16,12 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; +import com.facebook.presto.sql.tree.Query; +import com.google.common.annotations.VisibleForTesting; + +import java.util.HashMap; +import java.util.Map; +import java.util.Stack; import static com.facebook.presto.SystemSessionProperties.getMaxLeafNodesInPlan; import static com.facebook.presto.SystemSessionProperties.isLeafNodeLimitEnabled; @@ -28,10 +34,18 @@ public class SqlPlannerContext private int leafNodesInLogicalPlan; private final SqlToRowExpressionTranslator.Context translatorContext; + private final NestedCteStack nestedCteStack; + public SqlPlannerContext(int leafNodesInLogicalPlan) { this.leafNodesInLogicalPlan = leafNodesInLogicalPlan; this.translatorContext = new SqlToRowExpressionTranslator.Context(); + this.nestedCteStack = new NestedCteStack(); + } + + public NestedCteStack getNestedCteStack() + { + return nestedCteStack; } public SqlToRowExpressionTranslator.Context getTranslatorContext() @@ -49,4 +63,58 @@ public void incrementLeafNodes(Session session) } } } + + public class NestedCteStack + { + @VisibleForTesting + public static final String delimiter = "_*%$_"; + private final Stack cteStack; + private final Map rawCtePathMap; + + public NestedCteStack() + { + this.cteStack = new Stack<>(); + this.rawCtePathMap = new HashMap<>(); + } + + public void push(String cteName, Query query) + { + this.cteStack.push(cteName); + if (query.getWith().isPresent()) { + // All ctes defined in this context should have their paths updated + query.getWith().get().getQueries().forEach(with -> this.addNestedCte(with.getName().toString())); + } + } + + public void pop(Query query) + { + this.cteStack.pop(); + if (query.getWith().isPresent()) { + query.getWith().get().getQueries().forEach(with -> this.removeNestedCte(with.getName().toString())); + } + } + + public String getRawPath(String cteName) + { + if (!this.rawCtePathMap.containsKey(cteName)) { + return cteName; + } + return this.rawCtePathMap.get(cteName); + } + + private void addNestedCte(String cteName) + { + this.rawCtePathMap.put(cteName, getCurrentRelativeCtePath() + delimiter + cteName); + } + + private void removeNestedCte(String cteName) + { + this.rawCtePathMap.remove(cteName); + } + + public String getCurrentRelativeCtePath() + { + return String.join(delimiter, cteStack); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java index bab91d2c6be9..e57025d7259a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -22,6 +22,7 @@ import com.facebook.presto.cost.TaskCountEstimator; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.CteConsumerNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.ValuesNode; @@ -198,7 +199,7 @@ static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProv } List sourceNodes = PlanNodeSearcher.searchFrom(node, lookup) - .whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class, RemoteSourceNode.class)) + .whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class, RemoteSourceNode.class, CteConsumerNode.class)) .findAll(); return sourceNodes.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 061a3023a8ef..ddff237a1db2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -36,6 +36,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; @@ -95,6 +96,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.stream.Stream; import static com.facebook.presto.SystemSessionProperties.getAggregationPartitioningMergingStrategy; import static com.facebook.presto.SystemSessionProperties.getExchangeMaterializationStrategy; @@ -616,6 +618,22 @@ public PlanWithProperties visitFilter(FilterNode node, PreferredProperties prefe return rebaseAndDeriveProperties(node, planChild(node, preferredProperties)); } + @Override + public PlanWithProperties visitSequence(SequenceNode node, PreferredProperties preferredProperties) + { + List leftPlans = node.getCteProducers().stream() + .map(source -> accept(source, PreferredProperties.any())) + .collect(toImmutableList()); + PlanWithProperties rightPlan = accept(node.getPrimarySource(), preferredProperties); + List childrenNodes = Stream.concat( + leftPlans.stream().map(PlanWithProperties::getNode), + Stream.of(rightPlan.getNode()) + ).collect(toImmutableList()); + return new PlanWithProperties( + node.replaceChildren(childrenNodes), + rightPlan.getProperties()); + } + @Override public PlanWithProperties visitTableScan(TableScanNode node, PreferredProperties preferredProperties) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java index 3309f292d585..d5144ec97a4a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java @@ -19,6 +19,9 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.CteReferenceNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -55,6 +58,9 @@ public class ApplyConnectorOptimization implements PlanOptimizer { static final Set> CONNECTOR_ACCESSIBLE_PLAN_NODES = ImmutableSet.of( + CteProducerNode.class, + CteConsumerNode.class, + CteReferenceNode.class, DistinctLimitNode.class, FilterNode.class, TableScanNode.class, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteUtils.java new file mode 100644 index 000000000000..ac4a140bef28 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteUtils.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.common.type.Varchars; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.relation.VariableReferenceExpression; + +import java.util.List; + +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; + +public class CteUtils +{ + private CteUtils() + { + } + + // Determines whether the CTE can be materialized. + public static boolean isCteMaterializable(List outputVariables) + { + return outputVariables.stream().anyMatch(CteUtils::isVariableMaterializable) + && outputVariables.stream() + .allMatch(variableReferenceExpression -> { + if (Varchars.isVarcharType(variableReferenceExpression.getType())) { + return isSupportedVarcharType((VarcharType) variableReferenceExpression.getType()); + } + return true; + }); + } + + /* + Fetches the index of the first variable that can be materialized. + ToDo: Implement usage of NDV (number of distinct values) statistics to identify the best partitioning variable, + as temporary tables are bucketed. + */ + public static Integer getCtePartitionIndex(List outputVariables) + { + for (int i = 0; i < outputVariables.size(); i++) { + if (isVariableMaterializable(outputVariables.get(i))) { + return i; + } + } + throw new PrestoException(GENERIC_INTERNAL_ERROR, "No Partitioning index found"); + } + + /* + Currently, Hive bucketing does not support the Presto type 'ROW'. + */ + public static boolean isVariableMaterializable(VariableReferenceExpression var) + { + return !(var.getType() instanceof RowType); + } + + /* + While Presto supports Varchar of length 0 (as discussed in https://github.com/trinodb/trino/issues/1136), + Hive does not support this. + */ + private static boolean isSupportedVarcharType(VarcharType varcharType) + { + return (varcharType.isUnbounded() || varcharType.getLengthSafe() != 0); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 560f38ca8eee..e62d20849f97 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.UnionNode; import com.facebook.presto.spi.relation.CallExpression; @@ -166,6 +167,21 @@ public PlanWithProperties visitApply(ApplyNode node, HashComputationSet context) return new PlanWithProperties(node, ImmutableMap.of()); } + public PlanWithProperties visitSequence(SequenceNode node, HashComputationSet context) + { + List cteProducers = node.getCteProducers().stream() + .map(c -> + planAndEnforce(c, new HashComputationSet(), true, new HashComputationSet()).getNode()) + .collect(ImmutableList.toImmutableList()); + PlanWithProperties primarySource = plan(node.getPrimarySource(), context); + return new PlanWithProperties( + replaceChildren(node, ImmutableList.builder() + .addAll(cteProducers) + .add(primarySource.getNode()) + .build()), + primarySource.getHashVariables()); + } + @Override public PlanWithProperties visitLateralJoin(LateralJoinNode node, HashComputationSet context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java new file mode 100644 index 000000000000..408bc4939ee0 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java @@ -0,0 +1,236 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.eventlistener.CTEInformation; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SequenceNode; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.graph.GraphBuilder; +import com.google.common.graph.MutableGraph; +import com.google.common.graph.Traverser; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Stack; + +import static com.facebook.presto.SystemSessionProperties.getCteMaterializationStrategy; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.ALL; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/* + * Transformation of CTE Reference Nodes: + * This process converts CTE reference nodes into corresponding CteProducers and Consumers. + * Makes sure that execution deadlocks do not exist + * + * Example: + * Before Transformation: + * JOIN + * |-- CTEReference(cte2) + * | `-- TABLESCAN2 + * `-- CTEReference(cte3) + * `-- TABLESCAN3 + * + * After Transformation: + * SEQUENCE(cte1) + * |-- CTEProducer(cte2) + * | `-- TABLESCAN2 + * |-- CTEProducer(cte3) + * | `-- TABLESCAN3 + * `-- JOIN + * |-- CTEConsumer(cte2) + * `-- CTEConsumer(cte3) + */ +public class LogicalCteOptimizer + implements PlanOptimizer +{ + private final Metadata metadata; + + public LogicalCteOptimizer(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(variableAllocator, "variableAllocator is null"); + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(warningCollector, "warningCollector is null"); + if (!getCteMaterializationStrategy(session).equals(ALL) + || session.getCteInformationCollector().getCTEInformationList().stream().noneMatch(CTEInformation::isMaterialized)) { + return PlanOptimizerResult.optimizerResult(plan, false); + } + CteEnumerator cteEnumerator = new CteEnumerator(idAllocator, variableAllocator); + PlanNode rewrittenPlan = cteEnumerator.transformPersistentCtes(plan); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, cteEnumerator.isPlanRewritten()); + } + + public class CteEnumerator + { + private PlanNodeIdAllocator planNodeIdAllocator; + private VariableAllocator variableAllocator; + + private boolean isPlanRewritten; + + public CteEnumerator(PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator variableAllocator) + { + this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator must not be null"); + this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null"); + } + + public PlanNode transformPersistentCtes(PlanNode root) + { + checkArgument(root.getSources().size() == 1, "expected newChildren to contain 1 node"); + CteTransformerContext context = new CteTransformerContext(); + PlanNode transformedCte = SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator), + root, context); + List topologicalOrderedList = context.getTopologicalOrdering(); + if (topologicalOrderedList.isEmpty()) { + isPlanRewritten = false; + return root; + } + isPlanRewritten = true; + SequenceNode sequenceNode = new SequenceNode(root.getSourceLocation(), planNodeIdAllocator.getNextId(), topologicalOrderedList, + transformedCte.getSources().get(0)); + return root.replaceChildren(Arrays.asList(sequenceNode)); + } + + public boolean isPlanRewritten() + { + return isPlanRewritten; + } + } + + public class CteConsumerTransformer + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + private final VariableAllocator variableAllocator; + + public CteConsumerTransformer(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator must not be null"); + this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null"); + } + + @Override + public PlanNode visitCteReference(CteReferenceNode node, RewriteContext context) + { + context.get().addDependency(node.getCteName()); + context.get().pushActiveCte(node.getCteName()); + // So that dependent CTEs are processed properly + PlanNode actualSource = context.rewrite(node.getSource(), context.get()); + context.get().popActiveCte(); + CteProducerNode cteProducerSource = new CteProducerNode(node.getSourceLocation(), + idAllocator.getNextId(), + actualSource, + node.getCteName(), + variableAllocator.newVariable("rows", BIGINT), node.getOutputVariables()); + context.get().addProducer(node.getCteName(), cteProducerSource); + return new CteConsumerNode(node.getSourceLocation(), idAllocator.getNextId(), actualSource.getOutputVariables(), node.getCteName()); + } + + @Override + public PlanNode visitApply(ApplyNode node, RewriteContext context) + { + return new ApplyNode(node.getSourceLocation(), + idAllocator.getNextId(), + context.rewrite(node.getInput(), + context.get()), + context.rewrite(node.getSubquery(), + context.get()), + node.getSubqueryAssignments(), + node.getCorrelation(), + node.getOriginSubqueryError(), + node.getMayParticipateInAntiJoin()); + }} + + public class CteTransformerContext + { + public Map cteProducerMap; + + // a -> b indicates that b needs to be processed before a + MutableGraph graph; + public Stack activeCteStack; + + public CteTransformerContext() + { + cteProducerMap = new HashMap<>(); + // The cte graph will never have cycles because sql won't allow it + graph = GraphBuilder.directed().build(); + activeCteStack = new Stack<>(); + } + + public Map getCteProducerMap() + { + return ImmutableMap.copyOf(cteProducerMap); + } + + public void addProducer(String cteName, CteProducerNode cteProducer) + { + cteProducerMap.putIfAbsent(cteName, cteProducer); + } + + public void pushActiveCte(String cte) + { + this.activeCteStack.push(cte); + } + + public String popActiveCte() + { + return this.activeCteStack.pop(); + } + + public Optional peekActiveCte() + { + return (this.activeCteStack.isEmpty()) ? Optional.empty() : Optional.ofNullable(this.activeCteStack.peek()); + } + + public void addDependency(String currentCte) + { + graph.addNode(currentCte); + Optional parentCte = peekActiveCte(); + parentCte.ifPresent(s -> graph.putEdge(currentCte, s)); + } + + public List getTopologicalOrdering() + { + ImmutableList.Builder topSortedCteProducerListBuilder = ImmutableList.builder(); + Traverser.forGraph(graph).depthFirstPostOrder(graph.nodes()) + .forEach(cteName -> topSortedCteProducerListBuilder.add(cteProducerMap.get(cteName))); + return topSortedCteProducerListBuilder.build(); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java new file mode 100644 index 000000000000..f9388d766b2e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PhysicalCteOptimizer.java @@ -0,0 +1,307 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.PartitioningMetadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.eventlistener.CTEInformation; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.BasePlanFragmenter; +import com.facebook.presto.sql.planner.Partitioning; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.getCteMaterializationStrategy; +import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; +import static com.facebook.presto.SystemSessionProperties.getPartitioningProviderCatalog; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.sql.TemporaryTableUtil.assignPartitioningVariables; +import static com.facebook.presto.sql.TemporaryTableUtil.assignTemporaryTableColumnNames; +import static com.facebook.presto.sql.TemporaryTableUtil.createTemporaryTableScan; +import static com.facebook.presto.sql.TemporaryTableUtil.createTemporaryTableWriteWithoutExchanges; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.ALL; +import static com.facebook.presto.sql.planner.optimizations.CteUtils.getCtePartitionIndex; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/* + * PhysicalCteOptimizer Transformation: + * This optimizer modifies the logical plan by transforming CTE producers into table writes + * and CTE consumers into table scans. + * + * Example: + * Before Transformation: + * CTEProducer(cteX) + * |-- SomeOperation + * `-- CTEConsumer(cteX) + * + * After Transformation: + * TableWrite(cteX) + * |-- SomeOperation + * `-- TableScan(cteX) * + */ +public class PhysicalCteOptimizer + implements PlanOptimizer +{ + private final Metadata metadata; + + public PhysicalCteOptimizer(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(variableAllocator, "variableAllocator is null"); + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(warningCollector, "warningCollector is null"); + if (!getCteMaterializationStrategy(session).equals(ALL) + || session.getCteInformationCollector().getCTEInformationList().stream().noneMatch(CTEInformation::isMaterialized)) { + return PlanOptimizerResult.optimizerResult(plan, false); + } + PhysicalCteTransformerContext context = new PhysicalCteTransformerContext(); + CteProducerRewiter cteProducerRewiter = new CteProducerRewiter(session, idAllocator, variableAllocator); + CteConsumerRewrite cteConsumerRewrite = new CteConsumerRewrite(session, idAllocator, variableAllocator); + PlanNode producerReplaced = SimplePlanRewriter.rewriteWith(cteProducerRewiter, plan, context); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(cteConsumerRewrite, producerReplaced, context); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, + cteConsumerRewrite.isPlanRewritten() || cteProducerRewiter.isPlanRewritten()); + } + + public class CteProducerRewiter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + private final VariableAllocator variableAllocator; + + private final Session session; + + private boolean isPlanRewritten; + + public CteProducerRewiter(Session session, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator must not be null"); + this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null"); + this.session = requireNonNull(session, "session must not be null"); + } + + @Override + public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) + { + isPlanRewritten = true; + // Create Table Metadata + PlanNode actualSource = node.getSource(); + VariableReferenceExpression partitionVariable = actualSource.getOutputVariables() + .get(getCtePartitionIndex(actualSource.getOutputVariables())); + List partitioningTypes = Arrays.asList(partitionVariable.getType()); + String partitioningProviderCatalog = getPartitioningProviderCatalog(session); + // First column is taken as the partitioning column + Partitioning partitioning = Partitioning.create( + metadata.getPartitioningHandleForExchange(session, partitioningProviderCatalog, + getHashPartitionCount(session), partitioningTypes), + Arrays.asList(partitionVariable)); + BasePlanFragmenter.PartitioningVariableAssignments partitioningVariableAssignments + = assignPartitioningVariables(variableAllocator, partitioning); + Map variableToColumnMap = + assignTemporaryTableColumnNames(actualSource.getOutputVariables(), + partitioningVariableAssignments.getConstants().keySet()); + List partitioningVariables = partitioningVariableAssignments.getVariables(); + List partitionColumns = partitioningVariables.stream() + .map(variable -> variableToColumnMap.get(variable).getName()) + .collect(toImmutableList()); + PartitioningMetadata partitioningMetadata = new PartitioningMetadata(partitioning.getHandle(), partitionColumns); + + TableHandle temporaryTableHandle; + try { + temporaryTableHandle = metadata.createTemporaryTable( + session, + partitioningProviderCatalog, + ImmutableList.copyOf(variableToColumnMap.values()), + Optional.of(partitioningMetadata)); + context.get().put(node.getCteName(), + new PhysicalCteTransformerContext.TemporaryTableInfo( + createTemporaryTableScan( + metadata, + session, + idAllocator, + node.getSourceLocation(), + temporaryTableHandle, + actualSource.getOutputVariables(), + variableToColumnMap, + partitioningMetadata), node.getOutputVariables())); + } + catch (PrestoException e) { + if (e.getErrorCode().equals(NOT_SUPPORTED.toErrorCode())) { + throw new PrestoException( + NOT_SUPPORTED, + format("Temporary table cannot be created in catalog \"%s\": %s", partitioningProviderCatalog, e.getMessage()), + e); + } + throw e; + } + // Create the writer + return createTemporaryTableWriteWithoutExchanges( + metadata, + session, + idAllocator, + variableAllocator, + actualSource, + temporaryTableHandle, + actualSource.getOutputVariables(), + variableToColumnMap, + partitioningMetadata, + node.getRowCountVariable()); + } + + public boolean isPlanRewritten() + { + return isPlanRewritten; + } + } + + public class CteConsumerRewrite + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + private final VariableAllocator variableAllocator; + + private final Session session; + + private boolean isPlanRewritten; + + public CteConsumerRewrite(Session session, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator must not be null"); + this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null"); + this.session = requireNonNull(session, "session must not be null"); + } + + @Override + public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext context) + { + isPlanRewritten = true; + // Create Table Metadata + PhysicalCteTransformerContext.TemporaryTableInfo tableInfo = context.get().getTableInfo(node.getCteName()); + TableScanNode tempScan = tableInfo.getTableScanNode(); + + // Need to create new Variables for temp table scans to avoid duplicate references + List newOutputVariables = new ArrayList<>(); + Map newColumnAssignmentsMap = new HashMap<>(); + for (VariableReferenceExpression oldVariable : tempScan.getOutputVariables()) { + VariableReferenceExpression newVariable = variableAllocator.newVariable(oldVariable); + newOutputVariables.add(newVariable); + newColumnAssignmentsMap.put(newVariable, tempScan.getAssignments().get(oldVariable)); + } + + TableScanNode tableScanNode = new TableScanNode( + Optional.empty(), + idAllocator.getNextId(), + tempScan.getTable(), + newOutputVariables, + newColumnAssignmentsMap, + tempScan.getCurrentConstraint(), + tempScan.getEnforcedConstraint()); + + // The temporary table scan might have columns removed by the UnaliasSymbolReferences and other optimizers (its a plan tree after all), + // use originalOutputVariables (which are also canonicalized and maintained) and add them back + Map intermediateReferenceMap = new HashMap<>(); + for (int i = 0; i < tempScan.getOutputVariables().size(); i++) { + intermediateReferenceMap.put(tempScan.getOutputVariables().get(i), newOutputVariables.get(i)); + } + + Assignments.Builder assignments = Assignments.builder(); + for (int i = 0; i < tableInfo.getOriginalOutputVariables().size(); i++) { + assignments.put(node.getOutputVariables().get(i), intermediateReferenceMap.get(tableInfo.getOriginalOutputVariables().get(i))); + } + return new ProjectNode(Optional.empty(), idAllocator.getNextId(), Optional.empty(), + tableScanNode, assignments.build(), ProjectNode.Locality.LOCAL); + } + + public boolean isPlanRewritten() + { + return isPlanRewritten; + } + } + + public static class PhysicalCteTransformerContext + { + private Map cteNameToTableInfo = new HashMap<>(); + + public PhysicalCteTransformerContext() + { + cteNameToTableInfo = new HashMap<>(); + } + + public void put(String cteName, TemporaryTableInfo handle) + { + cteNameToTableInfo.put(cteName, handle); + } + + public TemporaryTableInfo getTableInfo(String cteName) + { + return cteNameToTableInfo.get(cteName); + } + + public static class TemporaryTableInfo + { + private final TableScanNode tableScanNode; + private final List originalOutputVariables; + + public TemporaryTableInfo(TableScanNode tableScanNode, List originalOutputVariables) + { + this.tableScanNode = requireNonNull(tableScanNode, "tableScanNode must not be null"); + this.originalOutputVariables = requireNonNull(originalOutputVariables, "originalOutputVariables must not be null"); + } + + public List getOriginalOutputVariables() + { + return originalOutputVariables; + } + + public TableScanNode getTableScanNode() + { + return tableScanNode; + } + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 1bee9ff07cbf..460c5631c9bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.ValuesNode; @@ -756,6 +757,12 @@ public ActualProperties visitValues(ValuesNode node, List cont .build(); } + public ActualProperties visitSequence(SequenceNode node, List context) + { + // Return the rightmost node properties + return context.get(context.size() - 1); + } + @Override public ActualProperties visitTableScan(TableScanNode node, List inputProperties) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index ebec09098fe3..65df155ae2dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -20,6 +20,8 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -30,6 +32,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.SetOperationNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; @@ -503,6 +506,34 @@ public PlanNode visitFilter(FilterNode node, RewriteContext> context) + { + // Some output can be pruned but current implementation of PhysicalCteProducer does not allow cteconsumer pruning + return node; + } + + @Override + public PlanNode visitCteProducer(CteProducerNode node, RewriteContext> context) + { + Set expectedInputs = ImmutableSet.copyOf(node.getOutputVariables()); + PlanNode source = context.rewrite(node.getSource(), expectedInputs); + return new CteProducerNode(node.getSourceLocation(), node.getId(), source, node.getCteName(), node.getRowCountVariable(), node.getOutputVariables()); + } + + @Override + public PlanNode visitSequence(SequenceNode node, RewriteContext> context) + { + ImmutableSet.Builder cteProducersBuilder = ImmutableSet.builder(); + node.getCteProducers().forEach(leftSource -> cteProducersBuilder.addAll(leftSource.getOutputVariables())); + Set leftInputs = cteProducersBuilder.build(); + List cteProducers = node.getCteProducers().stream() + .map(leftSource -> context.rewrite(leftSource, leftInputs)).collect(toImmutableList()); + Set rightInputs = ImmutableSet.copyOf(node.getPrimarySource().getOutputVariables()); + PlanNode primarySource = context.rewrite(node.getPrimarySource(), rightInputs); + return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource); + } + @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext> context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java index 1e005268bd5f..79b13554dc64 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java @@ -34,6 +34,7 @@ import com.facebook.presto.spi.function.LambdaDescriptor; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.MarkDistinctNode; @@ -284,6 +285,13 @@ public PlanNode visitOutput(OutputNode node, RewriteContext context) return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) + { + context.get().variables.addAll(node.getSource().getOutputVariables()); + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitProject(ProjectNode node, RewriteContext context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 83baddf6f014..b37adf5c99be 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; @@ -269,6 +270,11 @@ public StreamProperties visitMergeJoin(MergeJoinNode node, List inputProperties) + { + return new StreamProperties(MULTIPLE, Optional.empty(), false); + } @Override public StreamProperties visitValues(ValuesNode node, List context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 8881e0240d25..d48839efb52f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -20,6 +20,9 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.CteReferenceNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -32,6 +35,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.SetOperationNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; @@ -85,6 +89,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; @@ -125,7 +130,6 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider requireNonNull(types, "types is null"); requireNonNull(variableAllocator, "variableAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(new Rewriter(types, functionAndTypeManager, warningCollector), plan); return PlanOptimizerResult.optimizerResult(rewrittenPlan, !rewrittenPlan.equals(plan)); } @@ -136,11 +140,14 @@ private static class Rewriter private final Map mapping = new HashMap<>(); private final TypeProvider types; private final RowExpressionDeterminismEvaluator determinismEvaluator; + + private final FunctionAndTypeManager functionAndTypeManager; private final WarningCollector warningCollector; private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManager, WarningCollector warningCollector) { this.types = types; + this.functionAndTypeManager = functionAndTypeManager; this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager); this.warningCollector = warningCollector; } @@ -154,6 +161,35 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont return mapper.map(node, source); } + @Override + public PlanNode visitCteReference(CteReferenceNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + return new CteReferenceNode(node.getSourceLocation(), node.getId(), source, node.getCteName()); + } + + public PlanNode visitCteProducer(CteProducerNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + List canonical = Lists.transform(node.getOutputVariables(), this::canonicalize); + return new CteProducerNode(node.getSourceLocation(), node.getId(), source, node.getCteName(), node.getRowCountVariable(), canonical); + } + + public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext context) + { + // No rewrite on source by cte consumer + return node; + } + + public PlanNode visitSequence(SequenceNode node, RewriteContext context) + { + List cteProducers = node.getCteProducers().stream().map(c -> + SimplePlanRewriter.rewriteWith(new Rewriter(types, functionAndTypeManager, warningCollector), c)) + .collect(Collectors.toList()); + PlanNode primarySource = context.rewrite(node.getPrimarySource()); + return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource); + } + @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index 0a7f9e3b0584..aca617f3fa1a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.plan.PlanVisitor; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.sql.planner.CanonicalJoinNode; import com.facebook.presto.sql.planner.CanonicalTableScanNode; import com.facebook.presto.sql.planner.StatsEquivalentPlanNodeWithLimit; @@ -181,4 +182,8 @@ public R visitStatsEquivalentPlanNodeWithLimit(StatsEquivalentPlanNodeWithLimit { return visitPlan(node, context); } + public R visitSequence(SequenceNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 456c1b98a805..44fcf13ae3ab 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -37,6 +37,9 @@ import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.CteReferenceNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -48,6 +51,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; @@ -111,6 +115,7 @@ import io.airlift.units.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; @@ -829,6 +834,41 @@ public Void visitTableScan(TableScanNode node, Void context) return null; } + @Override + public Void visitSequence(SequenceNode node, Void context) + { + NodeRepresentation nodeOutput; + nodeOutput = addNode(node, "Sequence"); + nodeOutput.appendDetails(getCteExecutionOrder(node)); + + return processChildren(node, context); + } + + @Override + public Void visitCteConsumer(CteConsumerNode node, Void context) + { + NodeRepresentation nodeOutput; + nodeOutput = addNode(node, "CteConsumer"); + nodeOutput.appendDetailsLine("CTE_NAME: %s", node.getCteName()); + return processChildren(node, context); + } + + @Override + public Void visitCteProducer(CteProducerNode node, Void context) + { + NodeRepresentation nodeOutput; + nodeOutput = addNode(node, "CteProducer"); + nodeOutput.appendDetailsLine("CTE_NAME: %s", node.getCteName()); + return processChildren(node, context); + } + + @Override + public Void visitCteReference(CteReferenceNode node, Void context) + { + addNode(node, "CteReference"); + return processChildren(node, context); + } + @Override public Void visitValues(ValuesNode node, Void context) { @@ -1354,6 +1394,22 @@ public NodeRepresentation addNode(PlanNode rootNode, String name, String identif } } + public static String getCteExecutionOrder(SequenceNode node) + { + List cteProducers = node.getCteProducers().stream() + .filter(c -> (c instanceof CteProducerNode)) + .map(CteProducerNode.class::cast) + .collect(Collectors.toList()); + if (cteProducers.isEmpty()) { + return ""; + } + Collections.reverse(cteProducers); + return format("executionOrder = %s", + cteProducers.stream() + .map(CteProducerNode::getCteName) + .collect(Collectors.joining(" -> ", "{", "}"))); + } + public static String getDynamicFilterAssignments(AbstractJoinNode node) { if (node.getDynamicFilters().isEmpty()) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index f3064046a07c..c636076a72a4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -18,6 +18,9 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.CteReferenceNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; @@ -27,6 +30,7 @@ import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.SetOperationNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; @@ -523,6 +527,36 @@ public Void visitTableScan(TableScanNode node, Set return null; } + @Override + public Void visitCteReference(CteReferenceNode node, Set boundVariables) + { + node.getSource().accept(this, boundVariables); + return null; + } + + public Void visitCteProducer(CteProducerNode node, Set boundVariables) + { + PlanNode source = node.getSource(); + source.accept(this, boundVariables); + checkDependencies(source.getOutputVariables(), node.getOutputVariables(), + "Invalid node. Output column dependencies (%s) not in source plan output (%s)", + node.getOutputVariables(), source.getOutputVariables()); + + return null; + } + + public Void visitCteConsumer(CteConsumerNode node, Set boundVariables) + { + //We don't have to do a check here as CteConsumerNode has no dependencies. + return null; + } + + public Void visitSequence(SequenceNode node, Set boundVariables) + { + node.getSources().forEach(plan -> plan.accept(this, boundVariables)); + return null; + } + @Override public Void visitValues(ValuesNode node, Set boundVariables) { diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index e9df17dea691..0cd6677dd86b 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -28,6 +28,7 @@ import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; @@ -84,6 +85,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.getCteExecutionOrder; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.getDynamicFilterAssignments; import static com.facebook.presto.sql.planner.planPrinter.TextRenderer.formatAsLong; import static com.facebook.presto.sql.planner.planPrinter.TextRenderer.formatDouble; @@ -111,6 +113,8 @@ private enum NodeType SINK, WINDOW, UNION, + + SEQUENCE, SORT, SAMPLE, MARK_DISTINCT, @@ -140,6 +144,7 @@ private enum NodeType .put(NodeType.SINK, "indianred1") .put(NodeType.WINDOW, "darkolivegreen4") .put(NodeType.UNION, "turquoise4") + .put(NodeType.SEQUENCE, "turquoise4") .put(NodeType.MARK_DISTINCT, "violet") .put(NodeType.TABLE_WRITER, "cyan") .put(NodeType.TABLE_WRITER_MERGE, "cyan4") @@ -194,6 +199,7 @@ public static String printDistributed(SubPlan plan, FunctionAndTypeManager funct return output.toString(); } + public static String printDistributedFromFragments(List allFragments, FunctionAndTypeManager functionAndTypeManager, Session session) { PlanNodeIdGenerator idGenerator = new PlanNodeIdGenerator(); @@ -271,6 +277,18 @@ public Void visitPlan(PlanNode node, Void context) throw new UnsupportedOperationException(format("Node %s does not have a Graphviz visitor", node.getClass().getName())); } + @Override + public Void visitSequence(SequenceNode node, Void context) + { + String expression = getCteExecutionOrder(node); + printNode(node, "Sequence", expression, NODE_COLORS.get(NodeType.SEQUENCE)); + for (PlanNode planNode : node.getSources()) { + planNode.accept(this, context); + } + + return null; + } + @Override public Void visitTableWriter(TableWriterNode node, Void context) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index fd26cd7c1feb..52cd7e87eabf 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -19,6 +19,7 @@ import com.facebook.presto.operator.aggregation.histogram.HistogramGroupImplementation; import com.facebook.presto.operator.aggregation.multimapagg.MultimapAggGroupImplementation; import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy; @@ -254,7 +255,8 @@ public void testDefaults() .setUseHBOForScaledWriters(false) .setRemoveRedundantCastToVarcharInJoin(true) .setHandleComplexEquiJoins(false) - .setSkipHashGenerationForJoinWithTableScanInput(false)); + .setSkipHashGenerationForJoinWithTableScanInput(false) + .setCteMaterializationStrategy(CteMaterializationStrategy.NONE)); } @Test @@ -455,6 +457,7 @@ public void testExplicitPropertyMappings() .put("optimizer.rewrite-constant-array-contains-to-in", "true") .put("optimizer.use-hbo-for-scaled-writers", "true") .put("optimizer.remove-redundant-cast-to-varchar-in-join", "false") + .put("cte-materialization-strategy", "ALL") .put("optimizer.handle-complex-equi-joins", "true") .put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true") .build(); @@ -656,7 +659,8 @@ public void testExplicitPropertyMappings() .setUseHBOForScaledWriters(true) .setRemoveRedundantCastToVarcharInJoin(false) .setHandleComplexEquiJoins(true) - .setSkipHashGenerationForJoinWithTableScanInput(true); + .setSkipHashGenerationForJoinWithTableScanInput(true) + .setCteMaterializationStrategy(CteMaterializationStrategy.ALL); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CteConsumerMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CteConsumerMatcher.java new file mode 100644 index 000000000000..6b14ff7f8bbc --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CteConsumerMatcher.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.PlanNode; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +final class CteConsumerMatcher + implements Matcher +{ + private final String expectedCteName; + + public CteConsumerMatcher(String cteName) + { + this.expectedCteName = requireNonNull(cteName, "expectedCteName is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof CteConsumerNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + CteConsumerNode otherNode = (CteConsumerNode) node; + if (!expectedCteName.equalsIgnoreCase(otherNode.getCteName())) { + return NO_MATCH; + } + + return match(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("expectedCteName", expectedCteName) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CteProducerMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CteProducerMatcher.java new file mode 100644 index 000000000000..2e386dbb9bcb --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CteProducerMatcher.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.CteProducerNode; +import com.facebook.presto.spi.plan.PlanNode; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +final class CteProducerMatcher + implements Matcher +{ + private final String expectedCteName; + public CteProducerMatcher(String cteName) + { + this.expectedCteName = requireNonNull(cteName, "expectedCteName is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof CteProducerNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + CteProducerNode otherNode = (CteProducerNode) node; + if (!expectedCteName.equalsIgnoreCase(otherNode.getCteName())) { + return NO_MATCH; + } + return match(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("expectedCteName", expectedCteName) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 17d6e766c3e3..34f3d6124615 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -21,6 +21,8 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Step; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; @@ -29,6 +31,7 @@ import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; import com.facebook.presto.spi.plan.ValuesNode; @@ -439,6 +442,23 @@ public static PlanMatchPattern join( .with(joinMatcher); } + public static PlanMatchPattern cteConsumer(String cteName) + { + CteConsumerMatcher cteConsumerMatcher = new CteConsumerMatcher(cteName); + return node(CteConsumerNode.class).with(cteConsumerMatcher); + } + + public static PlanMatchPattern cteProducer(String cteName, PlanMatchPattern source) + { + CteProducerMatcher cteProducerMatcher = new CteProducerMatcher(cteName); + return node(CteProducerNode.class, source).with(cteProducerMatcher); + } + + public static PlanMatchPattern sequence(PlanMatchPattern... sources) + { + return node(SequenceNode.class, sources); + } + public static PlanMatchPattern spatialJoin(String expectedFilter, PlanMatchPattern left, PlanMatchPattern right) { return spatialJoin(expectedFilter, Optional.empty(), left, right); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java new file mode 100644 index 000000000000..57cff35abbd9 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java @@ -0,0 +1,211 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.Optimizer; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.SystemSessionProperties.CTE_MATERIALIZATION_STRATEGY; +import static com.facebook.presto.sql.planner.SqlPlannerContext.NestedCteStack.delimiter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.cteConsumer; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.cteProducer; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sequence; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestLogicalCteOptimizer + extends BasePlanTest +{ + @Test + public void testConvertSimpleCte() + { + assertUnitPlan("WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 ", + anyTree( + sequence(cteProducer("temp", anyTree(tableScan("orders"))), + anyTree(cteConsumer("temp"))))); + } + + @Test + public void testSimpleRedefinedCteWithSameName() + { + assertUnitPlan("WITH temp as " + + "( with temp as (SELECT orderkey FROM ORDERS) SELECT * FROM temp) " + + "SELECT * FROM temp", + anyTree( + sequence( + cteProducer("temp", anyTree(cteConsumer("temp" + delimiter + "temp"))), + cteProducer("temp" + delimiter + "temp", anyTree(tableScan("orders"))), + anyTree(cteConsumer("temp"))))); + } + + @Test + public void testComplexRedefinedNestedCtes() + { + assertUnitPlan( + "WITH " + + "cte1 AS ( " + + " SELECT orderkey, totalprice FROM ORDERS WHERE orderkey < 100 " + + "), " + + "cte2 AS ( " + + " WITH cte3 AS ( WITH cte4 AS (SELECT orderkey, totalprice FROM cte1 WHERE totalprice > 1000) SELECT * FROM cte4) " + + " SELECT cte3.orderkey FROM cte3 " + + "), " + + "cte3 AS ( " + + " SELECT * FROM customer WHERE custkey < 50 " + + ") " + + "SELECT cte3.*, cte2.orderkey FROM cte3 JOIN cte2 ON cte3.custkey = cte2.orderkey", + anyTree( + sequence( + cteProducer("cte3", anyTree(tableScan("customer"))), + cteProducer("cte2", anyTree(cteConsumer("cte2" + delimiter + "cte3"))), + cteProducer("cte2" + delimiter + "cte3", anyTree(cteConsumer("cte2" + delimiter + "cte3" + delimiter + "cte4"))), + cteProducer("cte2" + delimiter + "cte3" + delimiter + "cte4", anyTree(cteConsumer("cte1"))), + cteProducer("cte1", anyTree(tableScan("orders"))), + anyTree( + join( + anyTree(cteConsumer("cte3")), + anyTree(cteConsumer("cte2"))))))); + } + + @Test + public void testRedefinedCtesInDifferentScope() + { + assertUnitPlan("WITH cte1 AS ( WITH cte2 as (SELECT orderkey FROM ORDERS WHERE orderkey < 100)" + + "SELECT * FROM cte2), " + + " cte2 AS (SELECT * FROM customer WHERE custkey < 50) " + + "SELECT * FROM cte2 JOIN cte1 ON true", + anyTree( + sequence( + cteProducer("cte2", anyTree(tableScan("customer"))), + cteProducer("cte1", anyTree(cteConsumer("cte1" + delimiter + "cte2"))), + cteProducer("cte1" + delimiter + "cte2", anyTree(tableScan("orders"))), + anyTree(join(anyTree(cteConsumer("cte2")), anyTree(cteConsumer("cte1"))))))); + } + + @Test + public void testNestedCte() + { + assertUnitPlan("WITH temp1 as (SELECT orderkey FROM ORDERS), " + + " temp2 as (SELECT * FROM temp1) " + + "SELECT * FROM temp2", + anyTree( + sequence(cteProducer("temp2", anyTree(cteConsumer("temp1"))), + cteProducer("temp1", anyTree(tableScan("orders"))), + anyTree(cteConsumer("temp2"))))); + } + + @Test + public void testMultipleIndependentCtes() + { + assertUnitPlan("WITH temp1 as (SELECT orderkey FROM ORDERS), " + + " temp2 as (SELECT custkey FROM CUSTOMER) " + + "SELECT * FROM temp1, temp2", + anyTree( + sequence(cteProducer("temp1", anyTree(tableScan("orders"))), + cteProducer("temp2", anyTree(tableScan("customer"))), + anyTree(join(anyTree(cteConsumer("temp1")), anyTree(cteConsumer("temp2"))))))); + } + + @Test + public void testDependentCtes() + { + assertUnitPlan("WITH temp1 as (SELECT orderkey FROM ORDERS), " + + " temp2 as (SELECT orderkey FROM temp1) " + + "SELECT * FROM temp2 , temp1", + anyTree( + sequence(cteProducer("temp2", anyTree(cteConsumer("temp1"))), + cteProducer("temp1", anyTree(tableScan("orders"))), + anyTree(join(anyTree(cteConsumer("temp2")), anyTree(cteConsumer("temp1"))))))); + } + + @Test + public void testComplexCteWithJoins() + { + assertUnitPlan( + "WITH cte_orders AS (SELECT orderkey, custkey FROM ORDERS), " + + " cte_line_item AS (SELECT l.orderkey, l.suppkey FROM lineitem l JOIN cte_orders o ON l.orderkey = o.orderkey) " + + "SELECT li.orderkey, s.suppkey, s.name FROM cte_line_item li JOIN SUPPLIER s ON li.suppkey = s.suppkey", + anyTree( + sequence( + cteProducer("cte_line_item", + anyTree( + join( + anyTree(tableScan("lineitem")), + anyTree(cteConsumer("cte_orders"))))), + cteProducer("cte_orders", anyTree(tableScan("orders"))), + anyTree( + join( + anyTree(cteConsumer("cte_line_item")), + anyTree(tableScan("supplier"))))))); + } + + @Test + public void tesNoPersistentCteOnlyWithRowType() + { + assertUnitPlan("WITH temp AS " + + "( SELECT CAST(ROW('example_status', 100) AS ROW(status VARCHAR, amount INTEGER)) AS order_details" + + " FROM (VALUES (1))" + + ") SELECT * FROM temp", + anyTree(values("1"))); + } + + @Test + public void testSimplePersistentCteWithRowTypeAndNonRowType() + { + assertUnitPlan("WITH temp AS (" + + " SELECT * FROM (VALUES " + + " (CAST(ROW('example_status', 100) AS ROW(status VARCHAR, amount INTEGER)), 1)," + + " (CAST(ROW('another_status', 200) AS ROW(status VARCHAR, amount INTEGER)), 2)" + + " ) AS t (order_details, orderkey)" + + ") SELECT * FROM temp", + anyTree( + sequence( + cteProducer("temp", anyTree(values("status", "amount"))), + anyTree(cteConsumer("temp"))))); + } + + @Test + public void testNoPersistentCteWithZeroLengthVarcharType() + { + assertUnitPlan("WITH temp AS (" + + " SELECT * FROM (VALUES " + + " (CAST('' AS VARCHAR(0)), 9)" + + " ) AS t (text_column, number_column)" + + ") SELECT * FROM temp", + anyTree(values("text_column", "number_column"))); + } + + private void assertUnitPlan(String sql, PlanMatchPattern pattern) + { + List optimizers = ImmutableList.of( + new LogicalCteOptimizer(getQueryRunner().getMetadata())); + assertPlan(sql, getSession(), Optimizer.PlanStage.OPTIMIZED, pattern, optimizers); + } + + private Session getSession() + { + return Session.builder(this.getQueryRunner().getDefaultSession()) + .setSystemProperty(CTE_MATERIALIZATION_STRATEGY, "ALL") + .build(); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/CTEInformation.java b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/CTEInformation.java index 85d1cb824c32..ab635298ad4b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/CTEInformation.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/CTEInformation.java @@ -25,16 +25,19 @@ public class CTEInformation //number of references of the CTE in the query private int numberOfReferences; private final boolean isView; + private final boolean isMaterialized; @JsonCreator public CTEInformation( @JsonProperty("cteName") String cteName, @JsonProperty("numberOfReferences") int numberOfReferences, - @JsonProperty("isView") boolean isView) + @JsonProperty("isView") boolean isView, + @JsonProperty("isMaterialized") boolean isMaterialized) { this.cteName = requireNonNull(cteName, "cteName is null"); this.numberOfReferences = numberOfReferences; this.isView = isView; + this.isMaterialized = isMaterialized; } @JsonProperty @@ -43,6 +46,12 @@ public String getCteName() return cteName; } + @JsonProperty + public boolean isMaterialized() + { + return isMaterialized; + } + @JsonProperty public int getNumberOfReferences() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteConsumerNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteConsumerNode.java new file mode 100644 index 000000000000..dee7cb5db6ac --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteConsumerNode.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Immutable +public final class CteConsumerNode + extends PlanNode +{ + private final String cteName; + private final List originalOutputVariables; + + @JsonCreator + public CteConsumerNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("outputvars") List originalOutputVariables, + @JsonProperty("cteName") String cteName) + { + this(sourceLocation, id, Optional.empty(), originalOutputVariables, cteName); + } + + public CteConsumerNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + List originalOutputVariables, + String cteName) + { + super(sourceLocation, id, statsEquivalentPlanNode); + this.cteName = requireNonNull(cteName, "cteName must not be null"); + this.originalOutputVariables = requireNonNull(originalOutputVariables, "originalOutputVariables must not be null"); + } + + @Override + public List getSources() + { + // CteConsumer should be the leaf node + return Collections.emptyList(); + } + + @Override + public List getOutputVariables() + { + return originalOutputVariables; + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + // this function expects a new instance + checkArgument(newChildren.size() == 0, "expected newChildren to contain 0 node"); + return new CteConsumerNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), originalOutputVariables, cteName); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new CteConsumerNode(getSourceLocation(), getId(), statsEquivalentPlanNode, originalOutputVariables, cteName); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitCteConsumer(this, context); + } + + @JsonProperty + public String getCteName() + { + return cteName; + } + + private static void checkArgument(boolean condition, String message) + { + if (!condition) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteProducerNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteProducerNode.java new file mode 100644 index 000000000000..e01675e8840b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteProducerNode.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Optional; + +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; + +@Immutable +public final class CteProducerNode + extends PlanNode +{ + private final PlanNode source; + private final String cteName; + private final VariableReferenceExpression rowCountVariable; + private final List originalOutputVariables; + + @JsonCreator + public CteProducerNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("cteName") String cteName, + @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, + @JsonProperty("originalOutputVariables") List originalOutputVariables) + { + this(sourceLocation, id, Optional.empty(), source, cteName, rowCountVariable, originalOutputVariables); + } + + public CteProducerNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + PlanNode source, + String cteName, + VariableReferenceExpression rowCountVariable, + List originalOutputVariables) + { + super(sourceLocation, id, statsEquivalentPlanNode); + // Inside your method or constructor + this.cteName = requireNonNull(cteName, "cteName must not be null"); + this.source = requireNonNull(source, "source must not be null"); + this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable must not be null"); + this.originalOutputVariables = requireNonNull(originalOutputVariables, "originalOutputVariables must not be null"); + } + + @Override + public List getSources() + { + return singletonList(source); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @Override + public List getOutputVariables() + { + return originalOutputVariables; + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + checkArgument(newChildren.size() == 1, "expected newChildren to contain 1 node"); + return new CteProducerNode(newChildren.get(0).getSourceLocation(), getId(), getStatsEquivalentPlanNode(), newChildren.get(0), + cteName, rowCountVariable, originalOutputVariables); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new CteProducerNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, cteName, rowCountVariable, originalOutputVariables); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitCteProducer(this, context); + } + + @JsonProperty + public String getCteName() + { + return cteName; + } + + public VariableReferenceExpression getRowCountVariable() + { + return rowCountVariable; + } + + private static void checkArgument(boolean condition, String message) + { + if (!condition) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteReferenceNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteReferenceNode.java new file mode 100644 index 000000000000..3df6e9fefa37 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/CteReferenceNode.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Optional; + +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; + +@Immutable +public final class CteReferenceNode + extends PlanNode +{ + private final PlanNode source; + private final String cteName; + + @JsonCreator + public CteReferenceNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("cteName") String cteName) + { + this(sourceLocation, id, Optional.empty(), source, cteName); + } + + public CteReferenceNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + PlanNode source, + String cteName) + { + super(sourceLocation, id, statsEquivalentPlanNode); + this.cteName = requireNonNull(cteName, "cteName must not be null"); + this.source = requireNonNull(source, "source must not be null"); + } + + @Override + public List getSources() + { + return singletonList(source); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @Override + public List getOutputVariables() + { + return source.getOutputVariables(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + requireNonNull(newChildren, "newChildren is null"); + checkArgument(newChildren.size() == 1, "expected newChildren to contain 1 node"); + return new CteReferenceNode(newChildren.get(0).getSourceLocation(), getId(), getStatsEquivalentPlanNode(), newChildren.get(0), cteName); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new CteReferenceNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, cteName); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitCteReference(this, context); + } + + public String getCteName() + { + return cteName; + } + + private static void checkArgument(boolean condition, String message) + { + if (!condition) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/PlanVisitor.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/PlanVisitor.java index 600a1129b81e..1ee77f1e988b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/PlanVisitor.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/PlanVisitor.java @@ -84,4 +84,24 @@ public R visitDistinctLimit(DistinctLimitNode node, C context) { return visitPlan(node, context); } + + public R visitCteReference(CteReferenceNode node, C context) + { + return visitPlan(node, context); + } + + public R visitCteProducer(CteProducerNode node, C context) + { + return visitPlan(node, context); + } + + public R visitCteConsumer(CteConsumerNode node, C context) + { + return visitPlan(node, context); + } + + public R visitSequence(SequenceNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/SequenceNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/SequenceNode.java new file mode 100644 index 000000000000..09be90025b07 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/SequenceNode.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class SequenceNode + extends PlanNode +{ + // cteProducers {l1,l2,l3} will be in {l3, l2,l1} order + private final List cteProducers; + private final PlanNode primarySource; + + @JsonCreator + public SequenceNode(Optional sourceLocation, + @JsonProperty("id") PlanNodeId planNodeId, + @JsonProperty("cteProducers") List left, + @JsonProperty("primarySource") PlanNode primarySource) + { + this(sourceLocation, planNodeId, Optional.empty(), left, primarySource); + } + + public SequenceNode(Optional sourceLocation, + PlanNodeId planNodeId, + Optional statsEquivalentPlanNode, + List leftList, + PlanNode primarySource) + { + super(sourceLocation, planNodeId, statsEquivalentPlanNode); + this.cteProducers = leftList; + this.primarySource = primarySource; + } + + @JsonProperty + public List getCteProducers() + { + return this.cteProducers; + } + + @JsonProperty + public PlanNode getPrimarySource() + { + return this.primarySource; + } + + @Override + public List getSources() + { + List children = new ArrayList<>(cteProducers); + children.add(primarySource); + return children; + } + + @Override + public List getOutputVariables() + { + return primarySource.getOutputVariables(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new SequenceNode(newChildren.get(0).getSourceLocation(), getId(), getStatsEquivalentPlanNode(), + newChildren.subList(0, newChildren.size() - 1), newChildren.get(newChildren.size() - 1)); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new SequenceNode(getSourceLocation(), getId(), statsEquivalentPlanNode, cteProducers, this.getPrimarySource()); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitSequence(this, context); + } +}