Skip to content

Commit

Permalink
support CastError
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME committed Feb 22, 2023
1 parent 8a242fe commit bf346bc
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 45 deletions.
3 changes: 2 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl Binder {
Expr::TypedString { data_type, value } => {
let s: ExprImpl = self.bind_string(value)?.into();
s.cast_explicit(bind_data_type(&data_type)?)
.map_err(Into::into)
}
Expr::Row(exprs) => self.bind_row(exprs),
// input ref
Expand Down Expand Up @@ -430,7 +431,7 @@ impl Binder {
return self.bind_array_cast(expr.clone(), data_type);
}
let lhs = self.bind_expr(expr)?;
lhs.cast_explicit(data_type)
lhs.cast_explicit(data_type).map_err(Into::into)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Binder {
},
)
.into();
return lhs.cast_explicit(ty);
return lhs.cast_explicit(ty).map_err(Into::into);
}
let inner_type = if let DataType::List { datatype } = &ty {
*datatype.clone()
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl Binder {
return exprs
.into_iter()
.zip_eq_fast(expected_types)
.map(|(e, t)| e.cast_assign(t.clone()))
.map(|(e, t)| e.cast_assign(t.clone()).map_err(Into::into))
.try_collect();
}
std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
Expand Down
70 changes: 48 additions & 22 deletions src/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

use itertools::Itertools;
use risingwave_common::catalog::Schema;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{ErrorCode, Result as RwResult, RwError};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::vector_op::cast::literal_parsing;
use thiserror::Error;

use super::{cast_ok, infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal};
use crate::expr::{ExprDisplay, ExprType};
Expand Down Expand Up @@ -99,7 +100,7 @@ impl FunctionCall {
// The functions listed here are all variadic. Type signatures of functions that take a fixed
// number of arguments are checked
// [elsewhere](crate::expr::type_inference::build_type_derive_map).
pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> Result<Self> {
pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> RwResult<Self> {
let return_type = infer_type(func_type, &mut inputs)?;
Ok(Self {
func_type,
Expand All @@ -109,12 +110,16 @@ impl FunctionCall {
}

/// Create a cast expr over `child` to `target` type in `allows` context.
pub fn new_cast(child: ExprImpl, target: DataType, allows: CastContext) -> Result<ExprImpl> {
pub fn new_cast(
child: ExprImpl,
target: DataType,
allows: CastContext,
) -> Result<ExprImpl, CastError> {
if is_row_function(&child) {
// Row function will have empty fields in Datatype::Struct at this point. Therefore,
// we will need to take some special care to generate the cast types. For normal struct
// types, they will be handled in `cast_ok`.
return Self::cast_nested(child, target, allows);
return Self::cast_row_expr(child, target, allows);
}
if child.is_unknown() {
// `is_unknown` makes sure `as_literal` and `as_utf8` will never panic.
Expand Down Expand Up @@ -146,46 +151,51 @@ impl FunctionCall {
}
.into())
} else {
Err(ErrorCode::BindError(format!(
Err(CastError(format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
source, target, allows
))
.into())
)))
}
}

/// Cast a `ROW` expression to the target type. We intentionally disallow casting arbitrary
/// expressions, like `ROW(1)::STRUCT<i INTEGER>` to `STRUCT<VARCHAR>`, although an integer
/// is castible to VARCHAR. It's to simply the casting rules.
fn cast_nested(expr: ExprImpl, target_type: DataType, allows: CastContext) -> Result<ExprImpl> {
fn cast_row_expr(
expr: ExprImpl,
target_type: DataType,
allows: CastContext,
) -> Result<ExprImpl, CastError> {
let func = *expr.into_function_call().unwrap();
let (fields, field_names) = if let DataType::Struct(t) = &target_type {
(t.fields.clone(), t.field_names.clone())
} else {
return Err(ErrorCode::BindError(format!(
"column is of type '{}' but expression is of type record",
target_type
))
.into());
return Err(CastError(format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
func.return_type(),
target_type,
allows
)));
};
let (func_type, inputs, _) = func.decompose();
let msg = match fields.len().cmp(&inputs.len()) {
match fields.len().cmp(&inputs.len()) {
std::cmp::Ordering::Equal => {
let inputs = inputs
.into_iter()
.zip_eq_fast(fields.to_vec())
.map(|(e, t)| Self::new_cast(e, t, allows))
.collect::<Result<Vec<_>>>()?;
.collect::<Result<Vec<_>, CastError>>()?;
let return_type = DataType::new_struct(
inputs.iter().map(|i| i.return_type()).collect_vec(),
field_names,
);
return Ok(FunctionCall::new_unchecked(func_type, inputs, return_type).into());
Ok(FunctionCall::new_unchecked(func_type, inputs, return_type).into())
}
std::cmp::Ordering::Less => "Input has too few columns.",
std::cmp::Ordering::Greater => "Input has too many columns.",
};
Err(ErrorCode::BindError(format!("cannot cast record to {} ({})", target_type, msg)).into())
std::cmp::Ordering::Less => Err(CastError("Input has too few columns.".to_string())),
std::cmp::Ordering::Greater => {
Err(CastError("Input has too many columns.".to_string()))
}
}
}

/// Construct a `FunctionCall` expr directly with the provided `return_type`, bypassing type
Expand All @@ -205,7 +215,7 @@ impl FunctionCall {
pub fn new_binary_op_func(
mut func_types: Vec<ExprType>,
mut inputs: Vec<ExprImpl>,
) -> Result<ExprImpl> {
) -> RwResult<ExprImpl> {
let expr_type = func_types.remove(0);
match expr_type {
ExprType::Some | ExprType::All => {
Expand Down Expand Up @@ -274,7 +284,7 @@ impl FunctionCall {
function_call: &risingwave_pb::expr::FunctionCall,
expr_type: ExprType,
ret_type: DataType,
) -> Result<Self> {
) -> RwResult<Self> {
let inputs: Vec<_> = function_call
.get_children()
.iter()
Expand Down Expand Up @@ -419,3 +429,19 @@ pub fn is_row_function(expr: &ExprImpl) -> bool {
}
false
}

#[derive(Debug, Error)]
#[error("{0}")]
pub struct CastError(String);

impl From<CastError> for ErrorCode {
fn from(value: CastError) -> Self {
ErrorCode::BindError(value.to_string())
}
}

impl From<CastError> for RwError {
fn from(value: CastError) -> Self {
ErrorCode::BindError(value.to_string()).into()
}
}
20 changes: 11 additions & 9 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use enum_as_inner::EnumAsInner;
use fixedbitset::FixedBitSet;
use paste::paste;
use risingwave_common::array::ListValue;
use risingwave_common::error::Result;
use risingwave_common::error::Result as RwResult;
use risingwave_common::types::{DataType, Datum, Scalar};
use risingwave_expr::expr::{build_from_prost, AggKind};
use risingwave_pb::expr::expr_node::RexNode;
Expand Down Expand Up @@ -177,22 +177,22 @@ impl ExprImpl {
}

/// Shorthand to create cast expr to `target` type in implicit context.
pub fn cast_implicit(self, target: DataType) -> Result<ExprImpl> {
pub fn cast_implicit(self, target: DataType) -> Result<ExprImpl, CastError> {
FunctionCall::new_cast(self, target, CastContext::Implicit)
}

/// Shorthand to create cast expr to `target` type in assign context.
pub fn cast_assign(self, target: DataType) -> Result<ExprImpl> {
pub fn cast_assign(self, target: DataType) -> Result<ExprImpl, CastError> {
FunctionCall::new_cast(self, target, CastContext::Assign)
}

/// Shorthand to create cast expr to `target` type in explicit context.
pub fn cast_explicit(self, target: DataType) -> Result<ExprImpl> {
pub fn cast_explicit(self, target: DataType) -> Result<ExprImpl, CastError> {
FunctionCall::new_cast(self, target, CastContext::Explicit)
}

/// Shorthand to enforce implicit cast to boolean
pub fn enforce_bool_clause(self, clause: &str) -> Result<ExprImpl> {
pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
if self.is_unknown() {
let inner = self.cast_implicit(DataType::Boolean)?;
return Ok(inner);
Expand All @@ -218,26 +218,27 @@ impl ExprImpl {
/// References in `PostgreSQL`:
/// * [cast](https://github.com/postgres/postgres/blob/a3ff08e0b08dbfeb777ccfa8f13ebaa95d064c04/src/include/catalog/pg_cast.dat#L437-L444)
/// * [impl](https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/backend/utils/adt/bool.c#L204-L209)
pub fn cast_output(self) -> Result<ExprImpl> {
pub fn cast_output(self) -> RwResult<ExprImpl> {
if self.return_type() == DataType::Boolean {
return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
}
// Use normal cast for other types. Both `assign` and `explicit` can pass the castability
// check and there is no difference.
self.cast_assign(DataType::Varchar)
.map_err(|err| err.into())
}

/// Evaluate the expression on the given input.
///
/// TODO: This is a naive implementation. We should avoid proto ser/de.
/// Tracking issue: <https://github.com/risingwavelabs/risingwave/issues/3479>
fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
let backend_expr = build_from_prost(&self.to_expr_proto())?;
backend_expr.eval_row(input).map_err(Into::into)
}

/// Evaluate a constant expression.
pub fn eval_row_const(&self) -> Result<Datum> {
pub fn eval_row_const(&self) -> RwResult<Datum> {
assert!(self.is_const());
self.eval_row(&OwnedRow::empty())
}
Expand Down Expand Up @@ -728,7 +729,7 @@ impl ExprImpl {
}
}

pub fn from_expr_proto(proto: &ExprNode) -> Result<Self> {
pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
let rex_node = proto.get_rex_node()?;
let ret_type = proto.get_return_type()?.into();
let expr_type = proto.get_expr_type()?;
Expand Down Expand Up @@ -892,6 +893,7 @@ use risingwave_common::bail;
use risingwave_common::catalog::Schema;
use risingwave_common::row::OwnedRow;

use self::function_call::CastError;
use crate::binder::BoundSetExpr;
use crate::utils::Condition;

Expand Down
6 changes: 3 additions & 3 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use risingwave_common::types::{unnested_list_type, DataType, ScalarImpl};
use risingwave_pb::expr::table_function::Type;
use risingwave_pb::expr::TableFunction as TableFunctionProst;

use super::{Expr, ExprImpl, ExprRewriter, Result};
use super::{Expr, ExprImpl, ExprRewriter, RwResult};

/// A table function takes a row as input and returns a table. It is also known as Set-Returning
/// Function.
Expand Down Expand Up @@ -85,7 +85,7 @@ impl FromStr for TableFunctionType {
impl TableFunction {
/// Create a `TableFunction` expr with the return type inferred from `func_type` and types of
/// `inputs`.
pub fn new(func_type: TableFunctionType, args: Vec<ExprImpl>) -> Result<Self> {
pub fn new(func_type: TableFunctionType, args: Vec<ExprImpl>) -> RwResult<Self> {
// TODO: refactor into sth like FunctionCall::new.
// Current implementation is copied from legacy code.

Expand All @@ -94,7 +94,7 @@ impl TableFunction {
// generate_series ( start timestamp, stop timestamp, step interval ) or
// generate_series ( start i32, stop i32, step i32 )

fn type_check(exprs: &[ExprImpl]) -> Result<DataType> {
fn type_check(exprs: &[ExprImpl]) -> RwResult<DataType> {
let mut exprs = exprs.iter();
let (start, stop, step) = exprs.next_tuple().unwrap();
match (start.return_type(), stop.return_type(), step.return_type()) {
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/src/expr/type_inference/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ pub fn align_array_and_element(
.enumerate()
.map(|(idx, input)| {
if idx == array_idx {
input.cast_implicit(array_type.clone())
input.cast_implicit(array_type.clone()).map_err(Into::into)
} else {
input.cast_implicit(common_ele_type.clone())
input
.cast_implicit(common_ele_type.clone())
.map_err(Into::into)
}
})
.try_collect();
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use itertools::Itertools as _;
use num_integer::Integer as _;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::struct_type::StructType;
use risingwave_common::types::{DataType, DataTypeName, ScalarImpl};
use risingwave_common::util::iter_util::ZipEqFast;
Expand Down Expand Up @@ -48,7 +48,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec<ExprImpl>) -> Result<Dat
.map(|(expr, t)| {
if DataTypeName::from(expr.return_type()) != *t {
if t.is_scalar() {
return expr.cast_implicit((*t).into());
return expr.cast_implicit((*t).into()).map_err(Into::into);
} else {
return Err(ErrorCode::BindError(format!(
"Cannot implicitly cast '{:?}' to polymorphic type {:?}",
Expand All @@ -59,7 +59,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec<ExprImpl>) -> Result<Dat
}
Ok(expr)
})
.try_collect()?;
.try_collect::<_, _, RwError>()?;
Ok(sig.ret_type.into())
}

Expand Down Expand Up @@ -318,7 +318,7 @@ fn infer_type_for_special(
.enumerate()
.map(|(i, input)| match i {
// 0-th arg must be string
0 => input.cast_implicit(DataType::Varchar),
0 => input.cast_implicit(DataType::Varchar).map_err(Into::into),
// subsequent can be any type, using the output format
_ => input.cast_output(),
})
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/expr/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use parse_display::Display;
use risingwave_common::error::ErrorCode;
use risingwave_common::types::DataType;

use super::{Expr, ExprImpl, OrderBy, Result};
use super::{Expr, ExprImpl, OrderBy, RwResult};

/// A window function performs a calculation across a set of table rows that are somehow related to
/// the current row, according to the window spec `OVER (PARTITION BY .. ORDER BY ..)`.
Expand Down Expand Up @@ -79,7 +79,7 @@ impl WindowFunction {
partition_by: Vec<ExprImpl>,
order_by: OrderBy,
args: Vec<ExprImpl>,
) -> Result<Self> {
) -> RwResult<Self> {
if !args.is_empty() {
return Err(ErrorCode::BindError(format!(
"the length of args of {function_type} function should be 0"
Expand Down

0 comments on commit bf346bc

Please sign in to comment.