Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(expr): round to left of decimal point when 2nd arg is negative #10961

Merged
merged 1 commit into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/common/src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ impl Sub for Decimal {
}

impl Decimal {
pub const MAX_PRECISION: u8 = 28;

/// TODO: handle nan and inf
pub fn mantissa(&self) -> i128 {
match self {
Expand Down Expand Up @@ -457,6 +459,38 @@ impl Decimal {
}
}

/// Round to the left of the decimal point, for example `31.5` -> `30`.
#[must_use]
pub fn round_left_ties_away(&self, left: u32) -> Option<Self> {
let Self::Normalized(mut d) = self else { return Some(*self) };

// First, move the decimal point to the left so that we can reuse `round`. This is more
// efficient than division.
let old_scale = d.scale();
let new_scale = old_scale.saturating_add(left);
const MANTISSA_UP: i128 = 5 * 10i128.pow(Decimal::MAX_PRECISION as _);
let d = match new_scale.cmp(&Self::MAX_PRECISION.add(1).into()) {
// trivial within 28 digits
std::cmp::Ordering::Less => {
d.set_scale(new_scale).unwrap();
d.round_dp_with_strategy(0, RoundingStrategy::MidpointAwayFromZero)
}
// Special case: scale cannot be 29, but it may or may not be >= 0.5e+29
std::cmp::Ordering::Equal => (d.mantissa() / MANTISSA_UP).signum().into(),
// always 0 for >= 30 digits
std::cmp::Ordering::Greater => 0.into(),
};

// Then multiply back. Note that we cannot move decimal point to the right in order to get
// more zeros.
match left > Decimal::MAX_PRECISION.into() {
true => d.is_zero().then(|| 0.into()),
false => d
.checked_mul(RustDecimal::from_i128_with_scale(10i128.pow(left), 0))
.map(Self::Normalized),
}
}

#[must_use]
pub fn ceil(&self) -> Self {
match self {
Expand Down
7 changes: 3 additions & 4 deletions src/connector/src/parser/avro/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use apache_avro::Schema;
use itertools::Itertools;
use risingwave_common::types::DataType;
use risingwave_common::types::{DataType, Decimal};
use risingwave_pb::plan_common::ColumnDesc;

pub fn avro_schema_to_column_descs(schema: &Schema) -> anyhow::Result<Vec<ColumnDesc>> {
Expand All @@ -30,7 +30,6 @@ pub fn avro_schema_to_column_descs(schema: &Schema) -> anyhow::Result<Vec<Column
}
}

const RW_DECIMAL_MAX_PRECISION: usize = 28;
const DBZ_VARIABLE_SCALE_DECIMAL_NAME: &str = "VariableScaleDecimal";
const DBZ_VARIABLE_SCALE_DECIMAL_NAMESPACE: &str = "io.debezium.data";

Expand Down Expand Up @@ -81,10 +80,10 @@ fn avro_type_mapping(schema: &Schema) -> anyhow::Result<DataType> {
Schema::Float => DataType::Float32,
Schema::Double => DataType::Float64,
Schema::Decimal { precision, .. } => {
if precision > &RW_DECIMAL_MAX_PRECISION {
if *precision > Decimal::MAX_PRECISION.into() {
tracing::warn!(
"RisingWave supports decimal precision up to {}, but got {}. Will truncate.",
RW_DECIMAL_MAX_PRECISION,
Decimal::MAX_PRECISION,
precision
);
}
Expand Down
1 change: 0 additions & 1 deletion src/connector/src/parser/unified/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ where
}
}

const RW_DECIMAL_MAX_PRECISION: usize = 28;
pub(crate) fn avro_decimal_to_rust_decimal(
avro_decimal: AvroDecimal,
_precision: usize,
Expand Down
47 changes: 34 additions & 13 deletions src/expr/src/vector_op/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use num_traits::Zero;
use risingwave_common::types::{Decimal, F64};
use risingwave_expr_macro::function;

use crate::{ExprError, Result};

#[function("round_digit(decimal, int32) -> decimal")]
pub fn round_digits<D: Into<i32>>(input: Decimal, digits: D) -> Decimal {
let digits = digits.into();
pub fn round_digits(input: Decimal, digits: i32) -> Result<Decimal> {
if digits < 0 {
Decimal::zero()
input
.round_left_ties_away(digits.unsigned_abs())
.ok_or(ExprError::NumericOverflow)
} else {
// rust_decimal can only handle up to 28 digits of scale
input.round_dp_ties_away(std::cmp::min(digits as u32, 28))
Ok(input.round_dp_ties_away((digits as u32).min(Decimal::MAX_PRECISION.into())))
}
}

Expand Down Expand Up @@ -78,22 +80,41 @@ mod tests {
use super::ceil_f64;
use crate::vector_op::round::*;

fn do_test(input: &str, digits: i32, expected_output: &str) {
fn do_test(input: &str, digits: i32, expected_output: Option<&str>) {
let v = Decimal::from_str(input).unwrap();
let rounded_value = round_digits(v, digits);
assert_eq!(expected_output, rounded_value.to_string().as_str());
let rounded_value = round_digits(v, digits).ok();
assert_eq!(
expected_output,
rounded_value.as_ref().map(ToString::to_string).as_deref()
);
}

#[test]
fn test_round_digits() {
do_test("21.666666666666666666666666667", 4, "21.6667");
do_test("84818.33333333333333333333333", 4, "84818.3333");
do_test("84818.15", 1, "84818.2");
do_test("21.372736", -1, "0");
do_test("21.666666666666666666666666667", 4, Some("21.6667"));
do_test("84818.33333333333333333333333", 4, Some("84818.3333"));
do_test("84818.15", 1, Some("84818.2"));
do_test("21.372736", -1, Some("20"));
do_test("-79228162514264337593543950335", -30, Some("0"));
do_test("-79228162514264337593543950335", -29, None);
do_test("-79228162514264337593543950335", -28, None);
do_test(
"-79228162514264337593543950335",
-27,
Some("-79000000000000000000000000000"),
);
do_test("-792.28162514264337593543950335", -4, Some("0"));
do_test("-792.28162514264337593543950335", -3, Some("-1000"));
do_test("-792.28162514264337593543950335", -2, Some("-800"));
do_test("-792.28162514264337593543950335", -1, Some("-790"));
do_test("-50000000000000000000000000000", -29, None);
do_test("-49999999999999999999999999999", -29, Some("0"));
do_test("-500.00000000000000000000000000", -3, Some("-1000"));
do_test("-499.99999999999999999999999999", -3, Some("0"));
// When digit extends past original scale, it should just return original scale.
// Intuitively, it does not make sense after rounding `0` it becomes `0.000`. Precision
// should always be less or equal, not more.
do_test("0", 340, "0");
do_test("0", 340, Some("0"));
}

#[test]
Expand Down
Loading