Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 277 additions & 1 deletion pgdog/src/frontend/router/parser/aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use pg_query::protobuf::Integer;
use pg_query::protobuf::{a_const::Val, SelectStmt};
use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString};
use pg_query::NodeEnum;

use crate::frontend::router::parser::{ExpressionRegistry, Function};
Expand Down Expand Up @@ -67,6 +67,60 @@ pub struct Aggregate {
group_by: Vec<usize>,
}

fn target_list_to_index(stmt: &SelectStmt, column_names: Vec<&String>) -> Option<usize> {
for (idx, node) in stmt.target_list.iter().enumerate() {
if let Some(NodeEnum::ResTarget(res_target_box)) = node.node.as_ref() {
let res_target = res_target_box.as_ref();
if let Some(node_box) = res_target.val.as_ref() {
if let Some(NodeEnum::ColumnRef(column_ref)) = node_box.node.as_ref() {
let select_names: Vec<&String> = column_ref
.fields
.iter()
.filter_map(|field_node| {
if let Some(node_box) = field_node.node.as_ref() {
match node_box {
NodeEnum::String(PgQueryString {
sval: found_column_name,
..
}) => Some(found_column_name),
_ => None,
}
} else {
None
}
})
.collect();

if select_names.is_empty() {
continue;
}

if columns_match(&column_names, &select_names) {
return Some(idx);
}
}
}
}
}
None
}

fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool {
if group_by_names == select_names {
return true;
}

if group_by_names.len() == 1 && select_names.len() == 2 {
return select_names[1] == group_by_names[0];
}

if group_by_names.len() == 2 && select_names.len() == 1 {
return group_by_names[1] == select_names[0];
}

false
}

impl Aggregate {
/// Figure out what aggregates are present and which ones PgDog supports.
pub fn parse(stmt: &SelectStmt) -> Result<Self, Error> {
Expand All @@ -81,6 +135,20 @@ impl Aggregate {
Val::Ival(Integer { ival }) => Some(*ival as usize - 1), // We use 0-indexed arrays, Postgres uses 1-indexed.
_ => None,
}),
NodeEnum::ColumnRef(column_ref) => {
let column_names: Vec<&String> = column_ref
.fields
.iter()
.filter_map(|node| match node {
Node {
node:
Some(NodeEnum::String(PgQueryString { sval: column_name })),
} => Some(column_name),
_ => None,
})
.collect();
Some(target_list_to_index(stmt, column_names))
}
_ => None,
})
})
Expand Down Expand Up @@ -381,4 +449,212 @@ mod test {
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_column_name_single() {
let query = pg_query::parse("SELECT user_id, COUNT(1) FROM example GROUP BY user_id")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[0]);
assert_eq!(aggr.targets().len(), 1);
let target = &aggr.targets()[0];
assert!(matches!(target.function(), AggregateFunction::Count));
assert_eq!(target.column(), 1);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_column_name_multiple() {
let query = pg_query::parse(
"SELECT COUNT(*), user_id, category_id FROM example GROUP BY user_id, category_id",
)
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[1, 2]);
assert_eq!(aggr.targets().len(), 1);
let target = &aggr.targets()[0];
assert!(matches!(target.function(), AggregateFunction::Count));
assert_eq!(target.column(), 0);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_qualified_column_name() {
let query = pg_query::parse(
"SELECT COUNT(1), example.user_id FROM example GROUP BY example.user_id",
)
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[1]);
assert_eq!(aggr.targets().len(), 1);
let target = &aggr.targets()[0];
assert!(matches!(target.function(), AggregateFunction::Count));
assert_eq!(target.column(), 0);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_mixed_ordinal_and_column_name() {
let query = pg_query::parse(
"SELECT user_id, category_id, SUM(quantity) FROM example GROUP BY user_id, 2",
)
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[0, 1]);
assert_eq!(aggr.targets().len(), 1);
let target = &aggr.targets()[0];
assert!(matches!(target.function(), AggregateFunction::Sum));
assert_eq!(target.column(), 2);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_column_not_in_select() {
let query = pg_query::parse("SELECT COUNT(*) FROM example GROUP BY user_id")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
let empty: Vec<usize> = vec![];
assert_eq!(aggr.group_by(), empty.as_slice());
assert_eq!(aggr.targets().len(), 1);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_with_multiple_aggregates() {
let query = pg_query::parse(
"SELECT COUNT(*), SUM(price), user_id, AVG(price) FROM example GROUP BY user_id",
)
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[2]);
assert_eq!(aggr.targets().len(), 3);
assert!(matches!(
aggr.targets()[0].function(),
AggregateFunction::Count
));
assert!(matches!(
aggr.targets()[1].function(),
AggregateFunction::Sum
));
assert!(matches!(
aggr.targets()[2].function(),
AggregateFunction::Avg
));
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_qualified_matches_select_unqualified() {
let query =
pg_query::parse("SELECT user_id, COUNT(1) FROM example GROUP BY example.user_id")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[0]);
assert_eq!(aggr.targets().len(), 1);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_unqualified_matches_select_qualified() {
let query =
pg_query::parse("SELECT example.user_id, COUNT(1) FROM example GROUP BY user_id")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[0]);
assert_eq!(aggr.targets().len(), 1);
}
_ => panic!("not a select"),
}
}

#[test]
fn test_parse_group_by_both_qualified_order_matters() {
let query = pg_query::parse(
"SELECT example.user_id, COUNT(1) FROM example GROUP BY example.user_id",
)
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
match query.stmt.unwrap().node.unwrap() {
NodeEnum::SelectStmt(stmt) => {
let aggr = Aggregate::parse(&stmt).unwrap();
assert_eq!(aggr.group_by(), &[0]);
assert_eq!(aggr.targets().len(), 1);
}
_ => panic!("not a select"),
}
}
}
Loading