Skip to content

Commit

Permalink
Merge 42b56b6 into 42c5d43
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Aug 16, 2022
2 parents 42c5d43 + 42b56b6 commit 29d0509
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 52 deletions.
23 changes: 4 additions & 19 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,10 @@ impl fmt::Display for Expr {
Expr::UnaryOp { op, expr } => {
if op == &UnaryOperator::PGPostfixFactorial {
write!(f, "{}{}", expr, op)
} else {
} else if op == &UnaryOperator::Not {
write!(f, "{} {}", op, expr)
} else {
write!(f, "{}{}", op, expr)
}
}
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
Expand Down Expand Up @@ -1088,7 +1090,7 @@ pub enum Statement {
local: bool,
hivevar: bool,
variable: ObjectName,
value: Vec<SetVariableValue>,
value: Vec<Expr>,
},
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
///
Expand Down Expand Up @@ -2733,23 +2735,6 @@ impl fmt::Display for ShowStatementFilter {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SetVariableValue {
Ident(Ident),
Literal(Value),
}

impl fmt::Display for SetVariableValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use SetVariableValue::*;
match self {
Ident(ident) => write!(f, "{}", ident),
Literal(literal) => write!(f, "{}", literal),
}
}
}

/// Sqlite specific syntax
///
/// https://sqlite.org/lang_conflict.html
Expand Down
1 change: 1 addition & 0 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ impl Dialect for MySqlDialect {
|| ('A'..='Z').contains(&ch)
|| ch == '_'
|| ch == '$'
|| ch == '@'
|| ('\u{0080}'..='\u{ffff}').contains(&ch)
}

Expand Down
20 changes: 5 additions & 15 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3737,22 +3737,12 @@ impl<'a> Parser<'a> {
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let mut values = vec![];
loop {
let token = self.peek_token();
let value = match (self.parse_value(), token) {
(Ok(value), _) => SetVariableValue::Literal(value),
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()),
(Err(_), Token::Minus) => {
let next_token = self.next_token();
match next_token {
Token::Word(ident) => SetVariableValue::Ident(Ident {
quote_style: ident.quote_style,
value: format!("-{}", ident.value),
}),
_ => self.expected("word", next_token)?,
}
}
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
let value = if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
Expand Down
8 changes: 4 additions & 4 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ fn parse_select_count_wildcard() {

#[test]
fn parse_select_count_distinct() {
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
let sql = "SELECT COUNT(DISTINCT +x) FROM customer";
let select = verified_only_select(sql);
assert_eq!(
&Expr::Function(Function {
Expand All @@ -597,8 +597,8 @@ fn parse_select_count_distinct() {
);

one_statement_parses_to(
"SELECT COUNT(ALL + x) FROM customer",
"SELECT COUNT(+ x) FROM customer",
"SELECT COUNT(ALL +x) FROM customer",
"SELECT COUNT(+x) FROM customer",
);

let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
Expand Down Expand Up @@ -754,7 +754,7 @@ fn parse_compound_expr_2() {
#[test]
fn parse_unary_math() {
use self::Expr::*;
let sql = "- a + - b";
let sql = "-a + -b";
assert_eq!(
BinaryOp {
left: Box::new(UnaryOp {
Expand Down
9 changes: 6 additions & 3 deletions tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//! Test SQL syntax specific to Hive. The parser based on the generic dialect
//! is also tested (on the inputs it can handle).

use sqlparser::ast::{CreateFunctionUsing, Ident, ObjectName, SetVariableValue, Statement};
use sqlparser::ast::{CreateFunctionUsing, Expr, Ident, ObjectName, Statement, UnaryOperator};
use sqlparser::dialect::{GenericDialect, HiveDialect};
use sqlparser::parser::ParserError;
use sqlparser::test_utils::*;
Expand Down Expand Up @@ -220,14 +220,17 @@ fn set_statement_with_minus() {
Ident::new("java"),
Ident::new("opts")
]),
value: vec![SetVariableValue::Ident("-Xmx4g".into())],
value: vec![Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))
}],
}
);

assert_eq!(
hive().parse_sql_statements("SET hive.tez.java.opts = -"),
Err(ParserError::ParserError(
"Expected word, found: EOF".to_string()
"Expected variable value, found: EOF".to_string()
))
)
}
Expand Down
20 changes: 20 additions & 0 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,26 @@ fn parse_use() {
);
}

#[test]
fn parse_set_variables() {
mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')");
assert_eq!(
mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"),
Statement::SetVariable {
local: true,
hivevar: false,
variable: ObjectName(vec!["autocommit".into()]),
value: vec![Expr::Value(Value::Number(
#[cfg(not(feature = "bigdecimal"))]
"1".to_string(),
#[cfg(feature = "bigdecimal")]
bigdecimal::BigDecimal::from(1),
false
))],
}
);
}

#[test]
fn parse_create_table_auto_increment() {
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";
Expand Down
34 changes: 23 additions & 11 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
mod test_utils;
use test_utils::*;

use sqlparser::ast::Value::Boolean;
use sqlparser::ast::*;
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect};
use sqlparser::parser::ParserError;
Expand Down Expand Up @@ -782,7 +781,10 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("b".into())],
value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
})],
}
);

Expand All @@ -793,9 +795,7 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Literal(Value::SingleQuotedString(
"b".into()
))],
value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
}
);

Expand All @@ -806,7 +806,13 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Literal(number("0"))],
value: vec![Expr::Value(Value::Number(
#[cfg(not(feature = "bigdecimal"))]
"0".to_string(),
#[cfg(feature = "bigdecimal")]
bigdecimal::BigDecimal::from(0),
false,
))],
}
);

Expand All @@ -817,7 +823,10 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("DEFAULT".into())],
value: vec![Expr::Identifier(Ident {
value: "DEFAULT".into(),
quote_style: None
})],
}
);

Expand All @@ -828,7 +837,7 @@ fn parse_set() {
local: true,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("b".into())],
value: vec![Expr::Identifier("b".into())],
}
);

Expand All @@ -839,7 +848,10 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
value: vec![SetVariableValue::Ident("b".into())],
value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
})],
}
);

Expand All @@ -859,7 +871,7 @@ fn parse_set() {
Ident::new("reducer"),
Ident::new("parallelism")
]),
value: vec![SetVariableValue::Literal(Boolean(false))],
value: vec![Expr::Value(Value::Boolean(false))],
}
);

Expand Down Expand Up @@ -1107,7 +1119,7 @@ fn parse_pg_unary_ops() {
];

for (str_op, op) in pg_unary_ops {
let select = pg().verified_only_select(&format!("SELECT {} a", &str_op));
let select = pg().verified_only_select(&format!("SELECT {}a", &str_op));
assert_eq!(
SelectItem::UnnamedExpr(Expr::UnaryOp {
op: op.clone(),
Expand Down

0 comments on commit 29d0509

Please sign in to comment.