diff --git a/mysqldef/decimal.go b/mysqldef/decimal.go index 535465a80f3c9..4276e52d52afe 100644 --- a/mysqldef/decimal.go +++ b/mysqldef/decimal.go @@ -598,6 +598,11 @@ func (d Decimal) Value() (driver.Value, error) { return d.String(), nil } +// BigIntValue returns the *bit.Int value member of decimal. +func (d Decimal) BigIntValue() *big.Int { + return d.value +} + // UnmarshalText implements the encoding.TextUnmarshaler interface for XML // deserialization. func (d *Decimal) UnmarshalText(text []byte) error { diff --git a/mysqldef/decimal_test.go b/mysqldef/decimal_test.go index 232beb9d801d2..88448d1f7d2cb 100644 --- a/mysqldef/decimal_test.go +++ b/mysqldef/decimal_test.go @@ -868,5 +868,4 @@ func didPanic(f func()) bool { }() return ret - } diff --git a/util/codec/codec.go b/util/codec/codec.go index b0abe61616a95..2133ac27481a5 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -39,6 +39,7 @@ const ( formatStringFlag = 's' formatBytesFlag = 'b' formatDurationFlag = 't' + formatDecimalFlag = 'c' ) var sepKey = []byte{0x00, 0x00} @@ -107,6 +108,9 @@ func EncodeKey(args ...interface{}) ([]byte, error) { // duration may have negative value, so we cannot use String to encode directly. b = EncodeInt(b, int64(v.Duration)) format = append(format, formatDurationFlag) + case mysql.Decimal: + b = EncodeDecimal(b, v) + format = append(format, formatDecimalFlag) case nil: // We will 0x00, 0x00 for nil. // The []byte{} will be encoded as 0x00, 0x01. @@ -177,6 +181,8 @@ func DecodeKey(b []byte) ([]interface{}, error) { // use max fsp, let outer to do round manually. v[i] = mysql.Duration{Duration: time.Duration(r), Fsp: mysql.MaxFsp} } + case formatDecimalFlag: + b, v[i], err = DecodeDecimal(b) case formatNilFlag: if len(b) < 2 || (b[0] != 0x00 && b[1] != 0x00) { return nil, errors.Errorf("malformed encoded nil") diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index f1b42705ff527..837bd56ba8771 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -513,3 +513,111 @@ func (s *testCodecSuite) TestDuration(c *C) { c.Assert(ret, Equals, t.Ret) } } + +func (s *testCodecSuite) TestDecimal(c *C) { + tbl := []string{ + "1234.00", + "1234", + "12.34", + "12.340", + "0.1234", + "0.0", + "0", + "-0.0", + "-0.0000", + "-1234.00", + "-1234", + "-12.34", + "-12.340", + "-0.1234"} + + for _, t := range tbl { + m, err := mysql.ParseDecimal(t) + c.Assert(err, IsNil) + b, err := EncodeKey(m) + c.Assert(err, IsNil) + v, err := DecodeKey(b) + c.Assert(err, IsNil) + c.Assert(v, HasLen, 1) + vv, ok := v[0].(mysql.Decimal) + c.Assert(ok, IsTrue) + c.Assert(vv.Equals(m), IsTrue) + } + + tblCmp := []struct { + Arg1 interface{} + Arg2 interface{} + Ret int + }{ + // Test for float type decimal. + {"1234", "123400", -1}, + {"12340", "123400", -1}, + {"1234", "1234.5", -1}, + {"1234", "1234.0000", 0}, + {"1234", "12.34", 1}, + {"12.34", "12.35", -1}, + {"0.1234", "12.3400", -1}, + {"0.1234", "0.1235", -1}, + {"0.123400", "12.34", -1}, + {"12.34000", "12.34", 0}, + {"0.01234", "0.01235", -1}, + {"0.1234", "0", 1}, + {"0.0000", "0", 0}, + {"0.0001", "0", 1}, + {"0.0001", "0.0000", 1}, + {"0", "-0.0000", 0}, + {"-0.0001", "0", -1}, + {"-0.1234", "0", -1}, + {"-0.1234", "0.1234", -1}, + {"-1.234", "-12.34", 1}, + {"-0.1234", "-12.34", 1}, + {"-12.34", "1234", -1}, + {"-12.34", "-12.35", 1}, + {"-0.01234", "-0.01235", 1}, + {"-1234", "-123400", 1}, + {"-12340", "-123400", 1}, + + // Test for int type decimal. + {-1, 1, -1}, + {math.MaxInt64, math.MinInt64, 1}, + {math.MaxInt64, math.MaxInt32, 1}, + {math.MinInt32, math.MaxInt16, -1}, + {math.MinInt64, math.MaxInt8, -1}, + {0, math.MaxInt8, -1}, + {math.MinInt8, 0, -1}, + {math.MinInt16, math.MaxInt16, -1}, + {1, -1, 1}, + {1, 0, 1}, + {-1, 0, -1}, + {0, 0, 0}, + {math.MaxInt16, math.MaxInt16, 0}, + + // Test for uint type decimal. + {uint64(0), uint64(0), 0}, + {uint64(1), uint64(0), 1}, + {uint64(0), uint64(1), -1}, + {uint64(math.MaxInt8), uint64(math.MaxInt16), -1}, + {uint64(math.MaxUint32), uint64(math.MaxInt32), 1}, + {uint64(math.MaxUint8), uint64(math.MaxInt8), 1}, + {uint64(math.MaxUint16), uint64(math.MaxInt32), -1}, + {uint64(math.MaxUint64), uint64(math.MaxInt64), 1}, + {uint64(math.MaxInt64), uint64(math.MaxUint32), 1}, + {uint64(math.MaxUint64), uint64(0), 1}, + {uint64(0), uint64(math.MaxUint64), -1}, + } + + for _, t := range tblCmp { + m1, err := mysql.ConvertToDecimal(t.Arg1) + c.Assert(err, IsNil) + m2, err := mysql.ConvertToDecimal(t.Arg2) + c.Assert(err, IsNil) + + b1, err := EncodeKey(m1) + c.Assert(err, IsNil) + b2, err := EncodeKey(m2) + c.Assert(err, IsNil) + + ret := bytes.Compare(b1, b2) + c.Assert(ret, Equals, t.Ret) + } +} diff --git a/util/codec/decimal.go b/util/codec/decimal.go new file mode 100644 index 0000000000000..f33b5ab39f8e1 --- /dev/null +++ b/util/codec/decimal.go @@ -0,0 +1,189 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package codec + +import ( + "bytes" + "math/big" + + "github.com/juju/errors" + mysql "github.com/pingcap/tidb/mysqldef" +) + +const ( + negativeSign int64 = 8 + zeroSign int64 = 16 + positiveSign int64 = 24 +) + +func codecSign(value int64) int64 { + if value < 0 { + return negativeSign + } + + return positiveSign +} + +func encodeExp(expValue int64, expSign int64, valSign int64) int64 { + if expSign == negativeSign { + expValue = -expValue + } + + if expSign != valSign { + expValue = ^expValue + } + + return expValue +} + +func decodeExp(expValue int64, expSign int64, valSign int64) int64 { + if expSign != valSign { + expValue = ^expValue + } + + if expSign == negativeSign { + expValue = -expValue + } + + return expValue +} + +func codecValue(value []byte, valSign int64) { + if valSign == negativeSign { + reverseBytes(value) + } +} + +// EncodeDecimal encodes a decimal d into a byte slice which can be sorted lexicographically later. +// EncodeDecimal guarantees that the encoded value is in ascending order for comparison. +// Decimal encoding: +// EncodeInt -> value sign +// EncodeInt -> exp sign +// EncodeInt -> exp value +// EncodeBytes -> abs value bytes +func EncodeDecimal(b []byte, d mysql.Decimal) []byte { + if d.Equals(mysql.ZeroDecimal) { + return EncodeInt(b, zeroSign) + } + + v := d.BigIntValue() + valSign := codecSign(int64(v.Sign())) + + absVal := new(big.Int) + absVal.Abs(v) + + value := []byte(absVal.String()) + + // Trim right side "0", like "12.34000" -> "12.34" or "0.1234000" -> "0.1234". + if d.Exponent() != 0 { + value = bytes.TrimRight(value, "0") + } + + // Get exp and value, format is "value":"exp". + // like "12.34" -> "0.1234":"2". + // like "-0.01234" -> "-0.1234":"-1". + exp := int64(0) + div := big.NewInt(10) + for ; ; exp++ { + if absVal.Sign() == 0 { + break + } + absVal = absVal.Div(absVal, div) + } + + expVal := exp + int64(d.Exponent()) + expSign := codecSign(expVal) + + // For negtive exp, do bit reverse for exp. + // For negtive decimal, do bit reverse for exp and value. + expVal = encodeExp(expVal, expSign, valSign) + codecValue(value, valSign) + + r := EncodeInt(b, valSign) + r = EncodeInt(r, expSign) + r = EncodeInt(r, expVal) + r = EncodeBytes(r, value) + return r +} + +// DecodeDecimal decodes bytes to decimal. +// DecodeFloat decodes a float from a byte slice +// Decimal decoding: +// DecodeInt -> value sign +// DecodeInt -> exp sign +// DecodeInt -> exp value +// DecodeBytes -> abs value bytes +func DecodeDecimal(b []byte) ([]byte, mysql.Decimal, error) { + var ( + r []byte + d mysql.Decimal + err error + ) + + // Decode value sign. + valSign := zeroSign + r, valSign, err = DecodeInt(b) + if err != nil { + return r, d, errors.Trace(err) + } + if valSign == zeroSign { + d, err = mysql.ParseDecimal("0") + return r, d, errors.Trace(err) + } + + // Decode exp sign. + expSign := zeroSign + r, expSign, err = DecodeInt(r) + if err != nil { + return r, d, errors.Trace(err) + } + + // Decode exp value. + expVal := int64(0) + r, expVal, err = DecodeInt(r) + if err != nil { + return r, d, errors.Trace(err) + } + expVal = decodeExp(expVal, expSign, valSign) + + // Decode abs value bytes. + value := []byte{} + r, value, err = DecodeBytes(r) + if err != nil { + return r, d, errors.Trace(err) + } + codecValue(value, valSign) + + // Generate decimal string value. + var decimalStr []byte + if valSign == negativeSign { + decimalStr = append(decimalStr, '-') + } + + if expVal <= 0 { + // Like decimal "0.1234" or "0.01234". + decimalStr = append(decimalStr, '0') + decimalStr = append(decimalStr, '.') + decimalStr = append(decimalStr, bytes.Repeat([]byte{'0'}, -int(expVal))...) + decimalStr = append(decimalStr, value...) + } else { + // Like decimal "12.34". + decimalStr = append(decimalStr, value[:expVal]...) + decimalStr = append(decimalStr, '.') + decimalStr = append(decimalStr, value[expVal:]...) + } + + d, err = mysql.ParseDecimal(string(decimalStr)) + return r, d, errors.Trace(err) +} diff --git a/util/codec/decimal_test.go b/util/codec/decimal_test.go new file mode 100644 index 0000000000000..28532e766e6f3 --- /dev/null +++ b/util/codec/decimal_test.go @@ -0,0 +1,52 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package codec + +import ( + . "github.com/pingcap/check" + mysql "github.com/pingcap/tidb/mysqldef" +) + +var _ = Suite(&testDecimalSuite{}) + +type testDecimalSuite struct { +} + +func (s *testDecimalSuite) TestDecimalCodec(c *C) { + inputs := []struct { + Input float64 + }{ + {float64(123400)}, + {float64(1234)}, + {float64(12.34)}, + {float64(0.1234)}, + {float64(0.01234)}, + {float64(-0.1234)}, + {float64(-0.01234)}, + {float64(12.3400)}, + {float64(-12.34)}, + {float64(0.00000)}, + {float64(0)}, + {float64(-0.0)}, + {float64(-0.000)}, + } + + for _, input := range inputs { + v := mysql.NewDecimalFromFloat(input.Input) + b := EncodeDecimal([]byte{}, v) + _, d, err := DecodeDecimal(b) + c.Assert(err, IsNil) + c.Assert(v.Equals(d), IsTrue) + } +}