Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(binder): Incorrect cast when specifying columns #8770

Merged
merged 11 commits into from
Apr 5, 2023
24 changes: 24 additions & 0 deletions e2e_test/batch/basic/dml.slt.part
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t1 (v1 real, v2 int, v3 varchar);


# Insert

statement ok
insert into t1 (v2, v1, v3) values (1, 2, 'a'), (3, 4, 'b');

query RI rowsort
select * from t1;
----
2 1 a
4 3 b

statement ok
insert into t1 (v2, v1) values (1, 2), (3, 4);

statement ok
insert into t1 values (1, 2), (3, 4);

statement ok
drop table t1;

statement ok
create table t (v1 real, v2 int);

Expand Down
6 changes: 3 additions & 3 deletions src/frontend/planner_test/tests/testdata/insert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@
- name: To many target columns
sql: |
create table t (v1 int, v2 int);
insert into t (v1, v2, v2) values (5, 6);
binder_error: 'Bind error: INSERT has more target columns than values'
insert into t (v1, v2) values (5);
binder_error: 'Bind error: INSERT has more target columns than expressions'
- name: Not enough target columns
sql: |
create table t (v1 int, v2 int);
insert into t (v1) values (5, 6);
binder_error: 'Bind error: INSERT has less target columns than values'
binder_error: 'Bind error: INSERT has more expressions than target columns'
- name: insert literal null
sql: |
create table t(v1 int);
Expand Down
239 changes: 132 additions & 107 deletions src/frontend/src/binder/insert.rs
xxchan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use itertools::Itertools;
use risingwave_common::catalog::{Schema, TableVersionId};
use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem, SetExpr};
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};

