Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tidb_query_datatype, tidb_query_expr: Add div_precision_increment support in dag request #16622

Merged
merged 10 commits into from
Mar 13, 2024
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions components/tidb_query_datatype/src/codec/mysql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
codec::{
convert::{self, ConvertTo},
data_type::*,
mysql::DEFAULT_DIV_FRAC_INCR,
Error, Result, TEN_POW,
},
expr::EvalContext,
Expand Down Expand Up @@ -138,7 +139,6 @@ const DIG_MASK: u32 = TEN_POW[8];
const WORD_BASE: u32 = TEN_POW[9];
const WORD_MAX: u32 = WORD_BASE - 1;
const MAX_FRACTION: u8 = 30;
const DEFAULT_DIV_FRAC_INCR: u8 = 4;
const DIG_2_BYTES: &[u8] = &[0, 1, 1, 2, 2, 3, 3, 4, 4, 4];
const FRAC_MAX: &[u32] = &[
900000000, 990000000, 999000000, 999900000, 999990000, 999999000, 999999900, 999999990,
Expand Down Expand Up @@ -1714,7 +1714,7 @@ impl Decimal {
dec_encoded_len(&[prec, frac]).unwrap_or(3)
}

fn div(&self, rhs: &Decimal, frac_incr: u8) -> Option<Res<Decimal>> {
pub fn div(&self, rhs: &Decimal, frac_incr: u8) -> Option<Res<Decimal>> {
let result_frac_cnt =
cmp::min(self.result_frac_cnt.saturating_add(frac_incr), MAX_FRACTION);
let mut res = do_div_mod_impl(self, rhs, frac_incr, false, Some(result_frac_cnt));
Expand Down
3 changes: 3 additions & 0 deletions components/tidb_query_datatype/src/codec/mysql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub const MIN_FSP: i8 = 0;
/// `DEFAULT_FSP` is the default digit of fractional seconds part.
/// `MySQL` use 0 as the default Fsp.
pub const DEFAULT_FSP: i8 = 0;
/// `DEFAULT_DIV_FRAC_INCR` is the default value of decimal divide precision
/// inrements.
pub const DEFAULT_DIV_FRAC_INCR: u8 = 4;

fn check_fsp(fsp: i8) -> Result<u8> {
if fsp == UNSPECIFIED_FSP {
Expand Down
12 changes: 11 additions & 1 deletion components/tidb_query_datatype/src/expr/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use bitflags::bitflags;
use tipb::DagRequest;

use super::{Error, Result};
use crate::codec::mysql::Tz;
use crate::codec::mysql::{Tz, DEFAULT_DIV_FRAC_INCR};

bitflags! {
/// Please refer to SQLMode in `mysql/const.go` in repo `pingcap/parser` for details.
Expand Down Expand Up @@ -72,6 +72,7 @@ pub struct EvalConfig {
pub sql_mode: SqlMode,

pub paging_size: Option<u64>,
pub div_precision_increment: u8,
}

impl Default for EvalConfig {
Expand All @@ -98,6 +99,9 @@ impl EvalConfig {
if req.has_sql_mode() {
eval_cfg.set_sql_mode(SqlMode::from_bits_truncate(req.get_sql_mode()));
}
if req.has_div_precision_increment() {
eval_cfg.set_div_precision_incr(req.get_div_precision_increment() as u8);
}
Ok(eval_cfg)
}

Expand All @@ -108,6 +112,7 @@ impl EvalConfig {
max_warning_cnt: DEFAULT_MAX_WARNING_CNT,
sql_mode: SqlMode::empty(),
paging_size: None,
div_precision_increment: DEFAULT_DIV_FRAC_INCR,
}
}

Expand All @@ -127,6 +132,11 @@ impl EvalConfig {
self
}

pub fn set_div_precision_incr(&mut self, new_value: u8) -> &mut Self {
self.div_precision_increment = new_value;
self
}

pub fn set_time_zone_by_name(&mut self, tz_name: &str) -> Result<&mut Self> {
match Tz::from_tz_name(tz_name) {
Some(tz) => {
Expand Down
47 changes: 36 additions & 11 deletions components/tidb_query_expr/src/impl_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,19 @@ impl ArithmeticOpWithCtx for DecimalDivide {
type T = Decimal;

fn calc(ctx: &mut EvalContext, lhs: &Decimal, rhs: &Decimal) -> Result<Option<Decimal>> {
Ok(if let Some(value) = lhs / rhs {
value
.into_result_with_overflow_err(
ctx,
Error::overflow("DECIMAL", format!("({} / {})", lhs, rhs)),
)
.map(Some)
} else {
// TODO: handle RpnFuncExtra's field_type, round the result if is needed.
ctx.handle_division_by_zero().map(|_| None)
}?)
Ok(
if let Some(value) = lhs.div(rhs, ctx.cfg.div_precision_increment) {
value
.into_result_with_overflow_err(
ctx,
Error::overflow("DECIMAL", format!("({} / {})", lhs, rhs)),
)
.map(Some)
} else {
// TODO: handle RpnFuncExtra's field_type, round the result if is needed.
ctx.handle_division_by_zero().map(|_| None)
}?,
)
}
}

Expand Down Expand Up @@ -1237,6 +1239,29 @@ mod tests {

assert_eq!(actual, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}

let cases2 = vec![
(Some("2.2"), Some("1.3"), Some("1.692"), 2),
(Some("2.2"), Some("1.3"), Some("1.6923"), 3),
(Some("2.2"), Some("1.3"), Some("1.69231"), 4),
(None, Some("2"), None, 4),
(Some("123"), None, None, 4),
];
for (lhs, rhs, expected, frac_incr) in cases2 {
let mut cfg = EvalConfig::new();
cfg.set_div_precision_incr(frac_incr);
let ctx = EvalContext::new(cfg.into());
let actual: Option<Decimal> = RpnFnScalarEvaluator::new_for_test(ctx)
.push_param(lhs.map(|s| Decimal::from_str(s).unwrap()))
.push_param(rhs.map(|s| Decimal::from_str(s).unwrap()))
.evaluate(ScalarFuncSig::DivideDecimal)
.unwrap();

let expected = expected.map(|s| Decimal::from_str(s).unwrap());
if let (Some(lhs_), Some(rhs_)) = (expected, actual) {
assert_eq!(format!("{lhs_}"), format!("{rhs_}"));
}
}
}

#[test]
Expand Down
10 changes: 10 additions & 0 deletions components/tidb_query_expr/src/types/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ impl RpnFnScalarEvaluator {
}
}

/// Creates a new `RpnFnScalarEvaluator` for test usage.
pub fn new_for_test(ctx: EvalContext) -> Self {
Self {
rpn_expr_builder: RpnExpressionBuilder::new_for_test(),
return_field_type: None,
context: Some(ctx),
metadata: None,
}
}

/// Pushes a parameter as the value of an argument for evaluation. The field
/// type will be auto inferred by choosing an arbitrary field type that
/// matches the field type of the given value.
Expand Down