diff --git a/components/tidb_query/src/rpn_expr/impl_cast.rs b/components/tidb_query/src/rpn_expr/impl_cast.rs index 8c0fb692a39..f23769bf87a 100644 --- a/components/tidb_query/src/rpn_expr/impl_cast.rs +++ b/components/tidb_query/src/rpn_expr/impl_cast.rs @@ -11,7 +11,7 @@ use tipb::{Expr, FieldType}; use crate::codec::convert::*; use crate::codec::data_type::*; use crate::codec::error::{ERR_DATA_OUT_OF_RANGE, WARN_DATA_TRUNCATED}; -use crate::codec::mysql::Time; +use crate::codec::mysql::{binary_literal, Time}; use crate::codec::Error; use crate::expr::EvalContext; use crate::rpn_expr::{RpnExpressionNode, RpnFnCallExtra, RpnFnMeta}; @@ -77,10 +77,14 @@ fn get_cast_fn_rpn_meta( } } (EvalType::Bytes, EvalType::Real) => { - if !from_field_type.is_unsigned() { - cast_string_as_signed_real_fn_meta() - } else { - cast_string_as_unsigned_real_fn_meta() + match ( + from_field_type.is_binary_string_like(), + to_field_type.is_unsigned(), + ) { + (true, true) => cast_binary_string_as_unsigned_real_fn_meta(), + (true, false) => cast_binary_string_as_signed_real_fn_meta(), + (false, true) => cast_string_as_unsigned_real_fn_meta(), + (false, false) => cast_string_as_signed_real_fn_meta(), } } (EvalType::Decimal, EvalType::Real) => { @@ -469,8 +473,6 @@ fn cast_string_as_signed_real( match val { None => Ok(None), Some(val) => { - // FIXME: in TiDB's builtinCastStringAsRealSig, if val is IsBinaryLiteral, - // then return evalReal directly let r: f64 = val.convert(ctx)?; let r = produce_float_with_specified_tp(ctx, extra.ret_field_type, r)?; Ok(Real::new(r).ok()) @@ -478,6 +480,23 @@ fn cast_string_as_signed_real( } } +#[rpn_fn(capture = [ctx, extra])] +#[inline] +fn cast_binary_string_as_signed_real( + ctx: &mut EvalContext, + extra: &RpnFnCallExtra, + val: &Option, +) -> Result> { + match val { + None => Ok(None), + Some(val) => { + let r = binary_literal::to_uint(ctx, val)? as i64 as f64; + let r = produce_float_with_specified_tp(ctx, extra.ret_field_type, r)?; + Ok(Real::new(r).ok()) + } + } +} + #[rpn_fn(capture = [ctx, extra, metadata], metadata_type = tipb::InUnionMetadata)] #[inline] fn cast_string_as_unsigned_real( @@ -489,14 +508,28 @@ fn cast_string_as_unsigned_real( match val { None => Ok(None), Some(val) => { - // FIXME: in TiDB's builtinCastStringAsRealSig, if val is IsBinaryLiteral, - // then return evalReal directly let mut r: f64 = val.convert(ctx)?; if metadata.get_in_union() && r < 0f64 { r = 0f64; } let r = produce_float_with_specified_tp(ctx, extra.ret_field_type, r)?; - // FIXME: negative number to unsigned real's logic may be wrong here. + Ok(Real::new(r).ok()) + } + } +} + +#[rpn_fn(capture = [ctx, extra])] +#[inline] +fn cast_binary_string_as_unsigned_real( + ctx: &mut EvalContext, + extra: &RpnFnCallExtra, + val: &Option, +) -> Result> { + match val { + None => Ok(None), + Some(val) => { + let r = binary_literal::to_uint(ctx, val)? as f64; + let r = produce_float_with_specified_tp(ctx, extra.ret_field_type, r)?; Ok(Real::new(r).ok()) } } @@ -1254,7 +1287,7 @@ mod tests { metadata } - struct RetFieldTypeConfig { + struct FieldTypeConfig { unsigned: bool, flen: isize, decimal: isize, @@ -1263,9 +1296,9 @@ mod tests { collation: Option, } - impl Default for RetFieldTypeConfig { + impl Default for FieldTypeConfig { fn default() -> Self { - RetFieldTypeConfig { + FieldTypeConfig { unsigned: false, flen: UNSPECIFIED_LENGTH, decimal: UNSPECIFIED_LENGTH, @@ -1276,8 +1309,8 @@ mod tests { } } - impl From for FieldType { - fn from(config: RetFieldTypeConfig) -> Self { + impl From for FieldType { + fn from(config: FieldTypeConfig) -> Self { let mut ft = FieldType::default(); if let Some(c) = config.charset { ft.set_charset(String::from(c)); @@ -1657,9 +1690,9 @@ mod tests { } .into(); let metadata = make_metadata(cond.in_union()); - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { unsigned: cond.is_unsigned(), - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let extra = make_extra(&rft); @@ -2339,9 +2372,17 @@ mod tests { } #[test] - fn test_string_as_signed_real() { - test_none_with_ctx_and_extra(cast_string_as_signed_real); + fn test_cast_string_as_real() { + // None + { + let output: Option = RpnFnScalarEvaluator::new() + .push_param(ScalarValue::Bytes(None)) + .evaluate(ScalarFuncSig::CastStringAsReal) + .unwrap(); + assert_eq!(output, None); + } + // signed let ul = UNSPECIFIED_LENGTH; let cs: Vec<(String, f64, isize, isize, bool, bool)> = vec![ // (input, expect, flen, decimal, truncated, overflow) @@ -2376,49 +2417,48 @@ mod tests { (String::from("-1234abc"), -0.9f64, 1, 1, true, true), ]; - for (input, expect, flen, decimal, truncated, overflow) in cs { - let mut ctx = CtxConfig { - overflow_as_warning: true, - truncate_as_warning: true, - ..CtxConfig::default() - } - .into(); - let rft = RetFieldTypeConfig { - unsigned: false, - flen, - decimal, - ..RetFieldTypeConfig::default() - } - .into(); - let extra = make_extra(&rft); - let r = cast_string_as_signed_real(&mut ctx, &extra, &Some(input.clone().into_bytes())); - let r = r.map(|x| x.map(|x| x.into_inner())); - let log = format!( - "input: {}, expect: {}, flen: {}, decimal: {}, expect_truncated: {}, expect_overflow: {}", - input.as_str(), expect, flen, decimal, truncated, overflow + for (input, expected, flen, decimal, truncated, overflow) in cs { + let (result, ctx) = RpnFnScalarEvaluator::new() + .context(CtxConfig { + overflow_as_warning: true, + truncate_as_warning: true, + ..CtxConfig::default() + }) + .push_param(input.clone().into_bytes()) + .evaluate_raw( + FieldTypeConfig { + unsigned: false, + flen, + decimal, + tp: Some(FieldTypeTp::Double), + ..FieldTypeConfig::default() + }, + ScalarFuncSig::CastStringAsReal, + ); + let output: Option = result.unwrap().into(); + assert!( + (output.unwrap().into_inner() - expected).abs() < std::f64::EPSILON, + "input={:?}", + input ); - check_result(Some(&expect), &r, log.as_str()); - match (truncated, overflow) { - (true, true) => { - assert_eq!(ctx.warnings.warning_cnt, 2, "{}", log.as_str()); - let a = ctx.warnings.warnings[0].get_code(); - let b = ctx.warnings.warnings[1].get_code(); - let (a, b) = if a > b { (b, a) } else { (a, b) }; - assert_eq!(a, ERR_TRUNCATE_WRONG_VALUE, "{}", log.as_str()); - assert_eq!(b, ERR_DATA_OUT_OF_RANGE, "{}", log.as_str()); - } - (true, false) => check_warning(&ctx, Some(ERR_TRUNCATE_WRONG_VALUE), log.as_str()), - (false, true) => check_overflow(&ctx, true, log.as_str()), - _ => (), - } + let (warning_cnt, warnings) = match (truncated, overflow) { + (true, true) => (2, vec![ERR_TRUNCATE_WRONG_VALUE, ERR_DATA_OUT_OF_RANGE]), + (true, false) => (1, vec![ERR_TRUNCATE_WRONG_VALUE]), + (false, true) => (1, vec![ERR_DATA_OUT_OF_RANGE]), + _ => (0, vec![]), + }; + assert_eq!(ctx.warnings.warning_cnt, warning_cnt); + let mut got_warnings = ctx + .warnings + .warnings + .iter() + .map(|w| w.get_code()) + .collect::>(); + got_warnings.sort(); + assert_eq!(got_warnings, warnings); } - } - - #[test] - fn test_string_as_unsigned_real() { - test_none_with_ctx_and_extra_and_metadata(cast_string_as_unsigned_real); - let ul = UNSPECIFIED_LENGTH; + // unsigned let cs: Vec<(String, f64, isize, isize, bool, bool, bool)> = vec![ // (input, expect, flen, decimal, truncated, overflow, in_union) @@ -2559,46 +2599,50 @@ mod tests { (String::from("-1234abc"), 0.0, ul, ul, true, false, true), ]; - for (input, expect, flen, decimal, truncated, overflow, in_union) in cs { - let mut ctx = CtxConfig { - overflow_as_warning: true, - truncate_as_warning: true, - ..CtxConfig::default() - } - .into(); - let metadata = make_metadata(in_union); - let rft = RetFieldTypeConfig { - unsigned: true, - flen, - decimal, - ..RetFieldTypeConfig::default() - } - .into(); - let extra = make_extra(&rft); - - let p = Some(input.clone().into_bytes()); - let r = cast_string_as_unsigned_real(&mut ctx, &extra, &metadata, &p); - let r = r.map(|x| x.map(|x| x.into_inner())); - - let log = format!( - "input: {}, expect: {}, flen: {}, decimal: {}, expect_truncated: {}, expect_overflow: {}, in_union: {}", - input.as_str(), expect, flen, decimal, truncated, overflow, in_union + for (input, expected, flen, decimal, truncated, overflow, in_union) in cs { + let (result, ctx) = RpnFnScalarEvaluator::new() + .context(CtxConfig { + overflow_as_warning: true, + truncate_as_warning: true, + ..CtxConfig::default() + }) + .metadata(Box::new(make_metadata(in_union))) + .push_param(input.clone().into_bytes()) + .evaluate_raw( + FieldTypeConfig { + unsigned: true, + flen, + decimal, + tp: Some(FieldTypeTp::Double), + ..FieldTypeConfig::default() + }, + ScalarFuncSig::CastStringAsReal, + ); + let output: Option = result.unwrap().into(); + assert!( + (output.unwrap().into_inner() - expected).abs() < std::f64::EPSILON, + "input:{:?}, expected:{:?}, flen:{:?}, decimal:{:?}, truncated:{:?}, overflow:{:?}, in_union:{:?}", + input, expected, flen, decimal, truncated, overflow, in_union ); - - check_result(Some(&expect), &r, log.as_str()); - match (truncated, overflow) { - (true, true) => { - assert_eq!(ctx.warnings.warning_cnt, 2, "{}", log.as_str()); - let a = ctx.warnings.warnings[0].get_code(); - let b = ctx.warnings.warnings[1].get_code(); - let (a, b) = if a > b { (b, a) } else { (a, b) }; - assert_eq!(a, ERR_TRUNCATE_WRONG_VALUE, "{}", log.as_str()); - assert_eq!(b, ERR_DATA_OUT_OF_RANGE, "{}", log.as_str()); - } - (true, false) => check_warning(&ctx, Some(ERR_TRUNCATE_WRONG_VALUE), log.as_str()), - (false, true) => check_overflow(&ctx, true, log.as_str()), - _ => (), - } + let (warning_cnt, warnings) = match (truncated, overflow) { + (true, true) => (2, vec![ERR_TRUNCATE_WRONG_VALUE, ERR_DATA_OUT_OF_RANGE]), + (true, false) => (1, vec![ERR_TRUNCATE_WRONG_VALUE]), + (false, true) => (1, vec![ERR_DATA_OUT_OF_RANGE]), + _ => (0, vec![]), + }; + let mut got_warnings = ctx + .warnings + .warnings + .iter() + .map(|w| w.get_code()) + .collect::>(); + got_warnings.sort(); + assert_eq!( + ctx.warnings.warning_cnt, warning_cnt, + "input:{:?}, expected:{:?}, flen:{:?}, decimal:{:?}, truncated:{:?}, overflow:{:?}, in_union:{:?}, warnings:{:?}", + input, expected, flen, decimal, truncated, overflow, in_union,got_warnings, + ); + assert_eq!(got_warnings, warnings); } // not in union, neg @@ -2675,48 +2719,80 @@ mod tests { ], ), ]; - for (input, expect, flen, decimal, err_codes) in cs { - let mut ctx = CtxConfig { - overflow_as_warning: true, - truncate_as_warning: true, - ..CtxConfig::default() - } - .into(); - let metadata = make_metadata(false); - let rft = RetFieldTypeConfig { - unsigned: true, - flen, - decimal, - ..RetFieldTypeConfig::default() - } - .into(); - let extra = make_extra(&rft); - - let p = Some(input.clone().into_bytes()); - let r = cast_string_as_unsigned_real(&mut ctx, &extra, &metadata, &p); - let r = r.map(|x| x.map(|x| x.into_inner())); - let log = format!( - "input: {}, expect: {}, flen: {}, decimal: {}, err_code: {:?}", - input.as_str(), - expect, - flen, - decimal, - err_codes - ); - check_result(Some(&expect), &r, log.as_str()); - assert_eq!( - ctx.warnings.warning_cnt, - err_codes.len(), - "{}", - log.as_str() + for (input, expected, flen, decimal, err_codes) in cs { + let (result, ctx) = RpnFnScalarEvaluator::new() + .context(CtxConfig { + overflow_as_warning: true, + truncate_as_warning: true, + ..CtxConfig::default() + }) + .metadata(Box::new(make_metadata(false))) + .push_param(input.clone().into_bytes()) + .evaluate_raw( + FieldTypeConfig { + unsigned: true, + flen, + decimal, + tp: Some(FieldTypeTp::Double), + ..FieldTypeConfig::default() + }, + ScalarFuncSig::CastStringAsReal, + ); + let output: Option = result.unwrap().into(); + assert!( + (output.unwrap().into_inner() - expected).abs() < std::f64::EPSILON, + "input={:?}", + input ); + + assert_eq!(ctx.warnings.warning_cnt, err_codes.len()); for (idx, err) in err_codes.iter().enumerate() { assert_eq!( ctx.warnings.warnings[idx].get_code(), *err, - "{}", - log.as_str() + "input: {:?}", + input + ); + } + } + + // binary literal + let cases = vec![ + (vec![0x01, 0x02, 0x03], Some(f64::from(0x010203))), + (vec![0x01, 0x02, 0x03, 0x4], Some(f64::from(0x01020304))), + ( + vec![0x01, 0x02, 0x03, 0x4, 0x05, 0x06, 0x06, 0x06, 0x06], + None, + ), + ]; + for (input, expected) in cases { + let output: Result> = RpnFnScalarEvaluator::new() + .metadata(Box::new(make_metadata(false))) + .return_field_type(FieldTypeConfig { + flen: tidb_query_datatype::UNSPECIFIED_LENGTH, + decimal: tidb_query_datatype::UNSPECIFIED_LENGTH, + tp: Some(FieldTypeTp::Double), + ..FieldTypeConfig::default() + }) + .push_param_with_field_type( + input.clone(), + FieldTypeConfig { + tp: Some(FieldTypeTp::VarString), + collation: Some(Collation::Binary), + ..FieldTypeConfig::default() + }, ) + .evaluate(ScalarFuncSig::CastStringAsReal); + + if let Some(exp) = expected { + assert!(output.is_ok(), "input: {:?}", input); + assert!( + (output.unwrap().unwrap().into_inner() - exp).abs() < std::f64::EPSILON, + "input={:?}", + input + ); + } else { + assert!(output.is_err()); } } } @@ -3149,12 +3225,12 @@ mod tests { FlenType::ExtraOne => (res_len + 1) as isize, FlenType::Unspecified => UNSPECIFIED_LENGTH, }; - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { flen, charset: Some(charset), tp: Some(*tp), collation: Some(*collation), - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let extra = make_extra(&rft); @@ -3847,11 +3923,11 @@ mod tests { let ctx_in_dml_flag = vec![Flag::IN_INSERT_STMT, Flag::IN_UPDATE_OR_DELETE_STMT]; for in_dml_flag in ctx_in_dml_flag { let (res_flen, res_decimal) = (res_flen as isize, res_decimal as isize); - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { unsigned: is_unsigned, flen: res_flen, decimal: res_decimal, - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let metadata = make_metadata(in_union); @@ -4895,9 +4971,9 @@ mod tests { ..CtxConfig::default() } .into(); - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { decimal: fsp, - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let extra = make_extra(&rft); @@ -4951,9 +5027,9 @@ mod tests { ..CtxConfig::default() } .into(); - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { decimal: fsp as isize, - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let extra = make_extra(&rft); @@ -5146,9 +5222,9 @@ mod tests { for (s, fsp, expect_fsp, expect) in cs { let mut ctx = EvalContext::default(); - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { decimal: expect_fsp, - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let extra = make_extra(&rft); @@ -5180,9 +5256,9 @@ mod tests { ]; for (input, input_fsp, output_fsp, expect) in cs { - let rft = RetFieldTypeConfig { + let rft = FieldTypeConfig { decimal: output_fsp as isize, - ..RetFieldTypeConfig::default() + ..FieldTypeConfig::default() } .into(); let extra = make_extra(&rft); diff --git a/components/tidb_query/src/rpn_expr/types/test_util.rs b/components/tidb_query/src/rpn_expr/types/test_util.rs index ada3a095a0c..ed9aa009166 100644 --- a/components/tidb_query/src/rpn_expr/types/test_util.rs +++ b/components/tidb_query/src/rpn_expr/types/test_util.rs @@ -1,5 +1,7 @@ // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. +use std::any::Any; + use tipb::{Expr, ExprType, FieldType, ScalarFuncSig}; use crate::codec::batch::LazyBatchColumnVec; @@ -16,6 +18,7 @@ pub struct RpnFnScalarEvaluator { rpn_expr_builder: RpnExpressionBuilder, return_field_type: Option, context: Option, + metadata: Option>, } impl RpnFnScalarEvaluator { @@ -25,6 +28,7 @@ impl RpnFnScalarEvaluator { rpn_expr_builder: RpnExpressionBuilder::new(), return_field_type: None, context: None, + metadata: None, } } @@ -67,8 +71,14 @@ impl RpnFnScalarEvaluator { /// Sets the context to use during evaluation. /// /// If not set, a default `EvalContext` will be used. - pub fn context(mut self, context: EvalContext) -> Self { - self.context = Some(context); + pub fn context(mut self, context: impl Into) -> Self { + self.context = Some(context.into()); + self + } + + /// Sets the metadata to use during evaluation. + pub fn metadata(mut self, metadata: Box) -> Self { + self.metadata = Some(metadata); self } @@ -115,11 +125,14 @@ impl RpnFnScalarEvaluator { return (Err(e), context); } - let metadata = match (func.metadata_expr_ptr)(&mut fun_sig_expr) { - Ok(metadata) => metadata, - Err(e) => { - return (Err(e), context); - } + let metadata = match self.metadata { + Some(metadata) => metadata, + None => match (func.metadata_expr_ptr)(&mut fun_sig_expr) { + Ok(metadata) => metadata, + Err(e) => { + return (Err(e), context); + } + }, }; let expr = self .rpn_expr_builder diff --git a/components/tidb_query_codegen/src/rpn_function.rs b/components/tidb_query_codegen/src/rpn_function.rs index 1f4bed9d4d8..1347ca2897e 100644 --- a/components/tidb_query_codegen/src/rpn_function.rs +++ b/components/tidb_query_codegen/src/rpn_function.rs @@ -166,8 +166,8 @@ //! let (regex, arg) = self.extract(0); //! let regex = build_regex(regex); //! let mut result = Vec::with_capacity(output_rows); -//! for row in 0..output_rows { -//! let (text, _) = arg.extract(row); +//! for row_index in 0..output_rows { +//! let (text, _) = arg.extract(row_index); //! result.push(regex_match_impl(®ex, text)?); //! } //! Ok(Evaluable::into_vector_value(result)) @@ -603,8 +603,10 @@ fn generate_metadata_type_checker( extra: &mut crate::rpn_expr::RpnFnCallExtra<'_>, expr: &mut ::tipb::Expr, ) #where_clause { - let metadata = #metadata_expr; - #fn_body + for row_index in 0..output_rows { + let metadata = #metadata_expr; + #fn_body + } } }; } @@ -1061,8 +1063,8 @@ impl NormalRpnFn { #downcast_metadata let arg = &self; let mut result = Vec::with_capacity(output_rows); - for row in 0..output_rows { - #(let (#extract, arg) = arg.extract(row));*; + for row_index in 0..output_rows { + #(let (#extract, arg) = arg.extract(row_index));*; result.push( #fn_ident #ty_generics_turbofish ( #(#captures,)* #(#call_arg),* )?); } Ok(crate::codec::data_type::Evaluable::into_vector_value(result)) @@ -1246,9 +1248,9 @@ mod tests_normal { ) -> crate::Result { let arg = &self; let mut result = Vec::with_capacity(output_rows); - for row in 0..output_rows { - let (arg0, arg) = arg.extract(row); - let (arg1, arg) = arg.extract(row); + for row_index in 0..output_rows { + let (arg0, arg) = arg.extract(row_index); + let (arg1, arg) = arg.extract(row_index); result.push(foo(arg0, arg1)?); } Ok(crate::codec::data_type::Evaluable::into_vector_value(result)) @@ -1414,8 +1416,8 @@ mod tests_normal { ) -> crate::Result { let arg = &self; let mut result = Vec::with_capacity(output_rows); - for row in 0..output_rows { - let (arg0, arg) = arg.extract(row); + for row_index in 0..output_rows { + let (arg0, arg) = arg.extract(row_index); result.push(foo :: (arg0)?); } Ok(crate::codec::data_type::Evaluable::into_vector_value(result)) @@ -1560,9 +1562,9 @@ mod tests_normal { ) -> crate::Result { let arg = &self; let mut result = Vec::with_capacity(output_rows); - for row in 0..output_rows { - let (arg0, arg) = arg.extract(row); - let (arg1, arg) = arg.extract(row); + for row_index in 0..output_rows { + let (arg0, arg) = arg.extract(row_index); + let (arg1, arg) = arg.extract(row_index); result.push(foo(ctx, arg0, arg1)?); } Ok(crate::codec::data_type::Evaluable::into_vector_value(result))