use super::statement::RewriteExprsRecursive;
use super::{BoundQuery, BoundSetExpr};
use super::BoundQuery;
use crate::binder::Binder;
use crate::catalog::TableId;
use crate::expr::{ExprImpl, InputRef};
Expand Down Expand Up @@ -89,7 +89,7 @@ impl Binder {
pub(super) fn bind_insert(
&mut self,
name: ObjectName,
columns: Vec<Ident>,
cols_to_insert_by_user: Vec<Ident>,
source: Query,
returning_items: Vec<SelectItem>,
) -> Result<BoundInsert> {
Expand All @@ -100,16 +100,11 @@ impl Binder {
let table_id = table_catalog.id;
let owner = table_catalog.owner;
let table_version_id = table_catalog.version_id().expect("table must be versioned");
let columns_to_insert = table_catalog.columns_to_insert().cloned().collect_vec();

let expected_types: Vec<DataType> = columns_to_insert
xxchan marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.map(|c| c.data_type().clone())
.collect();
let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();

let generated_column_names: HashSet<_> = table_catalog.generated_column_names().collect();
for query_col in &columns {
let query_col_name = query_col.real_value();
for col in &cols_to_insert_by_user {
let query_col_name = col.real_value();
if generated_column_names.contains(query_col_name.as_str()) {
return Err(RwError::from(ErrorCode::BindError(format!(
"cannot insert a non-DEFAULT value into column \"{0}\". Column \"{0}\" is a generated column.",
Expand All @@ -135,57 +130,53 @@ impl Binder {
}
};

// When the column types of `source` query do not match `expected_types`, casting is
// needed.
let (returning_list, fields) = self.bind_returning_list(returning_items)?;
let is_returning = !returning_list.is_empty();

let col_indices_to_insert = get_col_indices_to_insert(
&cols_to_insert_in_table,
&cols_to_insert_by_user,
&table_name,
)?;
let expected_types: Vec<DataType> = col_indices_to_insert
.iter()
.map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
.collect();

// When the column types of `source` query do not match `expected_types`,
// casting is needed.
//
// In PG, when the `source` is a `VALUES` without order / limit / offset, special treatment
// is given and it is NOT equivalent to assignment cast over potential implicit cast inside.
// For example, the following is valid:
//
// ```
// create table t (v1 time);
// insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
// ```
//
// But the followings are not:
//
// ```
// values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
// insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05') limit 1;
// ```
//
// Because `timestamp` can cast to `time` in assignment context, but no casting between them
// is allowed implicitly.
//
// In this case, assignment cast should be used directly in `VALUES`, suppressing its
// internal implicit cast.
// In other cases, the `source` query is handled on its own and assignment cast is done
// afterwards.
let (source, cast_exprs, nulls_inserted) = match source {
Query {
with: None,
body: SetExpr::Values(values),
order_by: order,
limit: None,
offset: None,
fetch: None,
} if order.is_empty() => {
let (values, nulls_inserted) =
self.bind_values(values, Some(expected_types.clone()))?;
xxchan marked this conversation as resolved.
Show resolved Hide resolved
let body = BoundSetExpr::Values(values.into());
(
BoundQuery {
body,
order: vec![],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: vec![],
},
vec![],
nulls_inserted,
)
}
query => {
let bound = self.bind_query(query)?;
let actual_types = bound.data_types();
let cast_exprs = match expected_types == actual_types {
let bound_query;
let cast_exprs;

match source.as_simple_values() {
None => {
bound_query = self.bind_query(source)?;
let actual_types = bound_query.data_types();
cast_exprs = match expected_types == actual_types {
true => vec![],
false => Self::cast_on_insert(
&expected_types,
Expand All @@ -196,71 +187,45 @@ impl Binder {
.collect(),
)?,
};
(bound, cast_exprs, false)
}
};

let mut target_table_col_indices: Vec<usize> = vec![];
'outer: for query_column in &columns {
let column_name = query_column.real_value();
for (col_idx, table_column) in columns_to_insert.iter().enumerate() {
if column_name == table_column.name() {
target_table_col_indices.push(col_idx);
continue 'outer;
Some(values) => {
assert!(!values.0.is_empty());
let num_value_cols = values.0[0].len();
let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
let num_target_cols = if has_user_specified_columns {
cols_to_insert_by_user.len()
} else {
cols_to_insert_in_table.len()
};
let err_msg = match num_target_cols.cmp(&num_value_cols) {
std::cmp::Ordering::Equal => None,
std::cmp::Ordering::Greater => {
if has_user_specified_columns {
// e.g. insert into t (v1, v2) values (7)
Some("INSERT has more target columns than expressions")
} else {
// e.g. create table t (a int, b real)
// insert into t values (7)
// this kind of usage is fine, null values will be provided
// implicitly.
None
}
}
std::cmp::Ordering::Less => {
// e.g. create table t (a int, b real)
// insert into t (v1) values (7, 13)
// or insert into t values (7, 13, 17)
Some("INSERT has more expressions than target columns")
}
};
if let Some(msg) = err_msg {
return Err(RwError::from(ErrorCode::BindError(msg.to_string())));
}
}
// Invalid column name found
return Err(RwError::from(ErrorCode::BindError(format!(
"Column {} not found in table {}",
column_name, table_name
))));
}

// create table t1 (v1 int, v2 int); insert into t1 (v2) values (5);
// We added the null values above. Above is equivalent to
// insert into t1 values (NULL, 5);
let target_table_col_indices = if !target_table_col_indices.is_empty() && nulls_inserted {
let provided_insert_cols: HashSet<usize> =
target_table_col_indices.iter().cloned().collect();

let mut result: Vec<usize> = target_table_col_indices.clone();
for i in 0..columns_to_insert.len() {
if !provided_insert_cols.contains(&i) {
result.push(i);
}
let values = self.bind_values(values.clone(), Some(expected_types))?;
bound_query = BoundQuery::with_values(values);
cast_exprs = vec![];
}
result
} else {
target_table_col_indices
};

let (returning_list, fields) = self.bind_returning_list(returning_items)?;
let is_returning = !returning_list.is_empty();
// validate that query has a value for each target column, if target columns are used
// create table t1 (v1 int, v2 int);
// insert into t1 (v1, v2, v2) values (5, 6); // ...more target columns than values
// insert into t1 (v1) values (5, 6); // ...less target columns than values
let err_msg = match target_table_col_indices.len().cmp(&expected_types.len()) {
std::cmp::Ordering::Equal => None,
std::cmp::Ordering::Greater => Some("INSERT has more target columns than values"),
std::cmp::Ordering::Less => Some("INSERT has less target columns than values"),
};

if let Some(msg) = err_msg && !target_table_col_indices.is_empty() {
return Err(RwError::from(ErrorCode::BindError(
msg.to_string(),
)));
}

// Check if column was used multiple times in query e.g.
// insert into t1 (v1, v1) values (1, 5);
let mut uniq_cols = target_table_col_indices.clone();
uniq_cols.sort_unstable();
uniq_cols.dedup();
if target_table_col_indices.len() != uniq_cols.len() {
return Err(RwError::from(ErrorCode::BindError(
"Column specified more than once".to_string(),
)));
}

let insert = BoundInsert {
Expand All @@ -269,8 +234,8 @@ impl Binder {
table_name,
owner,
row_id_index,
column_indices: target_table_col_indices,
source,
column_indices: col_indices_to_insert,
source: bound_query,
cast_exprs,
returning_list,
returning_schema: if is_returning {
Expand Down Expand Up @@ -302,3 +267,63 @@ impl Binder {
Err(ErrorCode::BindError(msg.into()).into())
}
}

/// Returned indices have the same length as `cols_to_insert_in_table`.
/// The first elements have the same order as `cols_to_insert_by_user`.
/// The rest are what's not specified by the user.
///
/// Also checks there are no duplicate nor unknown columns provided by the user.
fn get_col_indices_to_insert(
cols_to_insert_in_table: &[ColumnCatalog],
cols_to_insert_by_user: &[Ident],
table_name: &str,
) -> Result<Vec<usize>> {
if cols_to_insert_by_user.is_empty() {
return Ok((0..cols_to_insert_in_table.len()).collect());
}

let mut col_indices_to_insert: Vec<usize> = Vec::new();

let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
col_name_to_idx.insert(col.name().to_string(), col_idx);
}

for col_name in cols_to_insert_by_user {
let col_name = &col_name.real_value();
match col_name_to_idx.get_mut(col_name) {
Some(value_ref) => {
if *value_ref == usize::MAX {
return Err(RwError::from(ErrorCode::BindError(
"Column specified more than once".to_string(),
)));
}
col_indices_to_insert.push(*value_ref);
*value_ref = usize::MAX; // mark this column name, for duplicate
// detection
}
None => {
// Invalid column name found
return Err(RwError::from(ErrorCode::BindError(format!(
"Column {} not found in table {}",
col_name, table_name
))));
}
}
}

// columns that are in the target table but not in the provided target columns
if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
for col in cols_to_insert_in_table {
if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
if *col_to_insert_idx != usize::MAX {
col_indices_to_insert.push(*col_to_insert_idx);
}
} else {
unreachable!();
}
}
}

Ok(col_indices_to_insert)
}
13 changes: 13 additions & 0 deletions src/frontend/src/binder/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With};

use super::statement::RewriteExprsRecursive;
use super::BoundValues;
use crate::binder::{Binder, BoundSetExpr};
use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};

Expand Down Expand Up @@ -95,6 +96,18 @@ impl BoundQuery {
self.body
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
}

/// Simple `VALUES` without other clauses.
pub fn with_values(values: BoundValues) -> Self {
BoundQuery {
body: BoundSetExpr::Values(values.into()),
order: vec![],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: vec![],
}
}
}

impl RewriteExprsRecursive for BoundQuery {
Expand Down