From 18f928eaed574997271dc0c1f3b3227d18ee559a Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sat, 21 Feb 2026 07:58:53 -0800 Subject: [PATCH 1/2] Handle DROP VIEW cleanup and parser span fixes --- include/yardstick_ffi.h | 1 + src/yardstick_extension.cpp | 5 + src/yardstick_parser_ffi.cpp | 459 +++++++++++++++++++++++-- yardstick-rs/src/ffi.rs | 22 +- yardstick-rs/src/parser_ffi.rs | 26 ++ yardstick-rs/src/sql/measures.rs | 573 ++++++++++++++++++++++++------- yardstick-rs/src/sql/mod.rs | 1 + 7 files changed, 941 insertions(+), 146 deletions(-) diff --git a/include/yardstick_ffi.h b/include/yardstick_ffi.h index d076b26..52f6a38 100644 --- a/include/yardstick_ffi.h +++ b/include/yardstick_ffi.h @@ -315,6 +315,7 @@ struct YardstickMeasureAggResult { bool yardstick_has_as_measure(const char* sql); bool yardstick_has_aggregate(const char* sql); +bool yardstick_drop_measure_view_from_sql(const char* sql); bool yardstick_has_curly_brace(const char* sql); bool yardstick_has_at_syntax(const char* sql); diff --git a/src/yardstick_extension.cpp b/src/yardstick_extension.cpp index 4a1eca1..b89b10d 100644 --- a/src/yardstick_extension.cpp +++ b/src/yardstick_extension.cpp @@ -42,6 +42,7 @@ extern "C" { extern "C" { bool yardstick_has_as_measure(const char *sql); bool yardstick_has_aggregate(const char *sql); + bool yardstick_drop_measure_view_from_sql(const char *sql); YardstickCreateViewResult yardstick_process_create_view(const char *sql); YardstickAggregateResult yardstick_expand_aggregate(const char *sql); void yardstick_free(char *ptr); @@ -338,6 +339,10 @@ ParserExtensionParseResult yardstick_parse(ParserExtensionInfo *, sql_to_check = semantic_stripped; } + if (yardstick_drop_measure_view_from_sql(sql_to_check.c_str())) { + return ParserExtensionParseResult(); + } + // Check for AGGREGATE() function if (yardstick_has_aggregate(sql_to_check.c_str())) { YardstickAggregateResult result = yardstick_expand_aggregate(sql_to_check.c_str()); diff --git a/src/yardstick_parser_ffi.cpp b/src/yardstick_parser_ffi.cpp index 1be1e03..1b07a24 100644 --- a/src/yardstick_parser_ffi.cpp +++ b/src/yardstick_parser_ffi.cpp @@ -33,6 +33,7 @@ #include "duckdb/parser/group_by_node.hpp" #include "duckdb/common/string_util.hpp" +#include #include #include #include @@ -52,6 +53,405 @@ static char* safe_strdup(const std::string& s) { return strdup(s.c_str()); } +static bool IsBoundaryChar(char c) { + return !std::isalnum(static_cast(c)) && c != '_'; +} + +static size_t SkipWhitespaceAndComments(const std::string& sql, size_t idx) { + while (idx < sql.size()) { + if (std::isspace(static_cast(sql[idx]))) { + idx++; + continue; + } + if (sql[idx] == '-' && idx + 1 < sql.size() && sql[idx + 1] == '-') { + idx += 2; + while (idx < sql.size() && sql[idx] != '\n' && sql[idx] != '\r') { + idx++; + } + continue; + } + if (sql[idx] == '/' && idx + 1 < sql.size() && sql[idx + 1] == '*') { + idx += 2; + while (idx + 1 < sql.size() && !(sql[idx] == '*' && sql[idx + 1] == '/')) { + idx++; + } + if (idx + 1 < sql.size()) { + idx += 2; + } else { + idx = sql.size(); + } + continue; + } + break; + } + return idx; +} + +static size_t FindMatchingParen(const std::string& sql, size_t open_pos) { + size_t depth = 0; + bool in_single = false; + bool in_double = false; + bool in_backtick = false; + bool in_bracket = false; + bool in_line_comment = false; + bool in_block_comment = false; + + for (size_t i = open_pos; i < sql.size(); i++) { + char c = sql[i]; + + if (in_line_comment) { + if (c == '\n' || c == '\r') { + in_line_comment = false; + } + continue; + } + if (in_block_comment) { + if (c == '*' && i + 1 < sql.size() && sql[i + 1] == '/') { + in_block_comment = false; + i++; + } + continue; + } + + if (in_single) { + if (c == '\'') { + if (i + 1 < sql.size() && sql[i + 1] == '\'') { + i++; + } else { + in_single = false; + } + } + continue; + } + if (in_double) { + if (c == '"') { + if (i + 1 < sql.size() && sql[i + 1] == '"') { + i++; + } else { + in_double = false; + } + } + continue; + } + if (in_backtick) { + if (c == '`') { + in_backtick = false; + } + continue; + } + if (in_bracket) { + if (c == ']') { + in_bracket = false; + } + continue; + } + + if (c == '\'') { + in_single = true; + continue; + } + if (c == '"') { + in_double = true; + continue; + } + if (c == '`') { + in_backtick = true; + continue; + } + if (c == '[') { + in_bracket = true; + continue; + } + if (c == '-' && i + 1 < sql.size() && sql[i + 1] == '-') { + in_line_comment = true; + i++; + continue; + } + if (c == '/' && i + 1 < sql.size() && sql[i + 1] == '*') { + in_block_comment = true; + i++; + continue; + } + + if (c == '(') { + depth++; + continue; + } + if (c == ')') { + if (depth == 0) { + return i; + } + depth--; + if (depth == 0) { + return i; + } + } + } + + return std::string::npos; +} + +static size_t FindTopLevelFrom(const std::string& sql) { + std::string upper = StringUtil::Upper(sql); + const std::string keyword = "FROM"; + + size_t depth = 0; + bool in_single = false; + bool in_double = false; + bool in_backtick = false; + bool in_bracket = false; + bool in_line_comment = false; + bool in_block_comment = false; + + for (size_t i = 0; i + keyword.size() <= upper.size(); i++) { + char c = sql[i]; + + if (in_line_comment) { + if (c == '\n' || c == '\r') { + in_line_comment = false; + } + continue; + } + if (in_block_comment) { + if (c == '*' && i + 1 < sql.size() && sql[i + 1] == '/') { + in_block_comment = false; + i++; + } + continue; + } + + if (in_single) { + if (c == '\'') { + if (i + 1 < sql.size() && sql[i + 1] == '\'') { + i++; + } else { + in_single = false; + } + } + continue; + } + if (in_double) { + if (c == '"') { + if (i + 1 < upper.size() && upper[i + 1] == '"') { + i++; + } else { + in_double = false; + } + } + continue; + } + if (in_backtick) { + if (c == '`') { + in_backtick = false; + } + continue; + } + if (in_bracket) { + if (c == ']') { + in_bracket = false; + } + continue; + } + + if (c == '\'') { + in_single = true; + continue; + } + if (c == '"') { + in_double = true; + continue; + } + if (c == '-' && i + 1 < sql.size() && sql[i + 1] == '-') { + in_line_comment = true; + i++; + continue; + } + if (c == '/' && i + 1 < sql.size() && sql[i + 1] == '*') { + in_block_comment = true; + i++; + continue; + } + if (c == '`') { + in_backtick = true; + continue; + } + if (c == '[') { + in_bracket = true; + continue; + } + + if (c == '(') { + depth++; + continue; + } + if (c == ')') { + if (depth > 0) { + depth--; + } + continue; + } + + if (depth == 0 && upper.compare(i, keyword.size(), keyword) == 0) { + char prev = i == 0 ? '\0' : upper[i - 1]; + char next = i + keyword.size() < upper.size() ? upper[i + keyword.size()] : '\0'; + if (IsBoundaryChar(prev) && IsBoundaryChar(next)) { + return i; + } + } + } + + return std::string::npos; +} + +static size_t FindSelectItemEnd(const std::string& sql, size_t start, size_t from_pos) { + if (start >= sql.size()) { + return start; + } + size_t limit = from_pos == std::string::npos ? sql.size() : from_pos; + size_t depth = 0; + bool in_single = false; + bool in_double = false; + bool in_backtick = false; + bool in_bracket = false; + bool in_line_comment = false; + bool in_block_comment = false; + + for (size_t i = start; i < limit; i++) { + char c = sql[i]; + + if (in_line_comment) { + if (c == '\n' || c == '\r') { + in_line_comment = false; + } + continue; + } + if (in_block_comment) { + if (c == '*' && i + 1 < limit && sql[i + 1] == '/') { + in_block_comment = false; + i++; + } + continue; + } + + if (in_single) { + if (c == '\'') { + if (i + 1 < limit && sql[i + 1] == '\'') { + i++; + } else { + in_single = false; + } + } + continue; + } + if (in_double) { + if (c == '"') { + if (i + 1 < limit && sql[i + 1] == '"') { + i++; + } else { + in_double = false; + } + } + continue; + } + if (in_backtick) { + if (c == '`') { + in_backtick = false; + } + continue; + } + if (in_bracket) { + if (c == ']') { + in_bracket = false; + } + continue; + } + + if (c == '\'') { + in_single = true; + continue; + } + if (c == '"') { + in_double = true; + continue; + } + if (c == '-' && i + 1 < limit && sql[i + 1] == '-') { + in_line_comment = true; + i++; + continue; + } + if (c == '/' && i + 1 < limit && sql[i + 1] == '*') { + in_block_comment = true; + i++; + continue; + } + if (c == '`') { + in_backtick = true; + continue; + } + if (c == '[') { + in_bracket = true; + continue; + } + + if (c == '(') { + depth++; + continue; + } + if (c == ')') { + if (depth > 0) { + depth--; + } + continue; + } + + if (depth == 0 && c == ',') { + return i; + } + } + + return limit; +} + +static size_t FindAggregateCallEnd(const std::string& sql, size_t start) { + std::string upper = StringUtil::Upper(sql); + const std::string keyword = "AGGREGATE"; + if (start >= sql.size() || upper.compare(start, keyword.size(), keyword) != 0) { + return start; + } + + size_t i = SkipWhitespaceAndComments(sql, start + keyword.size()); + if (i >= sql.size() || sql[i] != '(') { + return start; + } + + size_t close = FindMatchingParen(sql, i); + if (close == std::string::npos) { + return start; + } + size_t end = close + 1; + + while (end < sql.size()) { + size_t j = SkipWhitespaceAndComments(sql, end); + if (j + 2 > sql.size()) { + break; + } + if (upper.compare(j, 2, "AT") != 0) { + break; + } + size_t k = SkipWhitespaceAndComments(sql, j + 2); + if (k >= sql.size() || sql[k] != '(') { + break; + } + size_t at_close = FindMatchingParen(sql, k); + if (at_close == std::string::npos) { + break; + } + end = at_close + 1; + } + + return end; +} + //============================================================================= // Helper: Check if expression is a standard aggregate function //============================================================================= @@ -83,7 +483,8 @@ struct AggregateCallInfo { std::vector modifiers; }; -static void FindAggregateCalls(ParsedExpression* expr, std::vector& results); +static void FindAggregateCalls(ParsedExpression* expr, std::vector& results, + const std::string& sql); static void CollectTablesFromTableRef(TableRef* ref, std::vector& tables); static bool ExpressionContainsAggregate(ParsedExpression* expr); static bool ExpressionContainsMeasureRef(ParsedExpression* expr); @@ -93,7 +494,8 @@ static void QualifyColumnRefs(ParsedExpression* expr, const std::string& qualifi // AST Walking: Find AGGREGATE() function calls //============================================================================= -static void FindAggregateCalls(ParsedExpression* expr, std::vector& results) { +static void FindAggregateCalls(ParsedExpression* expr, std::vector& results, + const std::string& sql) { if (!expr) return; switch (expr->expression_class) { @@ -122,8 +524,8 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vector(expr->ToString().length()); + size_t end_pos = FindAggregateCallEnd(sql, info.start_pos); + info.end_pos = static_cast(end_pos); // Parse AT modifiers from remaining arguments // AT syntax in DuckDB typically appears as special function arguments @@ -192,25 +594,25 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vectorchildren) { - FindAggregateCalls(child.get(), results); + FindAggregateCalls(child.get(), results, sql); } if (func->filter) { - FindAggregateCalls(func->filter.get(), results); + FindAggregateCalls(func->filter.get(), results, sql); } break; } case ExpressionClass::COMPARISON: { auto* comp = static_cast(expr); - FindAggregateCalls(comp->left.get(), results); - FindAggregateCalls(comp->right.get(), results); + FindAggregateCalls(comp->left.get(), results, sql); + FindAggregateCalls(comp->right.get(), results, sql); break; } case ExpressionClass::CONJUNCTION: { auto* conj = static_cast(expr); for (auto& child : conj->children) { - FindAggregateCalls(child.get(), results); + FindAggregateCalls(child.get(), results, sql); } break; } @@ -218,7 +620,7 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vector(expr); for (auto& child : op->children) { - FindAggregateCalls(child.get(), results); + FindAggregateCalls(child.get(), results, sql); } break; } @@ -226,25 +628,25 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vector(expr); for (auto& check : case_expr->case_checks) { - FindAggregateCalls(check.when_expr.get(), results); - FindAggregateCalls(check.then_expr.get(), results); + FindAggregateCalls(check.when_expr.get(), results, sql); + FindAggregateCalls(check.then_expr.get(), results, sql); } if (case_expr->else_expr) { - FindAggregateCalls(case_expr->else_expr.get(), results); + FindAggregateCalls(case_expr->else_expr.get(), results, sql); } break; } case ExpressionClass::CAST: { auto* cast = static_cast(expr); - FindAggregateCalls(cast->child.get(), results); + FindAggregateCalls(cast->child.get(), results, sql); break; } case ExpressionClass::SUBQUERY: { auto* subq = static_cast(expr); if (subq->child) { - FindAggregateCalls(subq->child.get(), results); + FindAggregateCalls(subq->child.get(), results, sql); } // Note: We don't recurse into the subquery itself break; @@ -253,22 +655,22 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vector(expr); for (auto& child : window->children) { - FindAggregateCalls(child.get(), results); + FindAggregateCalls(child.get(), results, sql); } for (auto& part : window->partitions) { - FindAggregateCalls(part.get(), results); + FindAggregateCalls(part.get(), results, sql); } if (window->filter_expr) { - FindAggregateCalls(window->filter_expr.get(), results); + FindAggregateCalls(window->filter_expr.get(), results, sql); } break; } case ExpressionClass::BETWEEN: { auto* between = static_cast(expr); - FindAggregateCalls(between->input.get(), results); - FindAggregateCalls(between->lower.get(), results); - FindAggregateCalls(between->upper.get(), results); + FindAggregateCalls(between->input.get(), results, sql); + FindAggregateCalls(between->lower.get(), results, sql); + FindAggregateCalls(between->upper.get(), results, sql); break; } @@ -587,22 +989,22 @@ extern "C" YardstickAggregateCallList* yardstick_find_aggregates(const char* sql // Search in SELECT list for (auto& expr : select_node->select_list) { - FindAggregateCalls(expr.get(), aggregates); + FindAggregateCalls(expr.get(), aggregates, sql); } // Search in WHERE clause if (select_node->where_clause) { - FindAggregateCalls(select_node->where_clause.get(), aggregates); + FindAggregateCalls(select_node->where_clause.get(), aggregates, sql); } // Search in HAVING clause if (select_node->having) { - FindAggregateCalls(select_node->having.get(), aggregates); + FindAggregateCalls(select_node->having.get(), aggregates, sql); } // Search in GROUP BY expressions for (auto& expr : select_node->groups.group_expressions) { - FindAggregateCalls(expr.get(), aggregates); + FindAggregateCalls(expr.get(), aggregates, sql); } } @@ -705,6 +1107,10 @@ extern "C" YardstickSelectInfo* yardstick_parse_select(const char* sql) { } auto* select_node = static_cast(select_stmt->node.get()); + size_t from_pos = FindTopLevelFrom(sql); + if (from_pos == std::string::npos) { + from_pos = sql.size(); + } // Process SELECT list std::vector items; @@ -718,7 +1124,8 @@ extern "C" YardstickSelectInfo* yardstick_parse_select(const char* sql) { } else { item.start_pos = 0; } - item.end_pos = item.start_pos + static_cast(expr->ToString().length()); + size_t end_pos = FindSelectItemEnd(sql, item.start_pos, from_pos); + item.end_pos = static_cast(end_pos); item.is_aggregate = ExpressionContainsAggregate(expr.get()); item.is_star = expr->expression_class == ExpressionClass::STAR; diff --git a/yardstick-rs/src/ffi.rs b/yardstick-rs/src/ffi.rs index 0288121..9ecaafe 100644 --- a/yardstick-rs/src/ffi.rs +++ b/yardstick-rs/src/ffi.rs @@ -11,8 +11,9 @@ use std::os::raw::c_char; use std::ptr; use crate::sql::{ - expand_aggregate_with_at, expand_curly_braces, get_measure_aggregation, has_aggregate_function, - has_as_measure, has_at_syntax, has_curly_brace_measure, process_create_view, + drop_measure_view_from_sql, expand_aggregate_with_at, expand_curly_braces, + get_measure_aggregation, has_aggregate_function, has_as_measure, has_at_syntax, + has_curly_brace_measure, process_create_view, }; /// Result from processing CREATE VIEW with AS MEASURE @@ -73,6 +74,23 @@ pub extern "C" fn yardstick_has_aggregate(sql: *const c_char) -> bool { has_aggregate_function(sql_str) } +/// Drop a measure view from the catalog if the SQL is a DROP VIEW statement +#[no_mangle] +pub extern "C" fn yardstick_drop_measure_view_from_sql(sql: *const c_char) -> bool { + if sql.is_null() { + return false; + } + + let sql_str = unsafe { + match CStr::from_ptr(sql).to_str() { + Ok(s) => s, + Err(_) => return false, + } + }; + + drop_measure_view_from_sql(sql_str) +} + /// Check if SQL contains curly brace measure syntax: `{column}` #[no_mangle] pub extern "C" fn yardstick_has_curly_brace(sql: *const c_char) -> bool { diff --git a/yardstick-rs/src/parser_ffi.rs b/yardstick-rs/src/parser_ffi.rs index fe85b73..b2b1fb8 100644 --- a/yardstick-rs/src/parser_ffi.rs +++ b/yardstick-rs/src/parser_ffi.rs @@ -971,6 +971,19 @@ mod tests { assert_eq!(calls[0].modifiers[0].modifier_type, AtType::AllGlobal); } + #[test] + #[ignore = "requires C++ library to be linked"] + fn test_find_aggregates_with_comments_in_at_chain() { + let sql = "SELECT AGGREGATE(revenue) /* keep */ AT (ALL) FROM sales"; + let calls = find_aggregates(sql).unwrap(); + assert_eq!(calls.len(), 1); + let call = &calls[0]; + assert_eq!( + &sql[call.start_pos as usize..call.end_pos as usize], + "AGGREGATE(revenue) /* keep */ AT (ALL)" + ); + } + #[test] #[ignore = "requires C++ library to be linked"] fn test_parse_select() { @@ -982,6 +995,19 @@ mod tests { assert_eq!(info.group_by_cols.len(), 1); } + #[test] + #[ignore = "requires C++ library to be linked"] + fn test_parse_select_item_positions_ignore_comment_tokens() { + let sql = "SELECT region /* , fake comma FROM fake */, AGGREGATE(revenue) FROM sales"; + let info = parse_select(sql).unwrap(); + assert_eq!(info.items.len(), 2); + let first = &info.items[0]; + assert_eq!( + sql[first.start_pos as usize..first.end_pos as usize].trim(), + "region /* , fake comma FROM fake */" + ); + } + #[test] #[ignore = "requires C++ library to be linked"] fn test_parse_expression() { diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 231c084..4f30894 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -694,6 +694,39 @@ pub fn extract_view_name(sql: &str) -> Option { .map(|(_, name)| name.to_string()) } +/// Extract view name from DROP VIEW statement +pub fn extract_drop_view_name(sql: &str) -> Option { + let upper = sql.to_uppercase(); + let mut idx = skip_ws_and_comments(sql, 0); + if !matches_keyword_at(&upper, idx, "DROP") { + return None; + } + idx += "DROP".len(); + idx = skip_ws_and_comments(sql, idx); + + if !matches_keyword_at(&upper, idx, "VIEW") { + return None; + } + idx += "VIEW".len(); + idx = skip_ws_and_comments(sql, idx); + + if matches_keyword_at(&upper, idx, "IF") { + idx += "IF".len(); + idx = skip_ws_and_comments(sql, idx); + if !matches_keyword_at(&upper, idx, "EXISTS") { + return None; + } + idx += "EXISTS".len(); + idx = skip_ws_and_comments(sql, idx); + } + + let end = parse_qualified_name_span(sql, idx)?; + if !is_statement_tail(sql, end) { + return None; + } + extract_last_qualified_identifier(&sql[idx..end]) +} + /// Extract table name from SQL FROM clause pub fn extract_table_name_from_sql(sql: &str) -> Option { extract_table_and_alias_from_sql(sql).map(|(name, _)| name) @@ -702,91 +735,66 @@ pub fn extract_table_name_from_sql(sql: &str) -> Option { /// Extract table name and optional alias from SQL FROM clause /// Returns (table_name, Option) pub fn extract_table_and_alias_from_sql(sql: &str) -> Option<(String, Option)> { - fn skip_ws_and_comments(sql: &str, mut idx: usize) -> usize { - let bytes = sql.as_bytes(); - while idx < bytes.len() { - let c = bytes[idx] as char; - if c.is_whitespace() { - idx += 1; - continue; - } - if c == '-' && idx + 1 < bytes.len() && bytes[idx + 1] as char == '-' { - idx += 2; - while idx < bytes.len() { - let ch = bytes[idx] as char; - idx += 1; - if ch == '\n' || ch == '\r' { - break; - } - } - continue; - } - if c == '/' && idx + 1 < bytes.len() && bytes[idx + 1] as char == '*' { - idx += 2; - while idx + 1 < bytes.len() { - let ch = bytes[idx] as char; - if ch == '*' && bytes[idx + 1] as char == '/' { - idx += 2; - break; - } - idx += 1; - } - continue; - } - break; - } - idx - } - let from_pos = find_top_level_keyword(sql, "FROM", 0)?; - let mut idx = from_pos + 4; - idx = skip_ws_and_comments(sql, idx); - if idx >= sql.len() { + let upper = sql.to_uppercase(); + let i = skip_ws_and_comments(sql, from_pos + 4); + let bytes = sql.as_bytes(); + if i >= bytes.len() || bytes[i] == b'(' { return None; } - let table_start = idx; - while idx < sql.len() && is_table_ident_char(sql.as_bytes()[idx] as char) { - idx += 1; + let table_end = parse_qualified_name_span(sql, i)?; + let raw_table = sql[i..table_end].trim(); + let table_name = + extract_last_qualified_identifier(raw_table).unwrap_or_else(|| raw_table.to_string()); + + let j = skip_ws_and_comments(sql, table_end); + if j >= bytes.len() || sql[j..].starts_with(';') { + return Some((table_name, None)); } - if table_start == idx { - return None; + + if matches_keyword_at(&upper, j, "WHERE") + || matches_keyword_at(&upper, j, "GROUP") + || matches_keyword_at(&upper, j, "ORDER") + || matches_keyword_at(&upper, j, "LIMIT") + || matches_keyword_at(&upper, j, "HAVING") + || matches_keyword_at(&upper, j, "JOIN") + { + return Some((table_name, None)); } - let table = sql[table_start..idx].to_string(); - idx = skip_ws_and_comments(sql, idx); - if idx >= sql.len() || sql.as_bytes()[idx] as char == ';' { - return Some((table, None)); + if let Some((_, alias)) = parse_alias_span(sql, j) { + Some((table_name, Some(alias))) + } else { + Some((table_name, None)) } +} - let rest = &sql[idx..]; - let rest_upper = rest.to_uppercase(); - let mut rest_after_as = rest; - if rest_upper.starts_with("AS") { - let after_as = &rest[2..]; - if after_as - .chars() - .next() - .map_or(false, |ch| ch.is_whitespace()) - { - let mut as_idx = idx + 2; - as_idx = skip_ws_and_comments(sql, as_idx); - rest_after_as = &sql[as_idx..]; - } +fn insert_primary_table_alias(sql: &str, alias: &str) -> Option { + // Cases: + // - no top-level FROM: leave unchanged + // - FROM starts with subquery: leave unchanged + // - existing alias (with or without AS, quoted or unquoted): leave unchanged + // - no alias: insert the provided alias after the primary table name + let from_pos = find_top_level_keyword(sql, "FROM", 0)?; + let bytes = sql.as_bytes(); + let i = skip_ws_and_comments(sql, from_pos + 4); + if i >= bytes.len() || bytes[i] == b'(' { + return None; } - if let Ok((_, alias)) = identifier(rest_after_as.trim_start()) { - let alias_upper = alias.to_uppercase(); - if matches!( - alias_upper.as_str(), - "FROM" | "WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING" | "JOIN" - ) { - return Some((table, None)); - } - Some((table, Some(alias.to_string()))) - } else { - Some((table, None)) + let table_end = parse_qualified_name_span(sql, i)?; + let j = skip_ws_and_comments(sql, table_end); + if parse_alias_span(sql, j).is_some() { + return None; } + + let mut result = String::with_capacity(sql.len() + alias.len() + 1); + result.push_str(&sql[..table_end]); + result.push(' '); + result.push_str(alias); + result.push_str(&sql[table_end..]); + Some(result) } /// Extract the SELECT/WITH query from a CREATE VIEW statement @@ -827,8 +835,205 @@ fn is_boundary_char(ch: Option) -> bool { ch.map_or(true, |c| !c.is_alphanumeric() && c != '_') } -fn is_table_ident_char(ch: char) -> bool { - ch.is_alphanumeric() || ch == '_' || ch == '.' +fn is_statement_tail(sql: &str, mut idx: usize) -> bool { + let bytes = sql.as_bytes(); + idx = skip_ws_and_comments(sql, idx); + while idx < bytes.len() && bytes[idx] == b';' { + idx += 1; + idx = skip_ws_and_comments(sql, idx); + } + idx == bytes.len() +} + +fn skip_ws_and_comments(sql: &str, mut idx: usize) -> usize { + let bytes = sql.as_bytes(); + while idx < bytes.len() { + if bytes[idx].is_ascii_whitespace() { + idx += 1; + continue; + } + if bytes[idx] == b'-' && idx + 1 < bytes.len() && bytes[idx + 1] == b'-' { + idx += 2; + while idx < bytes.len() { + let ch = bytes[idx]; + idx += 1; + if ch == b'\n' || ch == b'\r' { + break; + } + } + continue; + } + if bytes[idx] == b'/' && idx + 1 < bytes.len() && bytes[idx + 1] == b'*' { + idx += 2; + while idx + 1 < bytes.len() { + if bytes[idx] == b'*' && bytes[idx + 1] == b'/' { + idx += 2; + break; + } + idx += 1; + } + if idx + 1 >= bytes.len() { + idx = bytes.len(); + } + continue; + } + break; + } + idx +} + +fn is_ident_start_byte(b: u8) -> bool { + (b as char).is_ascii_alphabetic() || b == b'_' +} + +fn is_ident_part_byte(b: u8) -> bool { + (b as char).is_ascii_alphanumeric() || b == b'_' +} + +fn strip_identifier_quotes(input: &str) -> &str { + if input.len() >= 2 { + let bytes = input.as_bytes(); + let (first, last) = (bytes[0], bytes[bytes.len() - 1]); + if (first == b'"' && last == b'"') + || (first == b'`' && last == b'`') + || (first == b'[' && last == b']') + { + return &input[1..input.len() - 1]; + } + } + input +} + +fn extract_last_qualified_identifier(input: &str) -> Option { + let bytes = input.as_bytes(); + let mut i = skip_ws_and_comments(input, 0); + let (end, _) = parse_identifier_token(input, i)?; + let mut last_start = i; + let mut last_end = end; + i = end; + + loop { + let mut j = skip_ws_and_comments(input, i); + if j >= bytes.len() || bytes[j] != b'.' { + break; + } + j += 1; + j = skip_ws_and_comments(input, j); + let (next_end, _) = parse_identifier_token(input, j)?; + last_start = j; + last_end = next_end; + i = next_end; + } + + let tail = skip_ws_and_comments(input, i); + if tail < bytes.len() { + return None; + } + + Some(strip_identifier_quotes(input[last_start..last_end].trim()).to_string()) +} + +fn parse_identifier_token(sql: &str, start: usize) -> Option<(usize, bool)> { + let bytes = sql.as_bytes(); + if start >= bytes.len() { + return None; + } + let first = bytes[start]; + if first == b'"' || first == b'`' || first == b'[' { + let end_char = match first { + b'"' => b'"', + b'`' => b'`', + _ => b']', + }; + let mut i = start + 1; + while i < bytes.len() { + if bytes[i] == end_char { + if end_char == b'"' && i + 1 < bytes.len() && bytes[i + 1] == b'"' { + i += 2; + continue; + } + return Some((i + 1, true)); + } + i += 1; + } + return None; + } + if !is_ident_start_byte(first) { + return None; + } + let mut i = start + 1; + while i < bytes.len() && is_ident_part_byte(bytes[i]) { + i += 1; + } + Some((i, false)) +} + +fn parse_qualified_name_span(sql: &str, start: usize) -> Option { + let bytes = sql.as_bytes(); + let mut i = start; + let (mut end, _) = parse_identifier_token(sql, i)?; + i = end; + + loop { + let mut j = skip_ws_and_comments(sql, i); + if j >= bytes.len() || bytes[j] != b'.' { + break; + } + j += 1; + j = skip_ws_and_comments(sql, j); + let (next_end, _) = parse_identifier_token(sql, j)?; + end = next_end; + i = end; + } + + Some(end) +} + +fn parse_alias_span(sql: &str, start: usize) -> Option<(usize, String)> { + let upper = sql.to_uppercase(); + let mut i = skip_ws_and_comments(sql, start); + let bytes = sql.as_bytes(); + if i >= bytes.len() { + return None; + } + + if matches_keyword_at(&upper, i, "AS") { + i += 2; + i = skip_ws_and_comments(sql, i); + } + + let (end, quoted) = parse_identifier_token(sql, i)?; + let raw = sql[i..end].trim().to_string(); + if quoted { + return Some((end, raw)); + } + + let alias_upper = raw.to_uppercase(); + if matches!( + alias_upper.as_str(), + "WHERE" + | "GROUP" + | "ORDER" + | "LIMIT" + | "HAVING" + | "JOIN" + | "ON" + | "USING" + | "INNER" + | "LEFT" + | "RIGHT" + | "FULL" + | "CROSS" + | "UNION" + | "EXCEPT" + | "INTERSECT" + | "WINDOW" + | "QUALIFY" + ) { + return None; + } + + Some((end, raw)) } fn find_top_level_keyword(sql: &str, keyword: &str, start: usize) -> Option { @@ -1021,42 +1226,6 @@ fn find_top_level_keyword(sql: &str, keyword: &str, start: usize) -> Option Option { - let from_pos = find_top_level_keyword(sql, "FROM", 0)?; - let bytes = sql.as_bytes(); - let mut idx = from_pos + 4; - while idx < bytes.len() && bytes[idx].is_ascii_whitespace() { - idx += 1; - } - if idx >= bytes.len() { - return None; - } - - let table_start = idx; - while idx < bytes.len() && is_table_ident_char(bytes[idx] as char) { - idx += 1; - } - if table_start == idx { - return None; - } - - let table_token = &sql[table_start..idx]; - let table_simple = table_token - .split('.') - .next_back() - .unwrap_or(table_token); - if !table_simple.eq_ignore_ascii_case(table_name) { - return None; - } - - let mut updated = String::with_capacity(sql.len() + alias.len() + 1); - updated.push_str(&sql[..idx]); - updated.push(' '); - updated.push_str(alias); - updated.push_str(&sql[idx..]); - Some(updated) -} - fn find_first_top_level_keyword(sql: &str, start: usize, keywords: &[&str]) -> Option { keywords .iter() @@ -1682,6 +1851,54 @@ fn expand_derived_measure_expr(expr: &str, measure_view: &MeasureView) -> String let mut chars = expr.chars().peekable(); while let Some(c) = chars.next() { + if c == '\'' { + result.push(c); + while let Some(next) = chars.next() { + result.push(next); + if next == '\'' { + if chars.peek() == Some(&'\'') { + result.push(chars.next().unwrap()); + } else { + break; + } + } + } + continue; + } + if c == '"' { + result.push(c); + while let Some(next) = chars.next() { + result.push(next); + if next == '"' { + if chars.peek() == Some(&'"') { + result.push(chars.next().unwrap()); + } else { + break; + } + } + } + continue; + } + if c == '`' { + result.push(c); + while let Some(next) = chars.next() { + result.push(next); + if next == '`' { + break; + } + } + continue; + } + if c == '[' { + result.push(c); + while let Some(next) = chars.next() { + result.push(next); + if next == ']' { + break; + } + } + continue; + } if c.is_alphabetic() || c == '_' { // Collect identifier let mut ident = String::from(c); @@ -4047,18 +4264,16 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { let mut result_sql = sql; // Handle alias for the primary table if needed for correlation + let mut insert_outer_alias = false; let primary_alias: Option = if needs_outer_alias { if let Some(ref pt) = from_info.primary_table { if pt.has_alias { Some(pt.effective_name.clone()) + } else if insert_primary_table_alias(&result_sql, "_outer").is_some() { + insert_outer_alias = true; + Some("_outer".to_string()) } else { - // No alias on primary table, add _outer - if let Some(updated_sql) = insert_table_alias(&result_sql, &pt.name, "_outer") { - result_sql = updated_sql; - Some("_outer".to_string()) - } else { - None - } + None } } else { None @@ -4275,6 +4490,12 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { } } + if insert_outer_alias { + if let Some(updated_sql) = insert_primary_table_alias(&result_sql, "_outer") { + result_sql = updated_sql; + } + } + // Check if there are still any remaining AGGREGATE calls (shouldn't be, but just in case) if has_aggregate_function(&result_sql) { return expand_aggregate(&result_sql); @@ -4348,6 +4569,27 @@ pub fn clear_measure_views() { views.clear(); } +pub fn drop_measure_view(view_name: &str) -> bool { + let mut views = MEASURE_VIEWS.lock().unwrap(); + let key = views + .keys() + .find(|k| k.eq_ignore_ascii_case(view_name)) + .cloned(); + if let Some(k) = key { + views.remove(&k); + return true; + } + false +} + +pub fn drop_measure_view_from_sql(sql: &str) -> bool { + let name = match extract_drop_view_name(sql) { + Some(name) => name, + None => return false, + }; + drop_measure_view(&name) +} + pub fn get_measure_aggregation(column_name: &str) -> Option<(String, String)> { let views = MEASURE_VIEWS.lock().unwrap(); @@ -4991,6 +5233,47 @@ FROM orders"#; ); } + #[test] + #[serial] + fn test_drop_measure_view_from_sql() { + clear_measure_views(); + store_measure_view( + "orders_v", + vec![ViewMeasure { + column_name: "revenue".to_string(), + expression: "SUM(amount)".to_string(), + is_decomposable: true, + }], + "SELECT year, SUM(amount) AS revenue FROM orders GROUP BY year", + Some("orders".to_string()), + ); + + assert!(get_measure_view("orders_v").is_some()); + assert!(!drop_measure_view_from_sql( + "DROP VIEW orders_v /* invalid tail */ extra" + )); + assert!(get_measure_view("orders_v").is_some()); + assert!(drop_measure_view_from_sql("DROP VIEW orders_v;")); + assert!(get_measure_view("orders_v").is_none()); + } + + #[test] + fn test_extract_drop_view_name_variants() { + assert_eq!( + extract_drop_view_name("DROP VIEW IF EXISTS analytics.orders_v"), + Some("orders_v".to_string()) + ); + assert_eq!( + extract_drop_view_name("drop view if exists \"Analytics\".\"Order.View\""), + Some("Order.View".to_string()) + ); + assert_eq!( + extract_drop_view_name("DROP VIEW /* keep */ IF EXISTS [dbo].[Orders View]"), + Some("Orders View".to_string()) + ); + assert_eq!(extract_drop_view_name("DROP VIEW orders_v extra"), None); + } + #[test] fn test_extract_base_relation_sql_with_cte() { let view_query = "WITH base AS (SELECT * FROM orders) \ @@ -5287,6 +5570,30 @@ FROM orders"#; assert_eq!(expanded3, "SUM(revenue) * 100"); } + #[test] + fn test_expand_derived_measure_expr_ignores_string_literals() { + let mv = MeasureView { + view_name: "sales_v".to_string(), + measures: vec![ViewMeasure { + column_name: "revenue".to_string(), + expression: "SUM(amount)".to_string(), + is_decomposable: true, + }], + base_query: "".to_string(), + base_table: Some("sales".to_string()), + base_relation_sql: None, + dimension_exprs: HashMap::new(), + group_by_cols: Vec::new(), + }; + + let expanded = + expand_derived_measure_expr("CASE WHEN status = 'revenue' THEN revenue ELSE 0 END", &mv); + assert_eq!( + expanded, + "CASE WHEN status = 'revenue' THEN SUM(revenue) ELSE 0 END" + ); + } + #[test] fn test_extract_table_and_alias() { // No alias @@ -5318,6 +5625,36 @@ FROM orders"#; extract_table_and_alias_from_sql("SELECT x FROM orders GROUP BY x"), Some(("orders".to_string(), None)) ); + + // Quoted qualified names with dots inside identifiers + assert_eq!( + extract_table_and_alias_from_sql( + "SELECT * FROM \"sales.schema\".\"orders.table\" o WHERE 1=1" + ), + Some(("orders.table".to_string(), Some("o".to_string()))) + ); + + // Alias after comments (no AS) + assert_eq!( + extract_table_and_alias_from_sql("SELECT * FROM orders /* source */ o"), + Some(("orders".to_string(), Some("o".to_string()))) + ); + + // Alias after comments (with AS) + assert_eq!( + extract_table_and_alias_from_sql("SELECT * FROM orders AS /* source */ o"), + Some(("orders".to_string(), Some("o".to_string()))) + ); + } + + #[test] + fn test_insert_primary_table_alias_skips_comments() { + let sql = "SELECT year, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v /* c */ GROUP BY year"; + let updated = insert_primary_table_alias(sql, "_outer").unwrap(); + assert!(updated.contains("FROM sales_v _outer /* c */ GROUP BY")); + + let sql_with_alias = "SELECT year FROM sales_v /* c */ s GROUP BY year"; + assert!(insert_primary_table_alias(sql_with_alias, "_outer").is_none()); } #[test] diff --git a/yardstick-rs/src/sql/mod.rs b/yardstick-rs/src/sql/mod.rs index 868bcd2..acc24ad 100644 --- a/yardstick-rs/src/sql/mod.rs +++ b/yardstick-rs/src/sql/mod.rs @@ -4,6 +4,7 @@ pub mod measures; pub use measures::{ // Processing functions + drop_measure_view_from_sql, expand_aggregate, expand_aggregate_with_at, expand_curly_braces, From 7c547138b4e2efd7a8a707941b8831f8abc96712 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sat, 21 Feb 2026 08:23:23 -0800 Subject: [PATCH 2/2] Fix parse_select SQL length handling --- src/yardstick_parser_ffi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yardstick_parser_ffi.cpp b/src/yardstick_parser_ffi.cpp index 1b07a24..079ec9e 100644 --- a/src/yardstick_parser_ffi.cpp +++ b/src/yardstick_parser_ffi.cpp @@ -1109,7 +1109,7 @@ extern "C" YardstickSelectInfo* yardstick_parse_select(const char* sql) { auto* select_node = static_cast(select_stmt->node.get()); size_t from_pos = FindTopLevelFrom(sql); if (from_pos == std::string::npos) { - from_pos = sql.size(); + from_pos = std::strlen(sql); } // Process SELECT list