diff --git a/pgdog/src/backend/pool/connection/aggregate.rs b/pgdog/src/backend/pool/connection/aggregate.rs index 78600fb1d..e7c5ef814 100644 --- a/pgdog/src/backend/pool/connection/aggregate.rs +++ b/pgdog/src/backend/pool/connection/aggregate.rs @@ -719,6 +719,7 @@ mod test { Decoder, }; use bytes::Bytes; + use pg_query::{protobuf::SelectStmt, NodeEnum}; use std::collections::VecDeque; #[test] @@ -787,22 +788,30 @@ mod test { } } + fn select(stmt: &str) -> SelectStmt { + let stmt = pg_query::parse(stmt) + .unwrap() + .protobuf + .stmts + .remove(0) + .stmt + .unwrap(); + match stmt.node.unwrap() { + NodeEnum::SelectStmt(stmt) => *stmt, + _ => panic!("not a select"), + } + } + + fn parse(stmt: &str) -> Aggregate { + Aggregate::parse(&select(stmt)) + } + #[test] fn aggregate_count_with_int_typecast() { // Regression test for https://github.com/pgdogdev/pgdog/issues/861 // SELECT COUNT(*)::int returns int4 from each shard; the accumulator // must merge the per-shard values and preserve the requested type. - let stmt = pg_query::parse("SELECT COUNT(*)::int FROM users") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT COUNT(*)::int FROM users"); let rd = RowDescription::new(&[integer_field("count")]); let decoder = Decoder::from(&rd); @@ -830,17 +839,7 @@ mod test { #[test] fn aggregate_count_default_bigint() { // SELECT COUNT(*) (no cast) should still merge correctly and stay bigint. - let stmt = pg_query::parse("SELECT COUNT(*) FROM users") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT COUNT(*) FROM users"); let rd = RowDescription::new(&[Field::bigint("count")]); let decoder = Decoder::from(&rd); @@ -866,17 +865,7 @@ mod test { #[test] fn aggregate_merges_avg_with_count() { - let stmt = pg_query::parse("SELECT COUNT(price), AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT COUNT(price), AVG(price) FROM menu"); let rd = RowDescription::new(&[Field::bigint("count"), Field::double("avg")]); let decoder = Decoder::from(&rd); @@ -906,17 +895,7 @@ mod test { #[test] fn aggregate_avg_without_count_passthrough() { - let stmt = pg_query::parse("SELECT AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT AVG(price) FROM menu"); let rd = RowDescription::new(&[Field::double("avg")]); let decoder = Decoder::from(&rd); @@ -943,17 +922,7 @@ mod test { #[test] fn aggregate_avg_with_rewrite_helper() { - let stmt = pg_query::parse("SELECT AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT AVG(price) FROM menu"); let rd = RowDescription::new(&[ Field::double("avg"), @@ -992,17 +961,7 @@ mod test { #[test] fn aggregate_multiple_avg_with_helpers() { - let stmt = pg_query::parse("SELECT AVG(price), AVG(discount) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT AVG(price), AVG(discount) FROM menu"); let rd = RowDescription::new(&[ Field::double("avg_price"), @@ -1055,17 +1014,7 @@ mod test { #[test] fn aggregate_stddev_samp_with_helpers() { - let stmt = pg_query::parse("SELECT STDDEV(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT STDDEV(price) FROM menu"); let rd = RowDescription::new(&[ Field::double("stddev_price"), @@ -1132,17 +1081,7 @@ mod test { #[test] fn aggregate_var_pop_with_helpers() { - let stmt = pg_query::parse("SELECT VAR_POP(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT VAR_POP(price) FROM menu"); let rd = RowDescription::new(&[ Field::double("var_price"), @@ -1201,17 +1140,7 @@ mod test { #[test] fn aggregate_distinct_count_not_paired() { - let stmt = pg_query::parse("SELECT COUNT(DISTINCT price), AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT COUNT(DISTINCT price), AVG(price) FROM menu"); let rd = RowDescription::new(&[Field::bigint("count"), Field::double("avg")]); let decoder = Decoder::from(&rd); @@ -1239,17 +1168,7 @@ mod test { #[test] fn aggregate_errors_when_helper_alias_missing() { - let stmt = pg_query::parse("SELECT AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT AVG(price) FROM menu"); let rd = RowDescription::new(&[Field::double("avg")]); let decoder = Decoder::from(&rd); @@ -1282,17 +1201,7 @@ mod test { #[test] fn aggregate_group_by_merges_rows() { - let stmt = pg_query::parse("SELECT price, SUM(quantity) FROM menu GROUP BY 1") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT price, SUM(quantity) FROM menu GROUP BY 1"); let rd = RowDescription::new(&[Field::double("price"), Field::bigint("sum")]); let decoder = Decoder::from(&rd); @@ -1333,17 +1242,7 @@ mod test { #[test] fn aggregate_group_by_multidimensional_arrays_uses_raw_bytes() { - let stmt = pg_query::parse("SELECT matrix, COUNT(*) FROM samples GROUP BY 1") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT matrix, COUNT(*) FROM samples GROUP BY 1"); let rd = RowDescription::new(&[integer_array_field("matrix"), Field::bigint("count")]); let decoder = Decoder::from(&rd); @@ -1389,18 +1288,7 @@ mod test { #[test] fn aggregate_group_by_interval_arrays_preserves_postgres_text_output() { - let stmt = - pg_query::parse("SELECT sample_interval_array, COUNT(*) FROM samples GROUP BY 1") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - let aggregate = match stmt.stmt.unwrap().node.unwrap() { - pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt), - _ => panic!("expected select stmt"), - }; + let aggregate = parse("SELECT sample_interval_array, COUNT(*) FROM samples GROUP BY 1"); let rd = RowDescription::new(&[ interval_array_field("sample_interval_array"), diff --git a/pgdog/src/frontend/router/parser/aggregate.rs b/pgdog/src/frontend/router/parser/aggregate.rs index 759728af6..fd88a4410 100644 --- a/pgdog/src/frontend/router/parser/aggregate.rs +++ b/pgdog/src/frontend/router/parser/aggregate.rs @@ -244,415 +244,215 @@ impl Aggregate { mod test { use super::*; - #[test] - fn test_parse_aggregate() { - let query = pg_query::parse("SELECT COUNT(*)::bigint FROM users") + fn select(stmt: &str) -> SelectStmt { + let stmt = pg_query::parse(stmt) .unwrap() .protobuf .stmts - .first() - .cloned() + .remove(0) + .stmt .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!( - aggr.targets().first().unwrap().function, - AggregateFunction::Count - ); - } - + match stmt.node.unwrap() { + NodeEnum::SelectStmt(stmt) => *stmt, _ => panic!("not a select"), } } + fn parse(stmt: &str) -> Aggregate { + Aggregate::parse(&select(stmt)) + } + + #[test] + fn test_parse_aggregate() { + let aggr = parse("SELECT COUNT(*)::bigint FROM users"); + assert_eq!( + aggr.targets().first().unwrap().function, + AggregateFunction::Count + ); + } + #[test] fn test_parse_avg_count_expr_id_matches() { - let query = pg_query::parse("SELECT COUNT(price), AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.targets().len(), 2); - let count = &aggr.targets()[0]; - let avg = &aggr.targets()[1]; - assert!(matches!(count.function(), AggregateFunction::Count)); - assert!(matches!(avg.function(), AggregateFunction::Avg)); - assert_eq!(count.expr_id(), avg.expr_id()); - assert!(!count.is_distinct()); - assert!(!avg.is_distinct()); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT COUNT(price), AVG(price) FROM menu"); + assert_eq!(aggr.targets().len(), 2); + let count = &aggr.targets()[0]; + let avg = &aggr.targets()[1]; + assert!(matches!(count.function(), AggregateFunction::Count)); + assert!(matches!(avg.function(), AggregateFunction::Avg)); + assert_eq!(count.expr_id(), avg.expr_id()); + assert!(!count.is_distinct()); + assert!(!avg.is_distinct()); } #[test] fn test_parse_avg_count_expr_id_differs() { - let query = pg_query::parse("SELECT COUNT(price), AVG(cost) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.targets().len(), 2); - let count = &aggr.targets()[0]; - let avg = &aggr.targets()[1]; - assert!(matches!(count.function(), AggregateFunction::Count)); - assert!(matches!(avg.function(), AggregateFunction::Avg)); - assert_ne!(count.expr_id(), avg.expr_id()); - assert!(!count.is_distinct()); - assert!(!avg.is_distinct()); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT COUNT(price), AVG(cost) FROM menu"); + assert_eq!(aggr.targets().len(), 2); + let count = &aggr.targets()[0]; + let avg = &aggr.targets()[1]; + assert!(matches!(count.function(), AggregateFunction::Count)); + assert!(matches!(avg.function(), AggregateFunction::Avg)); + assert_ne!(count.expr_id(), avg.expr_id()); + assert!(!count.is_distinct()); + assert!(!avg.is_distinct()); } #[test] fn test_parse_distinct_count_not_matching_avg() { - let query = pg_query::parse("SELECT COUNT(DISTINCT price), AVG(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.targets().len(), 2); - let count = &aggr.targets()[0]; - let avg = &aggr.targets()[1]; - assert!(matches!(count.function(), AggregateFunction::Count)); - assert!(matches!(avg.function(), AggregateFunction::Avg)); - assert!(count.is_distinct()); - assert!(!avg.is_distinct()); - assert_eq!(count.expr_id(), avg.expr_id()); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT COUNT(DISTINCT price), AVG(price) FROM menu"); + assert_eq!(aggr.targets().len(), 2); + let count = &aggr.targets()[0]; + let avg = &aggr.targets()[1]; + assert!(matches!(count.function(), AggregateFunction::Count)); + assert!(matches!(avg.function(), AggregateFunction::Avg)); + assert!(count.is_distinct()); + assert!(!avg.is_distinct()); + assert_eq!(count.expr_id(), avg.expr_id()); } #[test] fn test_parse_stddev_variants() { - let query = pg_query::parse( - "SELECT STDDEV(price), STDDEV_SAMP(price), STDDEV_POP(price) FROM menu", - ) - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.targets().len(), 3); - assert!(matches!( - aggr.targets()[0].function(), - AggregateFunction::StddevSamp - )); - assert!(matches!( - aggr.targets()[1].function(), - AggregateFunction::StddevSamp - )); - assert!(matches!( - aggr.targets()[2].function(), - AggregateFunction::StddevPop - )); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT STDDEV(price), STDDEV_SAMP(price), STDDEV_POP(price) FROM menu"); + assert_eq!(aggr.targets().len(), 3); + assert!(matches!( + aggr.targets()[0].function(), + AggregateFunction::StddevSamp + )); + assert!(matches!( + aggr.targets()[1].function(), + AggregateFunction::StddevSamp + )); + assert!(matches!( + aggr.targets()[2].function(), + AggregateFunction::StddevPop + )); } #[test] fn test_parse_variance_variants() { - let query = - pg_query::parse("SELECT VARIANCE(price), VAR_SAMP(price), VAR_POP(price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.targets().len(), 3); - assert!(matches!( - aggr.targets()[0].function(), - AggregateFunction::VarSamp - )); - assert!(matches!( - aggr.targets()[1].function(), - AggregateFunction::VarSamp - )); - assert!(matches!( - aggr.targets()[2].function(), - AggregateFunction::VarPop - )); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT VARIANCE(price), VAR_SAMP(price), VAR_POP(price) FROM menu"); + assert_eq!(aggr.targets().len(), 3); + assert!(matches!( + aggr.targets()[0].function(), + AggregateFunction::VarSamp + )); + assert!(matches!( + aggr.targets()[1].function(), + AggregateFunction::VarSamp + )); + assert!(matches!( + aggr.targets()[2].function(), + AggregateFunction::VarPop + )); } #[test] fn test_parse_group_by_ordinals() { - let query = - pg_query::parse("SELECT price, category_id, SUM(quantity) FROM menu GROUP BY 1, 2") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[0, 1]); - assert_eq!(aggr.targets().len(), 1); - let target = &aggr.targets()[0]; - assert!(matches!(target.function(), AggregateFunction::Sum)); - assert_eq!(target.column(), 2); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT price, category_id, SUM(quantity) FROM menu GROUP BY 1, 2"); + assert_eq!(aggr.group_by(), &[0, 1]); + assert_eq!(aggr.targets().len(), 1); + let target = &aggr.targets()[0]; + assert!(matches!(target.function(), AggregateFunction::Sum)); + assert_eq!(target.column(), 2); } #[test] fn test_parse_sum_distinct_sets_flag() { - let query = pg_query::parse("SELECT SUM(DISTINCT price) FROM menu") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.targets().len(), 1); - let target = &aggr.targets()[0]; - assert!(matches!(target.function(), AggregateFunction::Sum)); - assert!(target.is_distinct()); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT SUM(DISTINCT price) FROM menu"); + assert_eq!(aggr.targets().len(), 1); + let target = &aggr.targets()[0]; + assert!(matches!(target.function(), AggregateFunction::Sum)); + assert!(target.is_distinct()); } #[test] fn test_parse_group_by_column_name_single() { - let query = pg_query::parse("SELECT user_id, COUNT(1) FROM example GROUP BY user_id") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[0]); - assert_eq!(aggr.targets().len(), 1); - let target = &aggr.targets()[0]; - assert!(matches!(target.function(), AggregateFunction::Count)); - assert_eq!(target.column(), 1); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT user_id, COUNT(1) FROM example GROUP BY user_id"); + assert_eq!(aggr.group_by(), &[0]); + assert_eq!(aggr.targets().len(), 1); + let target = &aggr.targets()[0]; + assert!(matches!(target.function(), AggregateFunction::Count)); + assert_eq!(target.column(), 1); } #[test] fn test_parse_group_by_column_name_multiple() { - let query = pg_query::parse( + let aggr = parse( "SELECT COUNT(*), user_id, category_id FROM example GROUP BY user_id, category_id", - ) - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[1, 2]); - assert_eq!(aggr.targets().len(), 1); - let target = &aggr.targets()[0]; - assert!(matches!(target.function(), AggregateFunction::Count)); - assert_eq!(target.column(), 0); - } - _ => panic!("not a select"), - } + ); + assert_eq!(aggr.group_by(), &[1, 2]); + assert_eq!(aggr.targets().len(), 1); + let target = &aggr.targets()[0]; + assert!(matches!(target.function(), AggregateFunction::Count)); + assert_eq!(target.column(), 0); } #[test] fn test_parse_group_by_qualified_column_name() { - let query = pg_query::parse( - "SELECT COUNT(1), example.user_id FROM example GROUP BY example.user_id", - ) - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[1]); - assert_eq!(aggr.targets().len(), 1); - let target = &aggr.targets()[0]; - assert!(matches!(target.function(), AggregateFunction::Count)); - assert_eq!(target.column(), 0); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT COUNT(1), example.user_id FROM example GROUP BY example.user_id"); + assert_eq!(aggr.group_by(), &[1]); + assert_eq!(aggr.targets().len(), 1); + let target = &aggr.targets()[0]; + assert!(matches!(target.function(), AggregateFunction::Count)); + assert_eq!(target.column(), 0); } #[test] fn test_parse_group_by_mixed_ordinal_and_column_name() { - let query = pg_query::parse( - "SELECT user_id, category_id, SUM(quantity) FROM example GROUP BY user_id, 2", - ) - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[0, 1]); - assert_eq!(aggr.targets().len(), 1); - let target = &aggr.targets()[0]; - assert!(matches!(target.function(), AggregateFunction::Sum)); - assert_eq!(target.column(), 2); - } - _ => panic!("not a select"), - } + let aggr = + parse("SELECT user_id, category_id, SUM(quantity) FROM example GROUP BY user_id, 2"); + assert_eq!(aggr.group_by(), &[0, 1]); + assert_eq!(aggr.targets().len(), 1); + let target = &aggr.targets()[0]; + assert!(matches!(target.function(), AggregateFunction::Sum)); + assert_eq!(target.column(), 2); } #[test] fn test_parse_group_by_column_not_in_select() { - let query = pg_query::parse("SELECT COUNT(*) FROM example GROUP BY user_id") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - let empty: Vec = vec![]; - assert_eq!(aggr.group_by(), empty.as_slice()); - assert_eq!(aggr.targets().len(), 1); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT COUNT(*) FROM example GROUP BY user_id"); + assert!(aggr.group_by().is_empty()); + assert_eq!(aggr.targets().len(), 1); } #[test] fn test_parse_group_by_with_multiple_aggregates() { - let query = pg_query::parse( - "SELECT COUNT(*), SUM(price), user_id, AVG(price) FROM example GROUP BY user_id", - ) - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[2]); - assert_eq!(aggr.targets().len(), 3); - assert!(matches!( - aggr.targets()[0].function(), - AggregateFunction::Count - )); - assert!(matches!( - aggr.targets()[1].function(), - AggregateFunction::Sum - )); - assert!(matches!( - aggr.targets()[2].function(), - AggregateFunction::Avg - )); - } - _ => panic!("not a select"), - } + let aggr = + parse("SELECT COUNT(*), SUM(price), user_id, AVG(price) FROM example GROUP BY user_id"); + assert_eq!(aggr.group_by(), &[2]); + assert_eq!(aggr.targets().len(), 3); + assert!(matches!( + aggr.targets()[0].function(), + AggregateFunction::Count + )); + assert!(matches!( + aggr.targets()[1].function(), + AggregateFunction::Sum + )); + assert!(matches!( + aggr.targets()[2].function(), + AggregateFunction::Avg + )); } #[test] fn test_parse_group_by_qualified_matches_select_unqualified() { - let query = - pg_query::parse("SELECT user_id, COUNT(1) FROM example GROUP BY example.user_id") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[0]); - assert_eq!(aggr.targets().len(), 1); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT user_id, COUNT(1) FROM example GROUP BY example.user_id"); + assert_eq!(aggr.group_by(), &[0]); + assert_eq!(aggr.targets().len(), 1); } #[test] fn test_parse_group_by_unqualified_matches_select_qualified() { - let query = - pg_query::parse("SELECT example.user_id, COUNT(1) FROM example GROUP BY user_id") - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[0]); - assert_eq!(aggr.targets().len(), 1); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT example.user_id, COUNT(1) FROM example GROUP BY user_id"); + assert_eq!(aggr.group_by(), &[0]); + assert_eq!(aggr.targets().len(), 1); } #[test] fn test_parse_group_by_both_qualified_order_matters() { - let query = pg_query::parse( - "SELECT example.user_id, COUNT(1) FROM example GROUP BY example.user_id", - ) - .unwrap() - .protobuf - .stmts - .first() - .cloned() - .unwrap(); - match query.stmt.unwrap().node.unwrap() { - NodeEnum::SelectStmt(stmt) => { - let aggr = Aggregate::parse(&stmt); - assert_eq!(aggr.group_by(), &[0]); - assert_eq!(aggr.targets().len(), 1); - } - _ => panic!("not a select"), - } + let aggr = parse("SELECT example.user_id, COUNT(1) FROM example GROUP BY example.user_id"); + assert_eq!(aggr.group_by(), &[0]); + assert_eq!(aggr.targets().len(), 1); } }