From dc15bfbd45679da3a69f7ac02fe087fd36897fb8 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Tue, 21 Oct 2025 17:19:59 +0200 Subject: [PATCH 01/13] feat: support decimal --- README.md | 31 +++ buffer.go | 41 ++++ buffer_test.go | 141 +++++++++++ conf_parse.go | 4 +- decimal.go | 401 +++++++++++++++++++++++++++++++ export_test.go | 24 ++ http_sender.go | 112 ++++++++- http_sender_test.go | 34 ++- integration_test.go | 71 +++++- interop_test.go | 164 +++++++------ sender.go | 20 +- sender_pool.go | 5 + tcp_sender.go | 98 +++++++- tcp_sender_test.go | 19 ++ test/interop/questdb-client-test | 2 +- utils_test.go | 29 +++ 16 files changed, 1089 insertions(+), 107 deletions(-) create mode 100644 decimal.go diff --git a/README.md b/README.md index d6aeff7c..804cd522 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,37 @@ will now also serialize ``float64`` (double-precision) columns as binary. You might see a performance uplift if this is a dominant data type in your ingestion workload. +## Decimal columns + +QuestDB server version 9.2.0 and newer supports decimal columns with arbitrary precision and scale. +The Go client converts supported decimal values to QuestDB's text/binary wire format automatically, pass any of the following to `DecimalColumn`: + +- `questdb.ScaledDecimal`, including helpers like `questdb.NewDecimalFromInt64` and `questdb.NewDecimal`. +- Types implementing `questdb.DecimalMarshaler`. +- `github.com/shopspring/decimal.Decimal` values or pointers. +- `nil` or `questdb.NullDecimal()` to send a `NULL`. + +```go +price := qdb.NewDecimalFromInt64(12345, 2) // 123.45 with scale 2 +commission := qdb.NewDecimal(big.NewInt(-750), 4) // -0.0750 with scale 4 + +err = sender. + Table("trades"). + Symbol("symbol", "ETH-USD"). + DecimalColumn("price", price). + DecimalColumn("commission", commission). + AtNow(ctx) +``` + +To emit textual decimals, pass a validated string literal (without the trailing `d`—the client adds it): + +```go +err = sender. + Table("quotes"). + DecimalColumn("mid", "1.23456"). + AtNow(ctx) +``` + ## Pooled Line Senders **Warning: Experimental feature designed for use with HTTP senders ONLY** diff --git a/buffer.go b/buffer.go index 85d586f1..bb71f3b6 100644 --- a/buffer.go +++ b/buffer.go @@ -573,6 +573,47 @@ func (b *buffer) Float64Column(name string, val float64) *buffer { return b } +func (b *buffer) DecimalColumn(name string, val any) *buffer { + if !b.prepareForField() { + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + b.WriteByte('=') + if str, ok := val.(string); ok { + if err := validateDecimalText(str); err != nil { + b.lastErr = err + return b + } + b.WriteString(str) + b.WriteByte('d') + b.hasFields = true + return b + } + + dec, err := normalizeDecimalValue(val) + if err != nil { + b.lastErr = err + return b + } + scale, payload, err := dec.toBinary() + if err != nil { + b.lastErr = err + return b + } + b.WriteByte('=') + b.WriteByte(decimalBinaryTypeCode) + b.WriteByte(scale) + b.WriteByte(byte(len(payload))) + if len(payload) > 0 { + b.Write(payload) + } + b.hasFields = true + return b +} + func (b *buffer) Float64ColumnBinary(name string, val float64) *buffer { if !b.prepareForField() { return b diff --git a/buffer_test.go b/buffer_test.go index 646e5063..5eaa2a98 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -37,8 +37,23 @@ import ( "github.com/stretchr/testify/assert" ) +const decimalTypeCode byte = 0x17 + type bufWriterFn func(b *qdb.Buffer) error +type fakeShopspringDecimal struct { + coeff *big.Int + exp int32 +} + +func (f fakeShopspringDecimal) Coefficient() *big.Int { + return f.coeff +} + +func (f fakeShopspringDecimal) Exponent() int32 { + return f.exp +} + func newTestBuffer() qdb.Buffer { return qdb.NewBuffer(128*1024, 1024*1024, 127) } @@ -481,6 +496,132 @@ func TestFloat64ColumnBinary(t *testing.T) { } } +func TestDecimalColumnText(t *testing.T) { + prefix := []byte(testTable + " price==") + testCases := []struct { + name string + value any + expected []byte + }{ + { + name: "positive", + value: qdb.NewDecimalFromInt64(12345, 2), + expected: append(prefix, 0x17, 0x02, 0x02, 0x30, 0x39, 0x0A), + }, + { + name: "negative", + value: qdb.NewDecimal(big.NewInt(-12345), 3), + expected: append(prefix, 0x17, 0x03, 0x02, 0xCF, 0xC7, 0x0A), + }, + { + name: "zero with scale", + value: qdb.NewDecimalFromInt64(0, 4), + expected: append(prefix, 0x17, 0x04, 0x01, 0x0, 0x0A), + }, + { + name: "null decimal", + value: qdb.NullDecimal(), + expected: append(prefix, 0x17, 0x0, 0x0, 0x0A), + }, + { + name: "shopspring compatible", + value: fakeShopspringDecimal{coeff: big.NewInt(123456), exp: -4}, + expected: append(prefix, 0x17, 0x04, 0x03, 0x01, 0xE2, 0x40, 0x0A), + }, + { + name: "nil pointer treated as null", + value: (*fakeShopspringDecimal)(nil), + expected: append(prefix, 0x17, 0x0, 0x0, 0x0A), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumn("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, tc.expected, buf.Messages()) + }) + } +} + +func TestDecimalColumnStringValidation(t *testing.T) { + t.Run("valid strings", func(t *testing.T) { + testCases := []struct { + name string + value string + expected string + }{ + {"integer", "123", "123d"}, + {"decimal", "123.450", "123.450d"}, + {"negative", "-0.001", "-0.001d"}, + {"exponent positive", "1.2e3", "1.2e3d"}, + {"exponent negative", "-4.5E-2", "-4.5E-2d"}, + {"nan token", "NaN", "NaNd"}, + {"infinity token", "Infinity", "Infinityd"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumn("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + expected := []byte(testTable + " price=" + tc.expected + "\n") + assert.Equal(t, expected, buf.Messages()) + }) + } + }) + + t.Run("invalid strings", func(t *testing.T) { + testCases := []struct { + name string + value string + }{ + {"empty", ""}, + {"sign only", "+"}, + {"double dot", "12.3.4"}, + {"invalid char", "12a3"}, + {"exponent missing mantissa", "e10"}, + {"exponent no digits", "1.2e"}, + {"exponent sign no digits", "1.2e+"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumn("price", tc.value).At(time.Time{}, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "decimal") + assert.Empty(t, buf.Messages()) + }) + } + }) +} + +func TestDecimalColumnErrors(t *testing.T) { + t.Run("invalid scale", func(t *testing.T) { + buf := newTestBuffer() + dec := qdb.NewDecimalFromInt64(1, 100) + err := buf.Table(testTable).DecimalColumn("price", dec).At(time.Time{}, false) + assert.ErrorContains(t, err, "decimal scale") + assert.Empty(t, buf.Messages()) + }) + + t.Run("overflow", func(t *testing.T) { + buf := newTestBuffer() + bigVal := new(big.Int).Lsh(big.NewInt(1), 2100) + dec := qdb.NewDecimal(bigVal, 0) + err := buf.Table(testTable).DecimalColumn("price", dec).At(time.Time{}, false) + assert.ErrorContains(t, err, "exceeds 256-bit range") + assert.Empty(t, buf.Messages()) + }) + + t.Run("unsupported type", func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumn("price", struct{}{}).At(time.Time{}, false) + assert.ErrorContains(t, err, "unsupported decimal column value type") + assert.Empty(t, buf.Messages()) + }) +} + func TestFloat64Array1DColumn(t *testing.T) { testCases := []struct { name string diff --git a/conf_parse.go b/conf_parse.go index 7957dc51..66e84c40 100644 --- a/conf_parse.go +++ b/conf_parse.go @@ -169,8 +169,8 @@ func confFromStr(conf string) (*lineSenderConfig, error) { return nil, NewInvalidConfigStrError("invalid %s value, %q is not a valid int", k, v) } pVersion := protocolVersion(version) - if pVersion < ProtocolVersion1 || pVersion > ProtocolVersion2 { - return nil, NewInvalidConfigStrError("current client only supports protocol version 1 (text format for all datatypes), 2 (binary format for part datatypes) or explicitly unset") + if pVersion < ProtocolVersion1 || pVersion > ProtocolVersion3 { + return nil, NewInvalidConfigStrError("current client only supports protocol version 1 (text format for all datatypes), 2 (binary format for part datatypes), 3 (decimals) or explicitly unset") } senderConf.protocolVersion = pVersion } diff --git a/decimal.go b/decimal.go new file mode 100644 index 00000000..32990519 --- /dev/null +++ b/decimal.go @@ -0,0 +1,401 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2024 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ + +package questdb + +import ( + "encoding/binary" + "fmt" + "math/big" + "reflect" +) + +const ( + decimalBinaryTypeCode byte = 0x17 + maxDecimalScale uint32 = 76 + maxDecimalBytes int = 127 +) + +// ScaledDecimal represents a decimal value as a two's complement big-endian byte slice and a scale. +// NULL decimals are represented by valid=false. +type ScaledDecimal struct { + scale uint32 + unscaled []byte + valid bool +} + +// DecimalMarshaler allows custom types to provide a QuestDB-compatible decimal representation. +type DecimalMarshaler interface { + QuestDBDecimal() (ScaledDecimal, error) +} + +type shopspringDecimal interface { + Coefficient() *big.Int + Exponent() int32 +} + +// NewScaledDecimal constructs a decimal from a two's complement big-endian unscaled value and a scale. +// A nil unscaled slice produces a NULL decimal. +func NewScaledDecimal(unscaled []byte, scale uint32) ScaledDecimal { + if len(unscaled) == 0 { + return NullDecimal() + } + return ScaledDecimal{ + scale: scale, + unscaled: normalizeTwosComplement(unscaled), + valid: true, + } +} + +// NewDecimal constructs a decimal from an arbitrary-precision integer and a scale. +// Providing a nil unscaled value produces a NULL decimal. +func NewDecimal(unscaled *big.Int, scale uint32) ScaledDecimal { + if unscaled == nil { + return NullDecimal() + } + return ScaledDecimal{ + scale: scale, + unscaled: bigIntToTwosComplement(unscaled), + valid: true, + } +} + +// NewDecimalFromInt64 constructs a decimal from a 64-bit integer and a scale. +func NewDecimalFromInt64(unscaled int64, scale uint32) ScaledDecimal { + var be [8]byte + binary.BigEndian.PutUint64(be[:], uint64(unscaled)) + payload := trimTwosComplement(be[:]) + return ScaledDecimal{ + scale: scale, + unscaled: payload, + valid: true, + } +} + +// NullDecimal returns a NULL decimal representation. +func NullDecimal() ScaledDecimal { + return ScaledDecimal{} +} + +// IsNull reports whether the decimal represents NULL. +func (d ScaledDecimal) IsNull() bool { + return !d.valid +} + +// Scale returns the decimal scale. +func (d ScaledDecimal) Scale() uint32 { + return d.scale +} + +// UnscaledValue returns a copy of the unscaled integer value. +// For NULL decimals it returns nil. +func (d ScaledDecimal) UnscaledValue() *big.Int { + if d.IsNull() { + return nil + } + return twosComplementToBigInt(d.unscaled) +} + +func (d ScaledDecimal) ensureValidScale() error { + if d.IsNull() { + return nil + } + if d.scale > maxDecimalScale { + return fmt.Errorf("decimal scale %d exceeds maximum %d", d.scale, maxDecimalScale) + } + return nil +} + +func (d ScaledDecimal) toBinary() (byte, []byte, error) { + if d.IsNull() { + return 0, nil, nil + } + if err := d.ensureValidScale(); err != nil { + return 0, nil, err + } + payload := append([]byte(nil), d.unscaled...) + if len(payload) == 0 { + payload = []byte{0} + } + if len(payload) > maxDecimalBytes { + return 0, nil, fmt.Errorf("decimal value exceeds 256-bit range (got %d bytes)", len(payload)) + } + return byte(d.scale), payload, nil +} + +func normalizeDecimalValue(value any) (ScaledDecimal, error) { + if value == nil { + return NullDecimal(), nil + } + + switch v := value.(type) { + case ScaledDecimal: + return canonicalDecimal(v), nil + case *ScaledDecimal: + if v == nil { + return NullDecimal(), nil + } + return canonicalDecimal(*v), nil + case DecimalMarshaler: + if isNilInterface(v) { + return NullDecimal(), nil + } + dec, err := v.QuestDBDecimal() + if err != nil { + return ScaledDecimal{}, err + } + return canonicalDecimal(dec), nil + } + + if dec, ok := convertShopspringDecimal(value); ok { + return dec, nil + } + + return ScaledDecimal{}, fmt.Errorf("unsupported decimal column value type %T", value) +} + +func canonicalDecimal(d ScaledDecimal) ScaledDecimal { + if !d.valid { + return NullDecimal() + } + if len(d.unscaled) == 0 { + return NullDecimal() + } + return ScaledDecimal{ + scale: d.scale, + unscaled: normalizeTwosComplement(d.unscaled), + valid: true, + } +} + +func convertShopspringDecimal(value any) (ScaledDecimal, bool) { + dec, ok := value.(shopspringDecimal) + if !ok { + return ScaledDecimal{}, false + } + if isNilInterface(dec) { + return NullDecimal(), true + } + + coeff := dec.Coefficient() + if coeff == nil { + return NullDecimal(), true + } + + exp := dec.Exponent() + if exp >= 0 { + unscaled := new(big.Int).Set(coeff) + unscaled.Mul(unscaled, bigPow10(int(exp))) + return NewDecimal(unscaled, 0), true + } + scale := uint32(-exp) + unscaled := new(big.Int).Set(coeff) + return NewDecimal(unscaled, scale), true +} + +func isNilInterface(value any) bool { + if value == nil { + return true + } + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice, reflect.Func: + return rv.IsNil() + default: + return false + } +} + +func bigPow10(exponent int) *big.Int { + if exponent <= 0 { + return big.NewInt(1) + } + result := big.NewInt(1) + ten := big.NewInt(10) + for i := 0; i < exponent; i++ { + result.Mul(result, ten) + } + return result +} + +func bigIntToTwosComplement(value *big.Int) []byte { + if value.Sign() == 0 { + return []byte{0} + } + if value.Sign() > 0 { + bytes := value.Bytes() + if bytes[0]&0x80 != 0 { + return append([]byte{0x00}, bytes...) + } + return trimTwosComplement(bytes) + } + + bitLen := value.BitLen() + byteLen := (bitLen + 8) / 8 + if byteLen == 0 { + byteLen = 1 + } + + tmp := new(big.Int).Lsh(big.NewInt(1), uint(byteLen*8)) + tmp.Add(tmp, value) // value is negative, so this subtracts magnitude + bytes := tmp.Bytes() + if len(bytes) < int(byteLen) { + padding := make([]byte, int(byteLen)-len(bytes)) + bytes = append(padding, bytes...) + } + + bytes = trimTwosComplement(bytes) + if bytes[0]&0x80 == 0 { + bytes = append([]byte{0xFF}, bytes...) + } + return trimTwosComplement(bytes) +} + +func normalizeTwosComplement(src []byte) []byte { + if len(src) == 0 { + return []byte{0} + } + trimmed := trimTwosComplement(append([]byte(nil), src...)) + if len(trimmed) == 0 { + return []byte{0} + } + return trimmed +} + +func trimTwosComplement(bytes []byte) []byte { + if len(bytes) <= 1 { + return bytes + } + signBit := bytes[0] & 0x80 + i := 0 + for i < len(bytes)-1 { + if signBit == 0 { + if bytes[i] == 0x00 && bytes[i+1]&0x80 == 0 { + i++ + continue + } + } else { + if bytes[i] == 0xFF && bytes[i+1]&0x80 != 0 { + i++ + continue + } + } + break + } + return bytes[i:] +} + +func twosComplementToBigInt(bytes []byte) *big.Int { + if len(bytes) == 0 { + return big.NewInt(0) + } + if bytes[0]&0x80 == 0 { + return new(big.Int).SetBytes(bytes) + } + + inverted := make([]byte, len(bytes)) + for i := range bytes { + inverted[i] = ^bytes[i] + } + + magnitude := new(big.Int).SetBytes(inverted) + magnitude.Add(magnitude, big.NewInt(1)) + magnitude.Neg(magnitude) + return magnitude +} + +// validateDecimalText checks that the provided string is a valid decimal representation. +// It accepts numeric digits, optional sign, decimal point, exponent (e/E) and NaN/Infinity tokens. +func validateDecimalText(text string) error { + if text == "" { + return fmt.Errorf("decimal literal cannot be empty") + } + + switch text { + case "NaN", "Infinity", "+Infinity", "-Infinity": + return nil + } + + i := 0 + length := len(text) + if text[0] == '+' || text[0] == '-' { + if length == 1 { + return fmt.Errorf("decimal literal contains sign without digits") + } + i++ + } + + digits := 0 + seenDot := false + for i < length { + ch := text[i] + switch { + case ch >= '0' && ch <= '9': + digits++ + i++ + case ch == '.': + if seenDot { + return fmt.Errorf("decimal literal has multiple decimal points") + } + seenDot = true + i++ + case ch == 'e' || ch == 'E': + if digits == 0 && !seenDot { + return fmt.Errorf("decimal literal exponent without mantissa") + } + i++ + if i >= length { + return fmt.Errorf("decimal literal has incomplete exponent") + } + if text[i] == '+' || text[i] == '-' { + i++ + if i >= length { + return fmt.Errorf("decimal literal has incomplete exponent") + } + } + expDigits := 0 + for i < length && text[i] >= '0' && text[i] <= '9' { + i++ + expDigits++ + } + if expDigits == 0 { + return fmt.Errorf("decimal literal exponent has no digits") + } + if i != length { + return fmt.Errorf("decimal literal has trailing characters") + } + if digits == 0 && !seenDot { + return fmt.Errorf("decimal literal missing mantissa digits") + } + return nil + default: + return fmt.Errorf("decimal literal contains invalid character %q", ch) + } + } + + if digits == 0 { + return fmt.Errorf("decimal literal must contain at least one digit") + } + return nil +} diff --git a/export_test.go b/export_test.go index bcd02a66..21025153 100644 --- a/export_test.go +++ b/export_test.go @@ -63,12 +63,18 @@ func Messages(s LineSender) []byte { if hs, ok := s.(*httpLineSenderV2); ok { return hs.Messages() } + if hs, ok := s.(*httpLineSenderV3); ok { + return hs.Messages() + } if ts, ok := s.(*tcpLineSender); ok { return ts.Messages() } if ts, ok := s.(*tcpLineSenderV2); ok { return ts.Messages() } + if ts, ok := s.(*tcpLineSenderV3); ok { + return ts.Messages() + } panic("unexpected struct") } @@ -82,12 +88,18 @@ func MsgCount(s LineSender) int { if hs, ok := s.(*httpLineSenderV2); ok { return hs.MsgCount() } + if hs, ok := s.(*httpLineSenderV3); ok { + return hs.MsgCount() + } if ts, ok := s.(*tcpLineSender); ok { return ts.MsgCount() } if ts, ok := s.(*tcpLineSenderV2); ok { return ts.MsgCount() } + if ts, ok := s.(*tcpLineSenderV3); ok { + return ts.MsgCount() + } panic("unexpected struct") } @@ -101,12 +113,18 @@ func BufLen(s LineSender) int { if hs, ok := s.(*httpLineSenderV2); ok { return hs.BufLen() } + if hs, ok := s.(*httpLineSenderV3); ok { + return hs.BufLen() + } if ts, ok := s.(*tcpLineSender); ok { return ts.BufLen() } if ts, ok := s.(*tcpLineSenderV2); ok { return ts.BufLen() } + if ts, ok := s.(*tcpLineSenderV3); ok { + return ts.BufLen() + } panic("unexpected struct") } @@ -120,12 +138,18 @@ func ProtocolVersion(s LineSender) protocolVersion { if _, ok := s.(*httpLineSenderV2); ok { return ProtocolVersion2 } + if _, ok := s.(*httpLineSenderV3); ok { + return ProtocolVersion3 + } if _, ok := s.(*tcpLineSender); ok { return ProtocolVersion1 } if _, ok := s.(*tcpLineSenderV2); ok { return ProtocolVersion2 } + if _, ok := s.(*tcpLineSenderV3); ok { + return ProtocolVersion3 + } panic("unexpected struct") } diff --git a/http_sender.go b/http_sender.go index 7d070fc4..b88e1153 100644 --- a/http_sender.go +++ b/http_sender.go @@ -119,6 +119,10 @@ type httpLineSenderV2 struct { httpLineSender } +type httpLineSenderV3 struct { + httpLineSenderV2 +} + func newHttpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, error) { var transport *http.Transport s := &httpLineSender{ @@ -175,12 +179,21 @@ func newHttpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, } s.uri += fmt.Sprintf("://%s/write", s.address) - if pVersion == ProtocolVersion1 { + switch pVersion { + case ProtocolVersion1: return s, nil - } else { + case ProtocolVersion2: return &httpLineSenderV2{ *s, }, nil + case ProtocolVersion3: + return &httpLineSenderV3{ + httpLineSenderV2{ + *s, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported protocol version %d", pVersion) } } @@ -292,6 +305,11 @@ func (s *httpLineSender) Float64Column(name string, val float64) LineSender { return s } +func (s *httpLineSender) DecimalColumn(name string, val any) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) + return s +} + func (s *httpLineSender) StringColumn(name, val string) LineSender { s.buf.StringColumn(name, val) return s @@ -509,16 +527,22 @@ func parseServerSettings(resp *http.Response, conf *lineSenderConfig) (protocolV return ProtocolVersion1, nil } - hasProtocolVersion1 := false + hasProtocol1 := false + hasProtocol2 := false for _, version := range versions { - if version == 2 { - return ProtocolVersion2, nil - } - if version == 1 { - hasProtocolVersion1 = true + switch version { + case 3: + return ProtocolVersion3, nil + case 2: + hasProtocol2 = true + case 1: + hasProtocol1 = true } } - if hasProtocolVersion1 { + if hasProtocol2 { + return ProtocolVersion2, nil + } + if hasProtocol1 { return ProtocolVersion1, nil } @@ -617,3 +641,73 @@ func (s *httpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[flo s.buf.Float64ArrayNDColumn(name, values) return s } + +func (s *httpLineSenderV2) DecimalColumn(name string, val any) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) + return s +} + +func (s *httpLineSenderV3) Table(name string) LineSender { + s.buf.Table(name) + return s +} + +func (s *httpLineSenderV3) Symbol(name, val string) LineSender { + s.buf.Symbol(name, val) + return s +} + +func (s *httpLineSenderV3) Int64Column(name string, val int64) LineSender { + s.buf.Int64Column(name, val) + return s +} + +func (s *httpLineSenderV3) Long256Column(name string, val *big.Int) LineSender { + s.buf.Long256Column(name, val) + return s +} + +func (s *httpLineSenderV3) TimestampColumn(name string, ts time.Time) LineSender { + s.buf.TimestampColumn(name, ts) + return s +} + +func (s *httpLineSenderV3) StringColumn(name, val string) LineSender { + s.buf.StringColumn(name, val) + return s +} + +func (s *httpLineSenderV3) BoolColumn(name string, val bool) LineSender { + s.buf.BoolColumn(name, val) + return s +} + +func (s *httpLineSenderV3) Float64Column(name string, val float64) LineSender { + s.buf.Float64ColumnBinary(name, val) + return s +} + +func (s *httpLineSenderV3) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.Float64Array1DColumn(name, values) + return s +} + +func (s *httpLineSenderV3) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.Float64Array2DColumn(name, values) + return s +} + +func (s *httpLineSenderV3) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.Float64Array3DColumn(name, values) + return s +} + +func (s *httpLineSenderV3) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.Float64ArrayNDColumn(name, values) + return s +} + +func (s *httpLineSenderV3) DecimalColumn(name string, val any) LineSender { + s.buf.DecimalColumn(name, val) + return s +} diff --git a/http_sender_test.go b/http_sender_test.go index 1c93c9a2..69088ca0 100644 --- a/http_sender_test.go +++ b/http_sender_test.go @@ -827,7 +827,7 @@ func TestAutoDetectProtocolVersionNewServer2(t *testing.T) { func TestAutoDetectProtocolVersionNewServer3(t *testing.T) { ctx := context.Background() - srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 3}) + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 4}) assert.NoError(t, err) defer srv.Close() sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) @@ -838,13 +838,24 @@ func TestAutoDetectProtocolVersionNewServer3(t *testing.T) { func TestAutoDetectProtocolVersionNewServer4(t *testing.T) { ctx := context.Background() - srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{3}) + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{4}) assert.NoError(t, err) defer srv.Close() _, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) assert.ErrorContains(t, err, "server does not support current client") } +func TestAutoDetectProtocolVersionNewServer5(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 3}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion3) + assert.NoError(t, err) +} + func TestAutoDetectProtocolVersionError(t *testing.T) { ctx := context.Background() @@ -913,6 +924,25 @@ func TestArrayColumnUnsupportedInHttpProtocolV1(t *testing.T) { assert.Contains(t, err.Error(), "current protocol version does not support double-array") } +func TestDecimalColumnUnsupportedInHttpProtocolV2(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) + defer cancel() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.NoError(t, err) + defer sender.Close(ctx) + + err = sender. + Table(testTable). + DecimalColumn("price", "12.99"). + At(ctx, time.UnixMicro(1)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support decimal") +} + func BenchmarkHttpLineSenderBatch1000(b *testing.B) { ctx := context.Background() diff --git a/integration_test.go b/integration_test.go index 37eb7b17..9074743f 100644 --- a/integration_test.go +++ b/integration_test.go @@ -31,6 +31,7 @@ import ( "math/big" "path/filepath" "reflect" + "strings" "testing" "time" @@ -492,47 +493,93 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { Count: 3, }, }, + { + "decimal type", + testTable, + func(s qdb.LineSender) error { + err := s. + Table(testTable). + DecimalColumn("text_col", "123.45"). + DecimalColumn("binary_col", qdb.NewDecimalFromInt64(12345, 2)). + DecimalColumn("binary_neg_col", qdb.NewDecimalFromInt64(-12345, 2)). + DecimalColumn("binary_null_col", qdb.NullDecimal()). + At(ctx, time.UnixMicro(1)) + if err != nil { + return err + } + + return s. + Table(testTable). + DecimalColumn("text_col", "123.46"). + DecimalColumn("binary_col", qdb.NewDecimalFromInt64(12346, 2)). + DecimalColumn("binary_neg_col", qdb.NewDecimalFromInt64(-12346, 2)). + DecimalColumn("binary_null_col", qdb.NullDecimal()). + At(ctx, time.UnixMicro(2)) + }, + tableData{ + Columns: []column{ + {"text_col", "DECIMAL(18,3)"}, + {"binary_col", "DECIMAL(18,3)"}, + {"binary_neg_col", "DECIMAL(18,3)"}, + {"binary_null_col", "DECIMAL(18,3)"}, + {"timestamp", "TIMESTAMP"}, + }, + Dataset: [][]any{ + {"123.450", "123.450", "-123.450", nil, "1970-01-01T00:00:00.000001Z"}, + {"123.460", "123.460", "-123.460", nil, "1970-01-01T00:00:00.000002Z"}, + }, + Count: 2, + }, + }, } for _, tc := range testCases { for _, protocol := range []string{"tcp", "http"} { - for _, pVersion := range []int{0, 1, 2} { + for _, pVersion := range []int{0, 1, 2, 3} { suite.T().Run(fmt.Sprintf("%s: %s", tc.name, protocol), func(t *testing.T) { var ( sender qdb.LineSender err error ) - ignoreArray := false questdbC, err := setupQuestDB(ctx, noAuth) assert.NoError(t, err) + currentVersion := pVersion switch protocol { case "tcp": - if pVersion == 0 { + switch pVersion { + case 0: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress)) - ignoreArray = true - } else if pVersion == 1 { + currentVersion = 1 + case 1: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) - ignoreArray = true - } else if pVersion == 2 { + case 2: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + case 3: + sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) } assert.NoError(t, err) case "http": - if pVersion == 0 { + switch pVersion { + case 0: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress)) - } else if pVersion == 1 { + currentVersion = 3 + case 1: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) - ignoreArray = true - } else if pVersion == 2 { + case 2: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + case 3: + sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) } assert.NoError(t, err) default: panic(protocol) } - if ignoreArray && tc.name == "double array" { + if currentVersion < 2 && tc.name == "double array" { + return + } + if currentVersion < 3 && strings.Contains(tc.name, "decimal") { return } diff --git a/interop_test.go b/interop_test.go index 78f8ab4e..dee0d3f5 100644 --- a/interop_test.go +++ b/interop_test.go @@ -27,8 +27,10 @@ package questdb_test import ( "context" "encoding/json" + "fmt" "io" "os" + "strconv" "strings" "testing" @@ -60,8 +62,10 @@ type testColumn struct { } type testResult struct { - Status string `json:"status"` - Line string `json:"line"` + Status string `json:"status"` + Line string `json:"line"` + AnyLines []string `json:"anyLines"` + BinaryBase64 string `json:"binaryBase64"` } func TestTcpClientInterop(t *testing.T) { @@ -75,44 +79,9 @@ func TestTcpClientInterop(t *testing.T) { srv, err := newTestTcpServer(sendToBackChannel) assert.NoError(t, err) - sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) assert.NoError(t, err) - - sender.Table(tc.Table) - for _, s := range tc.Symbols { - sender.Symbol(s.Name, s.Value) - } - for _, s := range tc.Columns { - switch s.Type { - case "LONG": - sender.Int64Column(s.Name, int64(s.Value.(float64))) - case "DOUBLE": - sender.Float64Column(s.Name, s.Value.(float64)) - case "STRING": - sender.StringColumn(s.Name, s.Value.(string)) - case "BOOLEAN": - sender.BoolColumn(s.Name, s.Value.(bool)) - default: - assert.Fail(t, "unexpected column type: "+s.Type) - } - } - - err = sender.AtNow(ctx) - - switch tc.Result.Status { - case "SUCCESS": - assert.NoError(t, err) - err = sender.Flush(ctx) - assert.NoError(t, err) - - expectLines(t, srv.BackCh, strings.Split(tc.Result.Line, "\n")) - case "ERROR": - assert.Error(t, err) - default: - assert.Fail(t, "unexpected test status: "+tc.Result.Status) - } - - sender.Close(ctx) + execute(t, ctx, sender, srv.BackCh, tc) srv.Close() }) } @@ -129,49 +98,92 @@ func TestHttpClientInterop(t *testing.T) { srv, err := newTestHttpServer(sendToBackChannel) assert.NoError(t, err) - sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) assert.NoError(t, err) - - sender.Table(tc.Table) - for _, s := range tc.Symbols { - sender.Symbol(s.Name, s.Value) - } - for _, s := range tc.Columns { - switch s.Type { - case "LONG": - sender.Int64Column(s.Name, int64(s.Value.(float64))) - case "DOUBLE": - sender.Float64Column(s.Name, s.Value.(float64)) - case "STRING": - sender.StringColumn(s.Name, s.Value.(string)) - case "BOOLEAN": - sender.BoolColumn(s.Name, s.Value.(bool)) - default: - assert.Fail(t, "unexpected column type: "+s.Type) - } - } - - err = sender.AtNow(ctx) - - switch tc.Result.Status { - case "SUCCESS": - assert.NoError(t, err) - err = sender.Flush(ctx) - assert.NoError(t, err) - - expectLines(t, srv.BackCh, strings.Split(tc.Result.Line, "\n")) - case "ERROR": - assert.Error(t, err) - default: - assert.Fail(t, "unexpected test status: "+tc.Result.Status) - } - - sender.Close(ctx) + execute(t, ctx, sender, srv.BackCh, tc) srv.Close() }) } } +func execute(t *testing.T, ctx context.Context, sender qdb.LineSender, backCh chan string, tc testCase) { + sender.Table(tc.Table) + for _, s := range tc.Symbols { + sender.Symbol(s.Name, s.Value) + } + for _, s := range tc.Columns { + switch s.Type { + case "LONG": + sender.Int64Column(s.Name, int64(s.Value.(float64))) + case "DOUBLE": + sender.Float64Column(s.Name, s.Value.(float64)) + case "STRING": + sender.StringColumn(s.Name, s.Value.(string)) + case "BOOLEAN": + sender.BoolColumn(s.Name, s.Value.(bool)) + case "DECIMAL": + dec, err := parseDecimal64(s.Value.(string)) + assert.NoError(t, err) + sender.DecimalColumn(s.Name, dec) + default: + assert.Fail(t, "unexpected column type: "+s.Type) + } + } + + err := sender.AtNow(ctx) + + switch tc.Result.Status { + case "SUCCESS": + assert.NoError(t, err) + err = sender.Flush(ctx) + assert.NoError(t, err) + + if len(tc.Result.BinaryBase64) > 0 { + expectBinaryBase64(t, backCh, tc.Result.BinaryBase64) + } else if len(tc.Result.AnyLines) > 0 { + expectAnyLines(t, backCh, tc.Result.AnyLines) + } else { + expectLines(t, backCh, strings.Split(tc.Result.Line, "\n")) + } + case "ERROR": + assert.Error(t, err) + default: + assert.Fail(t, "unexpected test status: "+tc.Result.Status) + } + + sender.Close(ctx) +} + +// parseDecimal64 quick and dirty parser for a decimal64 value from its string representation +func parseDecimal64(s string) (qdb.ScaledDecimal, error) { + // Remove whitespace + s = strings.TrimSpace(s) + + // Check for empty string + if s == "" { + return qdb.ScaledDecimal{}, fmt.Errorf("empty string") + } + + // Find the decimal point and remove it + pointIndex := strings.Index(s, ".") + if pointIndex != -1 { + s = strings.ReplaceAll(s, ".", "") + } + + // Parse the integer part + unscaled, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return qdb.ScaledDecimal{}, err + } + + scale := 0 + if pointIndex != -1 { + scale = len(s) - pointIndex + } + + return qdb.NewDecimalFromInt64(unscaled, uint32(scale)), nil +} + func readTestCases() (testCases, error) { file, err := os.Open("./test/interop/questdb-client-test/ilp-client-interop-test.json") if err != nil { diff --git a/sender.go b/sender.go index b6f38f28..39b3c1b4 100644 --- a/sender.go +++ b/sender.go @@ -106,6 +106,17 @@ type LineSender interface { // '-', '*' '%%', '~', or a non-printable char. Float64Column(name string, val float64) LineSender + // DecimalColumn adds a decimal column value to the ILP message. + // + // Supported value types include questdb.Decimal, any custom type implementing + // questdb.DecimalMarshaler, github.com/shopspring/decimal.Decimal (value or pointer), + // and nil to encode a NULL decimal. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + DecimalColumn(name string, val any) LineSender + // StringColumn adds a string column value to the ILP message. // // Column name cannot contain any of the following characters: @@ -253,6 +264,7 @@ const ( protocolVersionUnset protocolVersion = 0 ProtocolVersion1 protocolVersion = 1 ProtocolVersion2 protocolVersion = 2 + ProtocolVersion3 protocolVersion = 3 ) type lineSenderConfig struct { @@ -479,8 +491,10 @@ func WithAutoFlushInterval(interval time.Duration) LineSenderOption { // - TCP transport does not negotiate the protocol version and uses [ProtocolVersion1] by // default. You must explicitly set [ProtocolVersion2] in order to ingest // arrays. +// - [ProtocolVersion3] enables decimal binary encoding (ILP v3). // // NOTE: QuestDB server version 9.0.0 or later is required for [ProtocolVersion2]. +// For [ProtocolVersion3], make sure the server advertises ILP v3 support via /settings. func WithProtocolVersion(version protocolVersion) LineSenderOption { return func(s *lineSenderConfig) { s.protocolVersion = version @@ -721,9 +735,9 @@ func validateConf(conf *lineSenderConfig) error { if conf.autoFlushInterval < 0 { return fmt.Errorf("auto flush interval is negative: %d", conf.autoFlushInterval) } - if conf.protocolVersion < protocolVersionUnset || conf.protocolVersion > ProtocolVersion2 { - return errors.New("current client only supports protocol version 1(text format for all datatypes), " + - "2(binary format for part datatypes) or explicitly unset") + if conf.protocolVersion < protocolVersionUnset || conf.protocolVersion > ProtocolVersion3 { + return errors.New("current client only supports protocol version 1 (text format for all datatypes), " + + "2 (binary format for floats/arrays), 3 (binary decimals) or explicitly unset") } return nil diff --git a/sender_pool.go b/sender_pool.go index 7da85726..9bf0870c 100644 --- a/sender_pool.go +++ b/sender_pool.go @@ -314,6 +314,11 @@ func (ps *pooledSender) Float64Column(name string, val float64) LineSender { return ps } +func (ps *pooledSender) DecimalColumn(name string, val any) LineSender { + ps.wrapped.DecimalColumn(name, val) + return ps +} + func (ps *pooledSender) StringColumn(name, val string) LineSender { ps.wrapped.StringColumn(name, val) return ps diff --git a/tcp_sender.go b/tcp_sender.go index f46c5de1..f331ba40 100644 --- a/tcp_sender.go +++ b/tcp_sender.go @@ -50,6 +50,10 @@ type tcpLineSenderV2 struct { tcpLineSender } +type tcpLineSenderV3 struct { + tcpLineSenderV2 +} + func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, error) { var ( d net.Dialer @@ -136,12 +140,27 @@ func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, s.conn = conn - if conf.protocolVersion == protocolVersionUnset || conf.protocolVersion == ProtocolVersion1 { + pVersion := conf.protocolVersion + if pVersion == protocolVersionUnset { + pVersion = ProtocolVersion1 + } + + switch pVersion { + case ProtocolVersion1: return s, nil - } else { + case ProtocolVersion2: return &tcpLineSenderV2{ *s, }, nil + case ProtocolVersion3: + return &tcpLineSenderV3{ + tcpLineSenderV2{ + *s, + }, + }, nil + default: + conn.Close() + return nil, fmt.Errorf("unsupported protocol version %d", pVersion) } } @@ -184,6 +203,11 @@ func (s *tcpLineSender) Float64Column(name string, val float64) LineSender { return s } +func (s *tcpLineSender) DecimalColumn(name string, val any) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) + return s +} + func (s *tcpLineSender) StringColumn(name, val string) LineSender { s.buf.StringColumn(name, val) return s @@ -346,3 +370,73 @@ func (s *tcpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[floa s.buf.Float64ArrayNDColumn(name, values) return s } + +func (s *tcpLineSenderV2) DecimalColumn(name string, val any) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) + return s +} + +func (s *tcpLineSenderV3) Table(name string) LineSender { + s.buf.Table(name) + return s +} + +func (s *tcpLineSenderV3) Symbol(name, val string) LineSender { + s.buf.Symbol(name, val) + return s +} + +func (s *tcpLineSenderV3) Int64Column(name string, val int64) LineSender { + s.buf.Int64Column(name, val) + return s +} + +func (s *tcpLineSenderV3) Long256Column(name string, val *big.Int) LineSender { + s.buf.Long256Column(name, val) + return s +} + +func (s *tcpLineSenderV3) TimestampColumn(name string, ts time.Time) LineSender { + s.buf.TimestampColumn(name, ts) + return s +} + +func (s *tcpLineSenderV3) StringColumn(name, val string) LineSender { + s.buf.StringColumn(name, val) + return s +} + +func (s *tcpLineSenderV3) BoolColumn(name string, val bool) LineSender { + s.buf.BoolColumn(name, val) + return s +} + +func (s *tcpLineSenderV3) Float64Column(name string, val float64) LineSender { + s.buf.Float64ColumnBinary(name, val) + return s +} + +func (s *tcpLineSenderV3) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.Float64Array1DColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.Float64Array2DColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.Float64Array3DColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.Float64ArrayNDColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) DecimalColumn(name string, val any) LineSender { + s.buf.DecimalColumn(name, val) + return s +} diff --git a/tcp_sender_test.go b/tcp_sender_test.go index ddc496e1..132e0cd8 100644 --- a/tcp_sender_test.go +++ b/tcp_sender_test.go @@ -366,6 +366,25 @@ func TestArrayColumnUnsupportedInTCPProtocolV1(t *testing.T) { assert.Contains(t, err.Error(), "current protocol version does not support double-array") } +func TestDecimalColumnUnsupportedInTCPProtocolV2(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) + defer cancel() + + srv, err := newTestTcpServer(readAndDiscard) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + assert.NoError(t, err) + defer sender.Close(ctx) + + err = sender. + Table(testTable). + DecimalColumn("price", "12.99"). + At(ctx, time.UnixMicro(1)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support decimal") +} + func BenchmarkLineSenderBatch1000(b *testing.B) { ctx := context.Background() diff --git a/test/interop/questdb-client-test b/test/interop/questdb-client-test index 42a30831..1aaa3f96 160000 --- a/test/interop/questdb-client-test +++ b/test/interop/questdb-client-test @@ -1 +1 @@ -Subproject commit 42a30831f1852ed0ac85ab466f5d9ad711b787ee +Subproject commit 1aaa3f96ab06c6bef7d1b08400c418ef87562e56 diff --git a/utils_test.go b/utils_test.go index d37cf5d5..9dfd79bb 100644 --- a/utils_test.go +++ b/utils_test.go @@ -26,6 +26,7 @@ package questdb_test import ( "bufio" + "encoding/base64" "encoding/json" "fmt" "io" @@ -39,6 +40,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) type serverType int64 @@ -341,3 +343,30 @@ func expectLines(t *testing.T, linesCh chan string, expected []string) { return reflect.DeepEqual(expected, actual) }, 10*time.Second, 100*time.Millisecond) } + +func expectAnyLines(t *testing.T, linesCh chan string, expected []string) { + assert.Eventually(t, func() bool { + select { + case l := <-linesCh: + return slices.Contains(expected, l) + default: + return false + } + }, 10*time.Second, 100*time.Millisecond) +} + +func expectBinaryBase64(t *testing.T, linesCh chan string, expected string) { + data, err := base64.StdEncoding.DecodeString(expected) + assert.NoError(t, err) + + actual := make([]byte, 0) + assert.Eventually(t, func() bool { + select { + case l := <-linesCh: + actual = append(actual, []byte(l+"\n")...) + default: + return false + } + return slices.Equal(data, actual) + }, 10*time.Second, 100*time.Millisecond) +} From c186a733c48a05ea4317fadbe47c79ee2c990c0c Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Tue, 21 Oct 2025 17:22:01 +0200 Subject: [PATCH 02/13] refactor: remove unused constant for decimal type code --- buffer_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/buffer_test.go b/buffer_test.go index 5eaa2a98..e24dc07d 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -37,8 +37,6 @@ import ( "github.com/stretchr/testify/assert" ) -const decimalTypeCode byte = 0x17 - type bufWriterFn func(b *qdb.Buffer) error type fakeShopspringDecimal struct { From 6c3459f545d50c871abc7c403700ac0bef4c5b5b Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 22 Oct 2025 09:19:18 +0200 Subject: [PATCH 03/13] docs: clarify comment on NewScaledDecimal regarding nil/empty unscaled slice --- decimal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decimal.go b/decimal.go index 32990519..a488c2a6 100644 --- a/decimal.go +++ b/decimal.go @@ -56,7 +56,7 @@ type shopspringDecimal interface { } // NewScaledDecimal constructs a decimal from a two's complement big-endian unscaled value and a scale. -// A nil unscaled slice produces a NULL decimal. +// A nil/empty unscaled slice produces a NULL decimal. func NewScaledDecimal(unscaled []byte, scale uint32) ScaledDecimal { if len(unscaled) == 0 { return NullDecimal() From f486bab1b257ee5b1ae5f41336a8aa8b00c94a76 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 22 Oct 2025 09:21:35 +0200 Subject: [PATCH 04/13] refactor: simplify protocol version handling in integration tests --- integration_test.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/integration_test.go b/integration_test.go index 9074743f..89065924 100644 --- a/integration_test.go +++ b/integration_test.go @@ -545,13 +545,11 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { questdbC, err := setupQuestDB(ctx, noAuth) assert.NoError(t, err) - currentVersion := pVersion switch protocol { case "tcp": switch pVersion { case 0: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress)) - currentVersion = 1 case 1: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) case 2: @@ -564,7 +562,6 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { switch pVersion { case 0: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress)) - currentVersion = 3 case 1: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) case 2: @@ -576,10 +573,11 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { default: panic(protocol) } - if currentVersion < 2 && tc.name == "double array" { + senderVersion := qdb.ProtocolVersion(sender) + if senderVersion < 2 && tc.name == "double array" { return } - if currentVersion < 3 && strings.Contains(tc.name, "decimal") { + if senderVersion < 3 && strings.Contains(tc.name, "decimal") { return } From 100bba6d778b3010b68c10fea285e92ec0b8a3ed Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 22 Oct 2025 09:25:28 +0200 Subject: [PATCH 05/13] fix: update error message for decimal value size limit --- buffer_test.go | 2 +- decimal.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/buffer_test.go b/buffer_test.go index e24dc07d..c5e6f5c9 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -608,7 +608,7 @@ func TestDecimalColumnErrors(t *testing.T) { bigVal := new(big.Int).Lsh(big.NewInt(1), 2100) dec := qdb.NewDecimal(bigVal, 0) err := buf.Table(testTable).DecimalColumn("price", dec).At(time.Time{}, false) - assert.ErrorContains(t, err, "exceeds 256-bit range") + assert.ErrorContains(t, err, "exceeds 127-bytes limit") assert.Empty(t, buf.Messages()) }) diff --git a/decimal.go b/decimal.go index a488c2a6..59d5d684 100644 --- a/decimal.go +++ b/decimal.go @@ -139,7 +139,7 @@ func (d ScaledDecimal) toBinary() (byte, []byte, error) { payload = []byte{0} } if len(payload) > maxDecimalBytes { - return 0, nil, fmt.Errorf("decimal value exceeds 256-bit range (got %d bytes)", len(payload)) + return 0, nil, fmt.Errorf("decimal value exceeds 127-bytes limit (got %d bytes)", len(payload)) } return byte(d.scale), payload, nil } From 5a3e29b2f440d0c4c1179387d87b03c489f8f2c0 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 22 Oct 2025 15:35:00 +0200 Subject: [PATCH 06/13] fix: improve validation for decimal text exponent handling --- decimal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decimal.go b/decimal.go index 59d5d684..d8fdde23 100644 --- a/decimal.go +++ b/decimal.go @@ -361,7 +361,7 @@ func validateDecimalText(text string) error { seenDot = true i++ case ch == 'e' || ch == 'E': - if digits == 0 && !seenDot { + if digits == 0 { return fmt.Errorf("decimal literal exponent without mantissa") } i++ From b2e697e2f11d07b396a38371dd58bd8a10dfd8f3 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Tue, 28 Oct 2025 10:55:38 +0100 Subject: [PATCH 07/13] fix: handle null decimals and improve error reporting in DecimalColumn --- buffer.go | 7 ++++++- buffer_test.go | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/buffer.go b/buffer.go index bb71f3b6..36e520b6 100644 --- a/buffer.go +++ b/buffer.go @@ -581,12 +581,12 @@ func (b *buffer) DecimalColumn(name string, val any) *buffer { if b.lastErr != nil { return b } - b.WriteByte('=') if str, ok := val.(string); ok { if err := validateDecimalText(str); err != nil { b.lastErr = err return b } + b.WriteByte('=') b.WriteString(str) b.WriteByte('d') b.hasFields = true @@ -603,6 +603,11 @@ func (b *buffer) DecimalColumn(name string, val any) *buffer { b.lastErr = err return b } + if len(payload) == 0 { + // Don't write null decimals + return b + } + b.WriteByte('=') b.WriteByte('=') b.WriteByte(decimalBinaryTypeCode) b.WriteByte(scale) diff --git a/buffer_test.go b/buffer_test.go index c5e6f5c9..b53b41ab 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -618,6 +618,13 @@ func TestDecimalColumnErrors(t *testing.T) { assert.ErrorContains(t, err, "unsupported decimal column value type") assert.Empty(t, buf.Messages()) }) + + t.Run("no column", func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumn("price", qdb.NullDecimal()).At(time.Time{}, false) + assert.ErrorContains(t, err, "no symbols or columns were provided: invalid message") + assert.Empty(t, buf.Messages()) + }) } func TestFloat64Array1DColumn(t *testing.T) { From a1c3981767374e7e2146bfde7b68e786d6352a2a Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Tue, 28 Oct 2025 10:57:47 +0100 Subject: [PATCH 08/13] fix: remove unnecessary check for mantissa digits in decimal validation --- decimal.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/decimal.go b/decimal.go index d8fdde23..ef68641d 100644 --- a/decimal.go +++ b/decimal.go @@ -385,9 +385,6 @@ func validateDecimalText(text string) error { if i != length { return fmt.Errorf("decimal literal has trailing characters") } - if digits == 0 && !seenDot { - return fmt.Errorf("decimal literal missing mantissa digits") - } return nil default: return fmt.Errorf("decimal literal contains invalid character %q", ch) From d2b6670b9045ffb3e18a94f061b977eccc8a4cd3 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 29 Oct 2025 14:51:14 +0100 Subject: [PATCH 09/13] refactor: rework decimal integration to conform to other types --- README.md | 17 ++-- buffer.go | 70 +++++++++------ buffer_test.go | 156 +++++++++++++++++++++++++++------- decimal.go | 201 ++++++++++++++------------------------------ http_sender.go | 42 +++++++-- http_sender_test.go | 2 +- integration_test.go | 21 +++-- interop_test.go | 2 +- sender.go | 25 +++++- sender_pool.go | 14 ++- tcp_sender.go | 42 +++++++-- tcp_sender_test.go | 2 +- utils_test.go | 2 +- 13 files changed, 361 insertions(+), 235 deletions(-) diff --git a/README.md b/README.md index 804cd522..f1f051ff 100644 --- a/README.md +++ b/README.md @@ -191,12 +191,11 @@ ingestion workload. ## Decimal columns QuestDB server version 9.2.0 and newer supports decimal columns with arbitrary precision and scale. -The Go client converts supported decimal values to QuestDB's text/binary wire format automatically, pass any of the following to `DecimalColumn`: +The Go client converts supported decimal values to QuestDB's text/binary wire format automatically: -- `questdb.ScaledDecimal`, including helpers like `questdb.NewDecimalFromInt64` and `questdb.NewDecimal`. -- Types implementing `questdb.DecimalMarshaler`. -- `github.com/shopspring/decimal.Decimal` values or pointers. -- `nil` or `questdb.NullDecimal()` to send a `NULL`. +- `DecimalColumnScaled`: `questdb.ScaledDecimal`, including helpers like `questdb.NewDecimalFromInt64` and `questdb.NewDecimal`. +- `DecimalColumnShopspring`: `github.com/shopspring/decimal.Decimal` values or pointers. +- `DecimalColumnString`: `string` literals representing decimal values (validated at runtime). ```go price := qdb.NewDecimalFromInt64(12345, 2) // 123.45 with scale 2 @@ -205,17 +204,17 @@ commission := qdb.NewDecimal(big.NewInt(-750), 4) // -0.0750 with scale 4 err = sender. Table("trades"). Symbol("symbol", "ETH-USD"). - DecimalColumn("price", price). - DecimalColumn("commission", commission). + DecimalColumnScaled("price", price). + DecimalColumnScaled("commission", commission). AtNow(ctx) ``` -To emit textual decimals, pass a validated string literal (without the trailing `d`—the client adds it): +To emit textual decimals, pass a validated string literal: ```go err = sender. Table("quotes"). - DecimalColumn("mid", "1.23456"). + DecimalColumnString("mid", "1.23456"). AtNow(ctx) ``` diff --git a/buffer.go b/buffer.go index 36e520b6..a3ce1c99 100644 --- a/buffer.go +++ b/buffer.go @@ -573,52 +573,70 @@ func (b *buffer) Float64Column(name string, val float64) *buffer { return b } -func (b *buffer) DecimalColumn(name string, val any) *buffer { +func (b *buffer) DecimalColumnScaled(name string, val ScaledDecimal) *buffer { if !b.prepareForField() { return b } - b.lastErr = b.writeColumnName(name) - if b.lastErr != nil { + return b.decimalColumnScaled(name, val) +} + +func (b *buffer) decimalColumnScaled(name string, val ScaledDecimal) *buffer { + if err := val.ensureValidScale(); err != nil { + b.lastErr = err return b } - if str, ok := val.(string); ok { - if err := validateDecimalText(str); err != nil { - b.lastErr = err - return b - } - b.WriteByte('=') - b.WriteString(str) - b.WriteByte('d') - b.hasFields = true + if val.IsNull() { + // Don't write null decimals + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { return b } + b.WriteByte('=') + b.WriteByte('=') + b.WriteByte(decimalBinaryTypeCode) + b.WriteByte((uint8)(val.scale)) + b.WriteByte(32 - val.offset) + b.Write(val.unscaled[val.offset:]) + b.hasFields = true + return b +} - dec, err := normalizeDecimalValue(val) - if err != nil { - b.lastErr = err +func (b *buffer) DecimalColumnString(name string, val string) *buffer { + if !b.prepareForField() { return b } - scale, payload, err := dec.toBinary() - if err != nil { + if err := validateDecimalText(val); err != nil { b.lastErr = err return b } - if len(payload) == 0 { - // Don't write null decimals + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { return b } b.WriteByte('=') - b.WriteByte('=') - b.WriteByte(decimalBinaryTypeCode) - b.WriteByte(scale) - b.WriteByte(byte(len(payload))) - if len(payload) > 0 { - b.Write(payload) - } + b.WriteString(val) + b.WriteByte('d') b.hasFields = true return b } +func (b *buffer) DecimalColumnShopspring(name string, val ShopspringDecimal) *buffer { + if !b.prepareForField() { + return b + } + if val == nil { + return b + } + dec, err := convertShopspringDecimal(val) + if err != nil { + b.lastErr = err + return b + } + return b.decimalColumnScaled(name, dec) +} + func (b *buffer) Float64ColumnBinary(name string, val float64) *buffer { if !b.prepareForField() { return b diff --git a/buffer_test.go b/buffer_test.go index b53b41ab..e478efa7 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -494,11 +494,14 @@ func TestFloat64ColumnBinary(t *testing.T) { } } -func TestDecimalColumnText(t *testing.T) { +func TestDecimalColumnScaled(t *testing.T) { + negative, err := qdb.NewDecimal(big.NewInt(-12345), 3) + assert.NoError(t, err) + prefix := []byte(testTable + " price==") testCases := []struct { name string - value any + value qdb.ScaledDecimal expected []byte }{ { @@ -508,7 +511,7 @@ func TestDecimalColumnText(t *testing.T) { }, { name: "negative", - value: qdb.NewDecimal(big.NewInt(-12345), 3), + value: negative, expected: append(prefix, 0x17, 0x03, 0x02, 0xCF, 0xC7, 0x0A), }, { @@ -516,29 +519,132 @@ func TestDecimalColumnText(t *testing.T) { value: qdb.NewDecimalFromInt64(0, 4), expected: append(prefix, 0x17, 0x04, 0x01, 0x0, 0x0A), }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumnScaled("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, tc.expected, buf.Messages()) + }) + } +} + +func TestDecimalColumnScaledTrimmingAndPadding(t *testing.T) { + prefix := []byte(testTable + " price==") + + testCases := []struct { + name string + value qdb.ScaledDecimal + expectedBytes []byte + }{ + { + name: "127 boundary", + value: qdb.NewDecimalFromInt64(127, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0x7F}, + }, + { + name: "128 sign extension", + value: qdb.NewDecimalFromInt64(128, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0x00, 0x80}, + }, + { + name: "255 sign extension", + value: qdb.NewDecimalFromInt64(255, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0x00, 0xFF}, + }, + { + name: "32768 sign extension", + value: qdb.NewDecimalFromInt64(32768, 0), + expectedBytes: []byte{0x17, 0x00, 0x03, 0x00, 0x80, 0x00}, + }, { - name: "null decimal", - value: qdb.NullDecimal(), - expected: append(prefix, 0x17, 0x0, 0x0, 0x0A), + name: "-1", + value: qdb.NewDecimalFromInt64(-1, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0xFF}, }, { - name: "shopspring compatible", - value: fakeShopspringDecimal{coeff: big.NewInt(123456), exp: -4}, - expected: append(prefix, 0x17, 0x04, 0x03, 0x01, 0xE2, 0x40, 0x0A), + name: "-2", + value: qdb.NewDecimalFromInt64(-2, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0xFE}, }, { - name: "nil pointer treated as null", - value: (*fakeShopspringDecimal)(nil), - expected: append(prefix, 0x17, 0x0, 0x0, 0x0A), + name: "-127", + value: qdb.NewDecimalFromInt64(-127, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0x81}, + }, + { + name: "-128", + value: qdb.NewDecimalFromInt64(-128, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0x80}, + }, + { + name: "-129", + value: qdb.NewDecimalFromInt64(-129, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0xFF, 0x7F}, + }, + { + name: "-256 sign extension", + value: qdb.NewDecimalFromInt64(-256, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0xFF, 0x00}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { buf := newTestBuffer() - err := buf.Table(testTable).DecimalColumn("price", tc.value).At(time.Time{}, false) + + err := buf.Table(testTable).DecimalColumnScaled("price", tc.value).At(time.Time{}, false) assert.NoError(t, err) - assert.Equal(t, tc.expected, buf.Messages()) + + expected := append(append([]byte{}, prefix...), tc.expectedBytes...) + expected = append(expected, '\n') + assert.Equal(t, expected, buf.Messages()) + }) + } +} + +func TestDecimalColumnShopspring(t *testing.T) { + prefix := []byte(testTable + " price==") + + testCases := []struct { + name string + value fakeShopspringDecimal + expectedBytes []byte + }{ + { + name: "negative exponent scales value", + value: fakeShopspringDecimal{coeff: big.NewInt(12345), exp: -2}, + expectedBytes: []byte{0x17, 0x02, 0x02, 0x30, 0x39}, + }, + { + name: "positive exponent multiplies coefficient", + value: fakeShopspringDecimal{coeff: big.NewInt(123), exp: 2}, + expectedBytes: []byte{0x17, 0x00, 0x02, 0x30, 0x0C}, + }, + { + name: "positive value sign extension", + value: fakeShopspringDecimal{coeff: big.NewInt(128), exp: 0}, + expectedBytes: []byte{0x17, 0x00, 0x02, 0x00, 0x80}, + }, + { + name: "negative value sign extension", + value: fakeShopspringDecimal{coeff: big.NewInt(-12345), exp: -3}, + expectedBytes: []byte{0x17, 0x03, 0x02, 0xCF, 0xC7}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + + err := buf.Table(testTable).DecimalColumnShopspring("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + + expected := append(append([]byte{}, prefix...), tc.expectedBytes...) + expected = append(expected, '\n') + assert.Equal(t, expected, buf.Messages()) }) } } @@ -561,7 +667,7 @@ func TestDecimalColumnStringValidation(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { buf := newTestBuffer() - err := buf.Table(testTable).DecimalColumn("price", tc.value).At(time.Time{}, false) + err := buf.Table(testTable).DecimalColumnString("price", tc.value).At(time.Time{}, false) assert.NoError(t, err) expected := []byte(testTable + " price=" + tc.expected + "\n") assert.Equal(t, expected, buf.Messages()) @@ -585,7 +691,7 @@ func TestDecimalColumnStringValidation(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { buf := newTestBuffer() - err := buf.Table(testTable).DecimalColumn("price", tc.value).At(time.Time{}, false) + err := buf.Table(testTable).DecimalColumnString("price", tc.value).At(time.Time{}, false) assert.Error(t, err) assert.Contains(t, err.Error(), "decimal") assert.Empty(t, buf.Messages()) @@ -598,30 +704,20 @@ func TestDecimalColumnErrors(t *testing.T) { t.Run("invalid scale", func(t *testing.T) { buf := newTestBuffer() dec := qdb.NewDecimalFromInt64(1, 100) - err := buf.Table(testTable).DecimalColumn("price", dec).At(time.Time{}, false) + err := buf.Table(testTable).DecimalColumnScaled("price", dec).At(time.Time{}, false) assert.ErrorContains(t, err, "decimal scale") assert.Empty(t, buf.Messages()) }) t.Run("overflow", func(t *testing.T) { - buf := newTestBuffer() bigVal := new(big.Int).Lsh(big.NewInt(1), 2100) - dec := qdb.NewDecimal(bigVal, 0) - err := buf.Table(testTable).DecimalColumn("price", dec).At(time.Time{}, false) - assert.ErrorContains(t, err, "exceeds 127-bytes limit") - assert.Empty(t, buf.Messages()) - }) - - t.Run("unsupported type", func(t *testing.T) { - buf := newTestBuffer() - err := buf.Table(testTable).DecimalColumn("price", struct{}{}).At(time.Time{}, false) - assert.ErrorContains(t, err, "unsupported decimal column value type") - assert.Empty(t, buf.Messages()) + _, err := qdb.NewDecimal(bigVal, 0) + assert.ErrorContains(t, err, "exceeds 32 bytes") }) t.Run("no column", func(t *testing.T) { buf := newTestBuffer() - err := buf.Table(testTable).DecimalColumn("price", qdb.NullDecimal()).At(time.Time{}, false) + err := buf.Table(testTable).DecimalColumnShopspring("price", nil).At(time.Time{}, false) assert.ErrorContains(t, err, "no symbols or columns were provided: invalid message") assert.Empty(t, buf.Messages()) }) diff --git a/decimal.go b/decimal.go index ef68641d..0fb5fef7 100644 --- a/decimal.go +++ b/decimal.go @@ -28,79 +28,81 @@ import ( "encoding/binary" "fmt" "math/big" - "reflect" ) const ( decimalBinaryTypeCode byte = 0x17 maxDecimalScale uint32 = 76 - maxDecimalBytes int = 127 ) // ScaledDecimal represents a decimal value as a two's complement big-endian byte slice and a scale. -// NULL decimals are represented by valid=false. +// NULL decimals are represented by an offset of 32. type ScaledDecimal struct { scale uint32 - unscaled []byte - valid bool + unscaled [32]byte + offset uint8 } -// DecimalMarshaler allows custom types to provide a QuestDB-compatible decimal representation. -type DecimalMarshaler interface { - QuestDBDecimal() (ScaledDecimal, error) -} - -type shopspringDecimal interface { +type ShopspringDecimal interface { Coefficient() *big.Int Exponent() int32 } // NewScaledDecimal constructs a decimal from a two's complement big-endian unscaled value and a scale. // A nil/empty unscaled slice produces a NULL decimal. -func NewScaledDecimal(unscaled []byte, scale uint32) ScaledDecimal { +func NewScaledDecimal(unscaled []byte, scale uint32) (ScaledDecimal, error) { if len(unscaled) == 0 { - return NullDecimal() + return ScaledDecimal{ + offset: 32, + }, nil + } + normalized, offset, err := normalizeTwosComplement(unscaled) + if err != nil { + return ScaledDecimal{}, err } return ScaledDecimal{ scale: scale, - unscaled: normalizeTwosComplement(unscaled), - valid: true, - } + unscaled: normalized, + offset: offset, + }, nil } // NewDecimal constructs a decimal from an arbitrary-precision integer and a scale. // Providing a nil unscaled value produces a NULL decimal. -func NewDecimal(unscaled *big.Int, scale uint32) ScaledDecimal { +func NewDecimal(unscaled *big.Int, scale uint32) (ScaledDecimal, error) { if unscaled == nil { - return NullDecimal() + return ScaledDecimal{ + offset: 32, + }, nil + } + unscaledRaw, offset, err := bigIntToTwosComplement(unscaled) + if err != nil { + return ScaledDecimal{}, err } return ScaledDecimal{ scale: scale, - unscaled: bigIntToTwosComplement(unscaled), - valid: true, - } + unscaled: unscaledRaw, + offset: offset, + }, nil } // NewDecimalFromInt64 constructs a decimal from a 64-bit integer and a scale. func NewDecimalFromInt64(unscaled int64, scale uint32) ScaledDecimal { var be [8]byte binary.BigEndian.PutUint64(be[:], uint64(unscaled)) - payload := trimTwosComplement(be[:]) + offset := trimTwosComplement(be[:]) + payload := [32]byte{} + copy(payload[32-(8-offset):], be[offset:]) return ScaledDecimal{ scale: scale, unscaled: payload, - valid: true, + offset: uint8(32 - (8 - offset)), } } -// NullDecimal returns a NULL decimal representation. -func NullDecimal() ScaledDecimal { - return ScaledDecimal{} -} - // IsNull reports whether the decimal represents NULL. func (d ScaledDecimal) IsNull() bool { - return !d.valid + return d.offset >= 32 } // Scale returns the decimal scale. @@ -114,7 +116,7 @@ func (d ScaledDecimal) UnscaledValue() *big.Int { if d.IsNull() { return nil } - return twosComplementToBigInt(d.unscaled) + return twosComplementToBigInt(d.unscaled[d.offset:]) } func (d ScaledDecimal) ensureValidScale() error { @@ -127,104 +129,26 @@ func (d ScaledDecimal) ensureValidScale() error { return nil } -func (d ScaledDecimal) toBinary() (byte, []byte, error) { - if d.IsNull() { - return 0, nil, nil - } - if err := d.ensureValidScale(); err != nil { - return 0, nil, err - } - payload := append([]byte(nil), d.unscaled...) - if len(payload) == 0 { - payload = []byte{0} - } - if len(payload) > maxDecimalBytes { - return 0, nil, fmt.Errorf("decimal value exceeds 127-bytes limit (got %d bytes)", len(payload)) - } - return byte(d.scale), payload, nil -} - -func normalizeDecimalValue(value any) (ScaledDecimal, error) { - if value == nil { - return NullDecimal(), nil - } - - switch v := value.(type) { - case ScaledDecimal: - return canonicalDecimal(v), nil - case *ScaledDecimal: - if v == nil { - return NullDecimal(), nil - } - return canonicalDecimal(*v), nil - case DecimalMarshaler: - if isNilInterface(v) { - return NullDecimal(), nil - } - dec, err := v.QuestDBDecimal() - if err != nil { - return ScaledDecimal{}, err - } - return canonicalDecimal(dec), nil - } - - if dec, ok := convertShopspringDecimal(value); ok { - return dec, nil - } - - return ScaledDecimal{}, fmt.Errorf("unsupported decimal column value type %T", value) -} - -func canonicalDecimal(d ScaledDecimal) ScaledDecimal { - if !d.valid { - return NullDecimal() - } - if len(d.unscaled) == 0 { - return NullDecimal() - } - return ScaledDecimal{ - scale: d.scale, - unscaled: normalizeTwosComplement(d.unscaled), - valid: true, - } -} - -func convertShopspringDecimal(value any) (ScaledDecimal, bool) { - dec, ok := value.(shopspringDecimal) - if !ok { - return ScaledDecimal{}, false - } - if isNilInterface(dec) { - return NullDecimal(), true - } - - coeff := dec.Coefficient() +func convertShopspringDecimal(value ShopspringDecimal) (ScaledDecimal, error) { + coeff := value.Coefficient() if coeff == nil { - return NullDecimal(), true + return ScaledDecimal{ + offset: 32, + }, nil } - exp := dec.Exponent() + exp := value.Exponent() + var scale uint32 + var unscaled *big.Int if exp >= 0 { - unscaled := new(big.Int).Set(coeff) + unscaled = new(big.Int).Set(coeff) unscaled.Mul(unscaled, bigPow10(int(exp))) - return NewDecimal(unscaled, 0), true - } - scale := uint32(-exp) - unscaled := new(big.Int).Set(coeff) - return NewDecimal(unscaled, scale), true -} - -func isNilInterface(value any) bool { - if value == nil { - return true - } - rv := reflect.ValueOf(value) - switch rv.Kind() { - case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice, reflect.Func: - return rv.IsNil() - default: - return false + scale = 0 + } else { + scale = uint32(-exp) + unscaled = new(big.Int).Set(coeff) } + return NewDecimal(unscaled, scale) } func bigPow10(exponent int) *big.Int { @@ -239,16 +163,16 @@ func bigPow10(exponent int) *big.Int { return result } -func bigIntToTwosComplement(value *big.Int) []byte { +func bigIntToTwosComplement(value *big.Int) ([32]byte, uint8, error) { if value.Sign() == 0 { - return []byte{0} + return [32]byte{0}, 32, nil } if value.Sign() > 0 { bytes := value.Bytes() if bytes[0]&0x80 != 0 { - return append([]byte{0x00}, bytes...) + bytes = append([]byte{0x00}, bytes...) } - return trimTwosComplement(bytes) + return normalizeTwosComplement(bytes) } bitLen := value.BitLen() @@ -265,27 +189,30 @@ func bigIntToTwosComplement(value *big.Int) []byte { bytes = append(padding, bytes...) } - bytes = trimTwosComplement(bytes) if bytes[0]&0x80 == 0 { bytes = append([]byte{0xFF}, bytes...) } - return trimTwosComplement(bytes) + return normalizeTwosComplement(bytes) } -func normalizeTwosComplement(src []byte) []byte { +// normalizeTwosComplement normalizes a two's complement big-endian byte slice to fit within 32 bytes and returns the normalized value along with the offset to the first significant byte. +func normalizeTwosComplement(src []byte) ([32]byte, uint8, error) { if len(src) == 0 { - return []byte{0} + return [32]byte{0}, 32, nil } - trimmed := trimTwosComplement(append([]byte(nil), src...)) - if len(trimmed) == 0 { - return []byte{0} + offset := trimTwosComplement(src) + if len(src)-offset > 32 { + return [32]byte{}, 0, fmt.Errorf("decimal unscaled value exceeds 32 bytes") } - return trimmed + var trimmed [32]byte + copy(trimmed[32-(len(src)-offset):], src[offset:]) + return trimmed, uint8(32 - (len(src) - offset)), nil } -func trimTwosComplement(bytes []byte) []byte { +// trimTwosComplement removes redundant sign bytes from a two's complement big-endian byte slice and returns the offset to the first significant byte. +func trimTwosComplement(bytes []byte) int { if len(bytes) <= 1 { - return bytes + return 0 } signBit := bytes[0] & 0x80 i := 0 @@ -303,7 +230,7 @@ func trimTwosComplement(bytes []byte) []byte { } break } - return bytes[i:] + return i } func twosComplementToBigInt(bytes []byte) *big.Int { diff --git a/http_sender.go b/http_sender.go index b88e1153..e2900932 100644 --- a/http_sender.go +++ b/http_sender.go @@ -305,8 +305,18 @@ func (s *httpLineSender) Float64Column(name string, val float64) LineSender { return s } -func (s *httpLineSender) DecimalColumn(name string, val any) LineSender { - s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) +func (s *httpLineSender) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSender) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSender) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) return s } @@ -642,8 +652,18 @@ func (s *httpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[flo return s } -func (s *httpLineSenderV2) DecimalColumn(name string, val any) LineSender { - s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) +func (s *httpLineSenderV2) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSenderV2) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSenderV2) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) return s } @@ -707,7 +727,17 @@ func (s *httpLineSenderV3) Float64ArrayNDColumn(name string, values *NdArray[flo return s } -func (s *httpLineSenderV3) DecimalColumn(name string, val any) LineSender { - s.buf.DecimalColumn(name, val) +func (s *httpLineSenderV3) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.DecimalColumnScaled(name, val) + return s +} + +func (s *httpLineSenderV3) DecimalColumnString(name string, val string) LineSender { + s.buf.DecimalColumnString(name, val) + return s +} + +func (s *httpLineSenderV3) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.DecimalColumnShopspring(name, val) return s } diff --git a/http_sender_test.go b/http_sender_test.go index 69088ca0..963a8c66 100644 --- a/http_sender_test.go +++ b/http_sender_test.go @@ -937,7 +937,7 @@ func TestDecimalColumnUnsupportedInHttpProtocolV2(t *testing.T) { err = sender. Table(testTable). - DecimalColumn("price", "12.99"). + DecimalColumnString("price", "12.99"). At(ctx, time.UnixMicro(1)) assert.Error(t, err) assert.Contains(t, err.Error(), "current protocol version does not support decimal") diff --git a/integration_test.go b/integration_test.go index 89065924..955cd5f5 100644 --- a/integration_test.go +++ b/integration_test.go @@ -499,10 +499,10 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { func(s qdb.LineSender) error { err := s. Table(testTable). - DecimalColumn("text_col", "123.45"). - DecimalColumn("binary_col", qdb.NewDecimalFromInt64(12345, 2)). - DecimalColumn("binary_neg_col", qdb.NewDecimalFromInt64(-12345, 2)). - DecimalColumn("binary_null_col", qdb.NullDecimal()). + DecimalColumnString("text_col", "123.45"). + DecimalColumnScaled("binary_col", qdb.NewDecimalFromInt64(12345, 2)). + DecimalColumnScaled("binary_neg_col", qdb.NewDecimalFromInt64(-12345, 2)). + DecimalColumnShopspring("binary_null_col", nil). At(ctx, time.UnixMicro(1)) if err != nil { return err @@ -510,10 +510,10 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { return s. Table(testTable). - DecimalColumn("text_col", "123.46"). - DecimalColumn("binary_col", qdb.NewDecimalFromInt64(12346, 2)). - DecimalColumn("binary_neg_col", qdb.NewDecimalFromInt64(-12346, 2)). - DecimalColumn("binary_null_col", qdb.NullDecimal()). + DecimalColumnString("text_col", "123.46"). + DecimalColumnScaled("binary_col", qdb.NewDecimalFromInt64(12346, 2)). + DecimalColumnScaled("binary_neg_col", qdb.NewDecimalFromInt64(-12346, 2)). + DecimalColumnShopspring("binary_null_col", nil). At(ctx, time.UnixMicro(2)) }, tableData{ @@ -521,12 +521,11 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { {"text_col", "DECIMAL(18,3)"}, {"binary_col", "DECIMAL(18,3)"}, {"binary_neg_col", "DECIMAL(18,3)"}, - {"binary_null_col", "DECIMAL(18,3)"}, {"timestamp", "TIMESTAMP"}, }, Dataset: [][]any{ - {"123.450", "123.450", "-123.450", nil, "1970-01-01T00:00:00.000001Z"}, - {"123.460", "123.460", "-123.460", nil, "1970-01-01T00:00:00.000002Z"}, + {"123.450", "123.450", "-123.450", "1970-01-01T00:00:00.000001Z"}, + {"123.460", "123.460", "-123.460", "1970-01-01T00:00:00.000002Z"}, }, Count: 2, }, diff --git a/interop_test.go b/interop_test.go index dee0d3f5..97d7bd3d 100644 --- a/interop_test.go +++ b/interop_test.go @@ -124,7 +124,7 @@ func execute(t *testing.T, ctx context.Context, sender qdb.LineSender, backCh ch case "DECIMAL": dec, err := parseDecimal64(s.Value.(string)) assert.NoError(t, err) - sender.DecimalColumn(s.Name, dec) + sender.DecimalColumnScaled(s.Name, dec) default: assert.Fail(t, "unexpected column type: "+s.Type) } diff --git a/sender.go b/sender.go index 39b3c1b4..cbc38cdb 100644 --- a/sender.go +++ b/sender.go @@ -40,6 +40,7 @@ var ( errFlushWithPendingMessage = errors.New("pending ILP message must be finalized with At or AtNow before calling Flush") errClosedSenderAt = errors.New("cannot queue new messages on a closed LineSender") errDoubleSenderClose = errors.New("double sender close") + errDecimalNotSupported = errors.New("current protocol version does not support decimal") ) // LineSender allows you to insert rows into QuestDB by sending ILP @@ -108,14 +109,30 @@ type LineSender interface { // DecimalColumn adds a decimal column value to the ILP message. // - // Supported value types include questdb.Decimal, any custom type implementing - // questdb.DecimalMarshaler, github.com/shopspring/decimal.Decimal (value or pointer), - // and nil to encode a NULL decimal. + // Serializes the decimal value using the text representation. // // Column name cannot contain any of the following characters: // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', // '-', '*' '%%', '~', or a non-printable char. - DecimalColumn(name string, val any) LineSender + DecimalColumnString(name string, val string) LineSender + + // DecimalColumnScaled adds a decimal column value to the ILP message. + // + // Serializes the decimal value using the binary representation. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + DecimalColumnScaled(name string, val ScaledDecimal) LineSender + + // DecimalColumnScaled adds a decimal column value to the ILP message. + // + // Serializes the decimal value using the binary representation. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender // StringColumn adds a string column value to the ILP message. // diff --git a/sender_pool.go b/sender_pool.go index 9bf0870c..eb71332d 100644 --- a/sender_pool.go +++ b/sender_pool.go @@ -314,8 +314,18 @@ func (ps *pooledSender) Float64Column(name string, val float64) LineSender { return ps } -func (ps *pooledSender) DecimalColumn(name string, val any) LineSender { - ps.wrapped.DecimalColumn(name, val) +func (ps *pooledSender) DecimalColumnString(name string, val string) LineSender { + ps.wrapped.DecimalColumnString(name, val) + return ps +} + +func (ps *pooledSender) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + ps.wrapped.DecimalColumnShopspring(name, val) + return ps +} + +func (ps *pooledSender) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + ps.wrapped.DecimalColumnScaled(name, val) return ps } diff --git a/tcp_sender.go b/tcp_sender.go index f331ba40..ac4e25be 100644 --- a/tcp_sender.go +++ b/tcp_sender.go @@ -203,8 +203,18 @@ func (s *tcpLineSender) Float64Column(name string, val float64) LineSender { return s } -func (s *tcpLineSender) DecimalColumn(name string, val any) LineSender { - s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) +func (s *tcpLineSender) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSender) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSender) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) return s } @@ -371,8 +381,18 @@ func (s *tcpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[floa return s } -func (s *tcpLineSenderV2) DecimalColumn(name string, val any) LineSender { - s.buf.SetLastErr(errors.New("current protocol version does not support decimal")) +func (s *tcpLineSenderV2) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSenderV2) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSenderV2) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) return s } @@ -436,7 +456,17 @@ func (s *tcpLineSenderV3) Float64ArrayNDColumn(name string, values *NdArray[floa return s } -func (s *tcpLineSenderV3) DecimalColumn(name string, val any) LineSender { - s.buf.DecimalColumn(name, val) +func (s *tcpLineSenderV3) DecimalColumnString(name string, val string) LineSender { + s.buf.DecimalColumnString(name, val) + return s +} + +func (s *tcpLineSenderV3) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.DecimalColumnScaled(name, val) + return s +} + +func (s *tcpLineSenderV3) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.DecimalColumnShopspring(name, val) return s } diff --git a/tcp_sender_test.go b/tcp_sender_test.go index 132e0cd8..cf1f57bf 100644 --- a/tcp_sender_test.go +++ b/tcp_sender_test.go @@ -379,7 +379,7 @@ func TestDecimalColumnUnsupportedInTCPProtocolV2(t *testing.T) { err = sender. Table(testTable). - DecimalColumn("price", "12.99"). + DecimalColumnString("price", "12.99"). At(ctx, time.UnixMicro(1)) assert.Error(t, err) assert.Contains(t, err.Error(), "current protocol version does not support decimal") diff --git a/utils_test.go b/utils_test.go index 9dfd79bb..de532ec5 100644 --- a/utils_test.go +++ b/utils_test.go @@ -74,7 +74,7 @@ func newTestTcpServer(serverType serverType) (*testServer, error) { } func newTestHttpServer(serverType serverType) (*testServer, error) { - return newTestServerWithProtocol(serverType, "http", []int{1, 2}) + return newTestServerWithProtocol(serverType, "http", []int{1, 2, 3}) } func newTestHttpServerWithErrMsg(serverType serverType, errMsg string) (*testServer, error) { From a7f4c07c65687c7647cef6e7c0e5ff8654a6084a Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 29 Oct 2025 15:05:34 +0100 Subject: [PATCH 10/13] fix: handle null decimals in DecimalColumn methods to prevent writing null values --- buffer.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/buffer.go b/buffer.go index a3ce1c99..4e5507be 100644 --- a/buffer.go +++ b/buffer.go @@ -574,6 +574,10 @@ func (b *buffer) Float64Column(name string, val float64) *buffer { } func (b *buffer) DecimalColumnScaled(name string, val ScaledDecimal) *buffer { + if val.IsNull() { + // Don't write null decimals + return b + } if !b.prepareForField() { return b } @@ -585,10 +589,6 @@ func (b *buffer) decimalColumnScaled(name string, val ScaledDecimal) *buffer { b.lastErr = err return b } - if val.IsNull() { - // Don't write null decimals - return b - } b.lastErr = b.writeColumnName(name) if b.lastErr != nil { return b @@ -623,10 +623,10 @@ func (b *buffer) DecimalColumnString(name string, val string) *buffer { } func (b *buffer) DecimalColumnShopspring(name string, val ShopspringDecimal) *buffer { - if !b.prepareForField() { + if val == nil { return b } - if val == nil { + if b.lastErr != nil { return b } dec, err := convertShopspringDecimal(val) @@ -634,6 +634,13 @@ func (b *buffer) DecimalColumnShopspring(name string, val ShopspringDecimal) *bu b.lastErr = err return b } + if dec.IsNull() { + // Don't write null decimals + return b + } + if !b.prepareForField() { + return b + } return b.decimalColumnScaled(name, dec) } From b6e5de3110b74b67c0ca61c1db3fec7f32e5ff11 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 29 Oct 2025 15:08:38 +0100 Subject: [PATCH 11/13] fix: adjust twos complement return value for zero in bigIntToTwosComplement --- buffer_test.go | 5 +++++ decimal.go | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/buffer_test.go b/buffer_test.go index e478efa7..cc1c56ac 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -618,6 +618,11 @@ func TestDecimalColumnShopspring(t *testing.T) { value: fakeShopspringDecimal{coeff: big.NewInt(12345), exp: -2}, expectedBytes: []byte{0x17, 0x02, 0x02, 0x30, 0x39}, }, + { + name: "zero", + value: fakeShopspringDecimal{coeff: big.NewInt(0), exp: 0}, + expectedBytes: []byte{0x17, 0x00, 0x01, 0x00}, + }, { name: "positive exponent multiplies coefficient", value: fakeShopspringDecimal{coeff: big.NewInt(123), exp: 2}, diff --git a/decimal.go b/decimal.go index 0fb5fef7..cd0907db 100644 --- a/decimal.go +++ b/decimal.go @@ -165,7 +165,7 @@ func bigPow10(exponent int) *big.Int { func bigIntToTwosComplement(value *big.Int) ([32]byte, uint8, error) { if value.Sign() == 0 { - return [32]byte{0}, 32, nil + return [32]byte{0}, 31, nil } if value.Sign() > 0 { bytes := value.Bytes() From f3222af3221f2751ff40f95ae924ffe562f838d5 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 29 Oct 2025 15:09:57 +0100 Subject: [PATCH 12/13] fix: correct method name in LineSender interface documentation --- sender.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sender.go b/sender.go index cbc38cdb..f4e4abf8 100644 --- a/sender.go +++ b/sender.go @@ -125,7 +125,7 @@ type LineSender interface { // '-', '*' '%%', '~', or a non-printable char. DecimalColumnScaled(name string, val ScaledDecimal) LineSender - // DecimalColumnScaled adds a decimal column value to the ILP message. + // DecimalColumnShopspring adds a decimal column value to the ILP message. // // Serializes the decimal value using the binary representation. // From 840577312ee34aec7e61b9ec7ffa817ed26de207 Mon Sep 17 00:00:00 2001 From: Raphael DALMON Date: Wed, 29 Oct 2025 15:12:55 +0100 Subject: [PATCH 13/13] fix: rename IsNull method to isNull for consistency in ScaledDecimal --- buffer.go | 4 ++-- decimal.go | 39 +++------------------------------------ 2 files changed, 5 insertions(+), 38 deletions(-) diff --git a/buffer.go b/buffer.go index 4e5507be..18d86149 100644 --- a/buffer.go +++ b/buffer.go @@ -574,7 +574,7 @@ func (b *buffer) Float64Column(name string, val float64) *buffer { } func (b *buffer) DecimalColumnScaled(name string, val ScaledDecimal) *buffer { - if val.IsNull() { + if val.isNull() { // Don't write null decimals return b } @@ -634,7 +634,7 @@ func (b *buffer) DecimalColumnShopspring(name string, val ShopspringDecimal) *bu b.lastErr = err return b } - if dec.IsNull() { + if dec.isNull() { // Don't write null decimals return b } diff --git a/decimal.go b/decimal.go index cd0907db..8653591d 100644 --- a/decimal.go +++ b/decimal.go @@ -100,27 +100,13 @@ func NewDecimalFromInt64(unscaled int64, scale uint32) ScaledDecimal { } } -// IsNull reports whether the decimal represents NULL. -func (d ScaledDecimal) IsNull() bool { +// isNull reports whether the decimal represents NULL. +func (d ScaledDecimal) isNull() bool { return d.offset >= 32 } -// Scale returns the decimal scale. -func (d ScaledDecimal) Scale() uint32 { - return d.scale -} - -// UnscaledValue returns a copy of the unscaled integer value. -// For NULL decimals it returns nil. -func (d ScaledDecimal) UnscaledValue() *big.Int { - if d.IsNull() { - return nil - } - return twosComplementToBigInt(d.unscaled[d.offset:]) -} - func (d ScaledDecimal) ensureValidScale() error { - if d.IsNull() { + if d.isNull() { return nil } if d.scale > maxDecimalScale { @@ -233,25 +219,6 @@ func trimTwosComplement(bytes []byte) int { return i } -func twosComplementToBigInt(bytes []byte) *big.Int { - if len(bytes) == 0 { - return big.NewInt(0) - } - if bytes[0]&0x80 == 0 { - return new(big.Int).SetBytes(bytes) - } - - inverted := make([]byte, len(bytes)) - for i := range bytes { - inverted[i] = ^bytes[i] - } - - magnitude := new(big.Int).SetBytes(inverted) - magnitude.Add(magnitude, big.NewInt(1)) - magnitude.Neg(magnitude) - return magnitude -} - // validateDecimalText checks that the provided string is a valid decimal representation. // It accepts numeric digits, optional sign, decimal point, exponent (e/E) and NaN/Infinity tokens. func validateDecimalText(text string) error {