Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tidb_query_datatype, tidb_query_expr: Add div_precision_increment support in dag request #16622

Merged
merged 10 commits into from
Mar 13, 2024
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions components/tidb_query_datatype/src/codec/mysql/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should leave a blank line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

use std::{
cmp,
cmp::Ordering,
Expand Down Expand Up @@ -1714,7 +1713,7 @@ impl Decimal {
dec_encoded_len(&[prec, frac]).unwrap_or(3)
}

fn div(&self, rhs: &Decimal, frac_incr: u8) -> Option<Res<Decimal>> {
pub fn div(&self, rhs: &Decimal, frac_incr: u8) -> Option<Res<Decimal>> {
let result_frac_cnt =
cmp::min(self.result_frac_cnt.saturating_add(frac_incr), MAX_FRACTION);
let mut res = do_div_mod_impl(self, rhs, frac_incr, false, Some(result_frac_cnt));
Expand Down
12 changes: 12 additions & 0 deletions components/tidb_query_datatype/src/expr/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ impl SqlMode {

const DEFAULT_MAX_WARNING_CNT: usize = 64;

const DEFAULT_DIV_PRECISION_INCR: i32 = 4;

#[derive(Clone, Debug)]
pub struct EvalConfig {
/// timezone to use when parse/calculate time.
Expand All @@ -72,6 +74,7 @@ pub struct EvalConfig {
pub sql_mode: SqlMode,

pub paging_size: Option<u64>,
pub div_precision_increment: i32,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the type should be u8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, updated

}

impl Default for EvalConfig {
Expand All @@ -98,6 +101,9 @@ impl EvalConfig {
if req.has_sql_mode() {
eval_cfg.set_sql_mode(SqlMode::from_bits_truncate(req.get_sql_mode()));
}
if req.has_div_precision_increment() {
eval_cfg.set_div_precision_incr(req.get_div_precision_increment() as i32);
}
Ok(eval_cfg)
}

Expand All @@ -108,6 +114,7 @@ impl EvalConfig {
max_warning_cnt: DEFAULT_MAX_WARNING_CNT,
sql_mode: SqlMode::empty(),
paging_size: None,
div_precision_increment: DEFAULT_DIV_PRECISION_INCR,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new default value, how about using the existing value DEFAULT_DIV_FRAC_INCR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, changed it.

}
}

Expand All @@ -127,6 +134,11 @@ impl EvalConfig {
self
}

pub fn set_div_precision_incr(&mut self, new_value: i32) -> &mut Self {
self.div_precision_increment = new_value;
self
}

pub fn set_time_zone_by_name(&mut self, tz_name: &str) -> Result<&mut Self> {
match Tz::from_tz_name(tz_name) {
Some(tz) => {
Expand Down
47 changes: 36 additions & 11 deletions components/tidb_query_expr/src/impl_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,19 @@ impl ArithmeticOpWithCtx for DecimalDivide {
type T = Decimal;

fn calc(ctx: &mut EvalContext, lhs: &Decimal, rhs: &Decimal) -> Result<Option<Decimal>> {
Ok(if let Some(value) = lhs / rhs {
value
.into_result_with_overflow_err(
ctx,
Error::overflow("DECIMAL", format!("({} / {})", lhs, rhs)),
)
.map(Some)
} else {
// TODO: handle RpnFuncExtra's field_type, round the result if is needed.
ctx.handle_division_by_zero().map(|_| None)
}?)
Ok(
if let Some(value) = lhs.div(rhs, ctx.cfg.div_precision_increment as u8) {
value
.into_result_with_overflow_err(
ctx,
Error::overflow("DECIMAL", format!("({} / {})", lhs, rhs)),
)
.map(Some)
} else {
// TODO: handle RpnFuncExtra's field_type, round the result if is needed.
ctx.handle_division_by_zero().map(|_| None)
}?,
)
}
}

Expand Down Expand Up @@ -1237,6 +1239,29 @@ mod tests {

assert_eq!(actual, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}

let cases2 = vec![
(Some("2.2"), Some("1.3"), Some("1.692"), 2),
(Some("2.2"), Some("1.3"), Some("1.6923"), 3),
(Some("2.2"), Some("1.3"), Some("1.69231"), 4),
(None, Some("2"), None, 4),
(Some("123"), None, None, 4),
];
for (lhs, rhs, expected, frac_incr) in cases2 {
let mut cfg = EvalConfig::new();
cfg.set_div_precision_incr(frac_incr);
let ctx = EvalContext::new(cfg.into());
let actual: Option<Decimal> = RpnFnScalarEvaluator::new_for_test(ctx)
.push_param(lhs.map(|s| Decimal::from_str(s).unwrap()))
.push_param(rhs.map(|s| Decimal::from_str(s).unwrap()))
.evaluate(ScalarFuncSig::DivideDecimal)
.unwrap();

let expected = expected.map(|s| Decimal::from_str(s).unwrap());
if let (Some(lhs_), Some(rhs_)) = (expected, actual) {
assert_eq!(format!("{lhs_}"), format!("{rhs_}"));
}
}
}

#[test]
Expand Down
10 changes: 10 additions & 0 deletions components/tidb_query_expr/src/types/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ impl RpnFnScalarEvaluator {
}
}

/// Creates a new `RpnFnScalarEvaluator` for test usage.
pub fn new_for_test(ctx: EvalContext) -> Self {
Self {
rpn_expr_builder: RpnExpressionBuilder::new_for_test(),
return_field_type: None,
context: Some(ctx),
metadata: None,
}
}

/// Pushes a parameter as the value of an argument for evaluation. The field
/// type will be auto inferred by choosing an arbitrary field type that
/// matches the field type of the given value.
Expand Down