Skip to content

Commit

Permalink
feat(frontend): support update column to default value (#8987)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <i@bugenzhao.com>
  • Loading branch information
BugenZhao committed Apr 4, 2023
1 parent 570a253 commit 80c477f
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 14 deletions.
11 changes: 11 additions & 0 deletions e2e_test/batch/basic/dml.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ select v1, v2 from t order by v2;
45 810
35 1919

statement ok
update t set v1 = DEFAULT where v2 = 810;

query RI
select v1, v2 from t order by v2;
----
114 10
514 20
NULL 810
35 1919

# Delete

statement ok
Expand Down
8 changes: 8 additions & 0 deletions src/frontend/planner_test/tests/testdata/update.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
└─BatchUpdate { table: t, exprs: [$1::Int32, $1, $2] }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id], distribution: UpstreamHashShard(t._row_id) }
- sql: |
create table t (v1 int, v2 real);
update t set v1 = DEFAULT;
batch_plan: |
BatchExchange { order: [], dist: Single }
└─BatchUpdate { table: t, exprs: [null:Int32, $1, $2] }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id], distribution: UpstreamHashShard(t._row_id) }
- sql: |
create table t (v1 int, v2 int);
update t set v1 = v2 + 1 where v2 > 0;
Expand Down
20 changes: 14 additions & 6 deletions src/frontend/src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use itertools::Itertools;
use risingwave_common::catalog::{Schema, TableVersionId};
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_sqlparser::ast::{Assignment, Expr, ObjectName, SelectItem};
use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};

use super::statement::RewriteExprsRecursive;
use super::{Binder, Relation};
Expand Down Expand Up @@ -114,17 +114,19 @@ impl Binder {
}

// (col1, col2) = (subquery)
(_ids, Expr::Subquery(_)) => {
(_ids, AssignmentValue::Expr(Expr::Subquery(_))) => {
return Err(ErrorCode::NotImplemented(
"subquery on the right side of multi-assignment".to_owned(),
None.into(),
)
.into())
}
// (col1, col2) = (expr1, expr2)
(ids, Expr::Row(values)) if ids.len() == values.len() => {
id.into_iter().zip_eq_fast(values.into_iter()).collect()
}
// TODO: support `DEFAULT` in multiple assignments
(ids, AssignmentValue::Expr(Expr::Row(values))) if ids.len() == values.len() => id
.into_iter()
.zip_eq_fast(values.into_iter().map(AssignmentValue::Expr))
.collect(),
// (col1, col2) = <other expr>
_ => {
return Err(ErrorCode::BindError(
Expand All @@ -148,7 +150,13 @@ impl Binder {
}
}

let value_expr = self.bind_expr(value)?.cast_assign(id_expr.return_type())?;
let value_expr = match value {
AssignmentValue::Expr(expr) => {
self.bind_expr(expr)?.cast_assign(id_expr.return_type())?
}
// TODO: specify default expression after we support non-`NULL` default values.
AssignmentValue::Default => ExprImpl::literal_null(id_expr.return_type()),
};

match assignment_exprs.entry(id_expr) {
Entry::Occupied(_) => {
Expand Down
22 changes: 20 additions & 2 deletions src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,12 +1814,30 @@ impl fmt::Display for GrantObjects {
}
}

/// SQL assignment `foo = expr` as used in SQLUpdate
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AssignmentValue {
/// An expression, e.g. `foo = 1`
Expr(Expr),
/// The `DEFAULT` keyword, e.g. `foo = DEFAULT`
Default,
}

impl fmt::Display for AssignmentValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AssignmentValue::Expr(expr) => write!(f, "{}", expr),
AssignmentValue::Default => f.write_str("DEFAULT"),
}
}
}

/// SQL assignment `foo = { expr | DEFAULT }` as used in SQLUpdate
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Assignment {
pub id: Vec<Ident>,
pub value: Expr,
pub value: AssignmentValue,
}

impl fmt::Display for Assignment {
Expand Down
8 changes: 7 additions & 1 deletion src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4035,7 +4035,13 @@ impl Parser {
pub fn parse_assignment(&mut self) -> Result<Assignment, ParserError> {
let id = self.parse_identifiers_non_keywords()?;
self.expect_token(&Token::Eq)?;
let value = self.parse_expr()?;

let value = if self.parse_keyword(Keyword::DEFAULT) {
AssignmentValue::Default
} else {
AssignmentValue::Expr(self.parse_expr()?)
};

Ok(Assignment { id, value })
}

Expand Down
14 changes: 9 additions & 5 deletions src/sqlparser/tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn parse_insert_values() {

#[test]
fn parse_update() {
let sql = "UPDATE t SET a = 1, b = 2, c = 3 WHERE d";
let sql = "UPDATE t SET a = 1, b = 2, c = 3, d = DEFAULT WHERE e";
match verified_stmt(sql) {
Statement::Update {
table_name,
Expand All @@ -106,19 +106,23 @@ fn parse_update() {
vec![
Assignment {
id: vec!["a".into()],
value: Expr::Value(number("1")),
value: AssignmentValue::Expr(Expr::Value(number("1"))),
},
Assignment {
id: vec!["b".into()],
value: Expr::Value(number("2")),
value: AssignmentValue::Expr(Expr::Value(number("2"))),
},
Assignment {
id: vec!["c".into()],
value: Expr::Value(number("3")),
value: AssignmentValue::Expr(Expr::Value(number("3"))),
},
Assignment {
id: vec!["d".into()],
value: AssignmentValue::Default,
}
]
);
assert_eq!(selection.unwrap(), Expr::Identifier("d".into()));
assert_eq!(selection.unwrap(), Expr::Identifier("e".into()));
}
_ => unreachable!(),
}
Expand Down

0 comments on commit 80c477f

Please sign in to comment.