diff --git a/README.md b/README.md index d6aeff7c..f1f051ff 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,36 @@ will now also serialize ``float64`` (double-precision) columns as binary. You might see a performance uplift if this is a dominant data type in your ingestion workload. +## Decimal columns + +QuestDB server version 9.2.0 and newer supports decimal columns with arbitrary precision and scale. +The Go client converts supported decimal values to QuestDB's text/binary wire format automatically: + +- `DecimalColumnScaled`: `questdb.ScaledDecimal`, including helpers like `questdb.NewDecimalFromInt64` and `questdb.NewDecimal`. +- `DecimalColumnShopspring`: `github.com/shopspring/decimal.Decimal` values or pointers. +- `DecimalColumnString`: `string` literals representing decimal values (validated at runtime). + +```go +price := qdb.NewDecimalFromInt64(12345, 2) // 123.45 with scale 2 +commission := qdb.NewDecimal(big.NewInt(-750), 4) // -0.0750 with scale 4 + +err = sender. + Table("trades"). + Symbol("symbol", "ETH-USD"). + DecimalColumnScaled("price", price). + DecimalColumnScaled("commission", commission). + AtNow(ctx) +``` + +To emit textual decimals, pass a validated string literal: + +```go +err = sender. + Table("quotes"). + DecimalColumnString("mid", "1.23456"). + AtNow(ctx) +``` + ## Pooled Line Senders **Warning: Experimental feature designed for use with HTTP senders ONLY** diff --git a/buffer.go b/buffer.go index 85d586f1..18d86149 100644 --- a/buffer.go +++ b/buffer.go @@ -573,6 +573,77 @@ func (b *buffer) Float64Column(name string, val float64) *buffer { return b } +func (b *buffer) DecimalColumnScaled(name string, val ScaledDecimal) *buffer { + if val.isNull() { + // Don't write null decimals + return b + } + if !b.prepareForField() { + return b + } + return b.decimalColumnScaled(name, val) +} + +func (b *buffer) decimalColumnScaled(name string, val ScaledDecimal) *buffer { + if err := val.ensureValidScale(); err != nil { + b.lastErr = err + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + b.WriteByte('=') + b.WriteByte('=') + b.WriteByte(decimalBinaryTypeCode) + b.WriteByte((uint8)(val.scale)) + b.WriteByte(32 - val.offset) + b.Write(val.unscaled[val.offset:]) + b.hasFields = true + return b +} + +func (b *buffer) DecimalColumnString(name string, val string) *buffer { + if !b.prepareForField() { + return b + } + if err := validateDecimalText(val); err != nil { + b.lastErr = err + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + b.WriteByte('=') + b.WriteString(val) + b.WriteByte('d') + b.hasFields = true + return b +} + +func (b *buffer) DecimalColumnShopspring(name string, val ShopspringDecimal) *buffer { + if val == nil { + return b + } + if b.lastErr != nil { + return b + } + dec, err := convertShopspringDecimal(val) + if err != nil { + b.lastErr = err + return b + } + if dec.isNull() { + // Don't write null decimals + return b + } + if !b.prepareForField() { + return b + } + return b.decimalColumnScaled(name, dec) +} + func (b *buffer) Float64ColumnBinary(name string, val float64) *buffer { if !b.prepareForField() { return b diff --git a/buffer_test.go b/buffer_test.go index 646e5063..cc1c56ac 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -39,6 +39,19 @@ import ( type bufWriterFn func(b *qdb.Buffer) error +type fakeShopspringDecimal struct { + coeff *big.Int + exp int32 +} + +func (f fakeShopspringDecimal) Coefficient() *big.Int { + return f.coeff +} + +func (f fakeShopspringDecimal) Exponent() int32 { + return f.exp +} + func newTestBuffer() qdb.Buffer { return qdb.NewBuffer(128*1024, 1024*1024, 127) } @@ -481,6 +494,240 @@ func TestFloat64ColumnBinary(t *testing.T) { } } +func TestDecimalColumnScaled(t *testing.T) { + negative, err := qdb.NewDecimal(big.NewInt(-12345), 3) + assert.NoError(t, err) + + prefix := []byte(testTable + " price==") + testCases := []struct { + name string + value qdb.ScaledDecimal + expected []byte + }{ + { + name: "positive", + value: qdb.NewDecimalFromInt64(12345, 2), + expected: append(prefix, 0x17, 0x02, 0x02, 0x30, 0x39, 0x0A), + }, + { + name: "negative", + value: negative, + expected: append(prefix, 0x17, 0x03, 0x02, 0xCF, 0xC7, 0x0A), + }, + { + name: "zero with scale", + value: qdb.NewDecimalFromInt64(0, 4), + expected: append(prefix, 0x17, 0x04, 0x01, 0x0, 0x0A), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumnScaled("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, tc.expected, buf.Messages()) + }) + } +} + +func TestDecimalColumnScaledTrimmingAndPadding(t *testing.T) { + prefix := []byte(testTable + " price==") + + testCases := []struct { + name string + value qdb.ScaledDecimal + expectedBytes []byte + }{ + { + name: "127 boundary", + value: qdb.NewDecimalFromInt64(127, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0x7F}, + }, + { + name: "128 sign extension", + value: qdb.NewDecimalFromInt64(128, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0x00, 0x80}, + }, + { + name: "255 sign extension", + value: qdb.NewDecimalFromInt64(255, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0x00, 0xFF}, + }, + { + name: "32768 sign extension", + value: qdb.NewDecimalFromInt64(32768, 0), + expectedBytes: []byte{0x17, 0x00, 0x03, 0x00, 0x80, 0x00}, + }, + { + name: "-1", + value: qdb.NewDecimalFromInt64(-1, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0xFF}, + }, + { + name: "-2", + value: qdb.NewDecimalFromInt64(-2, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0xFE}, + }, + { + name: "-127", + value: qdb.NewDecimalFromInt64(-127, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0x81}, + }, + { + name: "-128", + value: qdb.NewDecimalFromInt64(-128, 0), + expectedBytes: []byte{0x17, 0x00, 0x01, 0x80}, + }, + { + name: "-129", + value: qdb.NewDecimalFromInt64(-129, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0xFF, 0x7F}, + }, + { + name: "-256 sign extension", + value: qdb.NewDecimalFromInt64(-256, 0), + expectedBytes: []byte{0x17, 0x00, 0x02, 0xFF, 0x00}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + + err := buf.Table(testTable).DecimalColumnScaled("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + + expected := append(append([]byte{}, prefix...), tc.expectedBytes...) + expected = append(expected, '\n') + assert.Equal(t, expected, buf.Messages()) + }) + } +} + +func TestDecimalColumnShopspring(t *testing.T) { + prefix := []byte(testTable + " price==") + + testCases := []struct { + name string + value fakeShopspringDecimal + expectedBytes []byte + }{ + { + name: "negative exponent scales value", + value: fakeShopspringDecimal{coeff: big.NewInt(12345), exp: -2}, + expectedBytes: []byte{0x17, 0x02, 0x02, 0x30, 0x39}, + }, + { + name: "zero", + value: fakeShopspringDecimal{coeff: big.NewInt(0), exp: 0}, + expectedBytes: []byte{0x17, 0x00, 0x01, 0x00}, + }, + { + name: "positive exponent multiplies coefficient", + value: fakeShopspringDecimal{coeff: big.NewInt(123), exp: 2}, + expectedBytes: []byte{0x17, 0x00, 0x02, 0x30, 0x0C}, + }, + { + name: "positive value sign extension", + value: fakeShopspringDecimal{coeff: big.NewInt(128), exp: 0}, + expectedBytes: []byte{0x17, 0x00, 0x02, 0x00, 0x80}, + }, + { + name: "negative value sign extension", + value: fakeShopspringDecimal{coeff: big.NewInt(-12345), exp: -3}, + expectedBytes: []byte{0x17, 0x03, 0x02, 0xCF, 0xC7}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + + err := buf.Table(testTable).DecimalColumnShopspring("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + + expected := append(append([]byte{}, prefix...), tc.expectedBytes...) + expected = append(expected, '\n') + assert.Equal(t, expected, buf.Messages()) + }) + } +} + +func TestDecimalColumnStringValidation(t *testing.T) { + t.Run("valid strings", func(t *testing.T) { + testCases := []struct { + name string + value string + expected string + }{ + {"integer", "123", "123d"}, + {"decimal", "123.450", "123.450d"}, + {"negative", "-0.001", "-0.001d"}, + {"exponent positive", "1.2e3", "1.2e3d"}, + {"exponent negative", "-4.5E-2", "-4.5E-2d"}, + {"nan token", "NaN", "NaNd"}, + {"infinity token", "Infinity", "Infinityd"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumnString("price", tc.value).At(time.Time{}, false) + assert.NoError(t, err) + expected := []byte(testTable + " price=" + tc.expected + "\n") + assert.Equal(t, expected, buf.Messages()) + }) + } + }) + + t.Run("invalid strings", func(t *testing.T) { + testCases := []struct { + name string + value string + }{ + {"empty", ""}, + {"sign only", "+"}, + {"double dot", "12.3.4"}, + {"invalid char", "12a3"}, + {"exponent missing mantissa", "e10"}, + {"exponent no digits", "1.2e"}, + {"exponent sign no digits", "1.2e+"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumnString("price", tc.value).At(time.Time{}, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "decimal") + assert.Empty(t, buf.Messages()) + }) + } + }) +} + +func TestDecimalColumnErrors(t *testing.T) { + t.Run("invalid scale", func(t *testing.T) { + buf := newTestBuffer() + dec := qdb.NewDecimalFromInt64(1, 100) + err := buf.Table(testTable).DecimalColumnScaled("price", dec).At(time.Time{}, false) + assert.ErrorContains(t, err, "decimal scale") + assert.Empty(t, buf.Messages()) + }) + + t.Run("overflow", func(t *testing.T) { + bigVal := new(big.Int).Lsh(big.NewInt(1), 2100) + _, err := qdb.NewDecimal(bigVal, 0) + assert.ErrorContains(t, err, "exceeds 32 bytes") + }) + + t.Run("no column", func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).DecimalColumnShopspring("price", nil).At(time.Time{}, false) + assert.ErrorContains(t, err, "no symbols or columns were provided: invalid message") + assert.Empty(t, buf.Messages()) + }) +} + func TestFloat64Array1DColumn(t *testing.T) { testCases := []struct { name string diff --git a/conf_parse.go b/conf_parse.go index 7957dc51..66e84c40 100644 --- a/conf_parse.go +++ b/conf_parse.go @@ -169,8 +169,8 @@ func confFromStr(conf string) (*lineSenderConfig, error) { return nil, NewInvalidConfigStrError("invalid %s value, %q is not a valid int", k, v) } pVersion := protocolVersion(version) - if pVersion < ProtocolVersion1 || pVersion > ProtocolVersion2 { - return nil, NewInvalidConfigStrError("current client only supports protocol version 1 (text format for all datatypes), 2 (binary format for part datatypes) or explicitly unset") + if pVersion < ProtocolVersion1 || pVersion > ProtocolVersion3 { + return nil, NewInvalidConfigStrError("current client only supports protocol version 1 (text format for all datatypes), 2 (binary format for part datatypes), 3 (decimals) or explicitly unset") } senderConf.protocolVersion = pVersion } diff --git a/decimal.go b/decimal.go new file mode 100644 index 00000000..8653591d --- /dev/null +++ b/decimal.go @@ -0,0 +1,292 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2024 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ + +package questdb + +import ( + "encoding/binary" + "fmt" + "math/big" +) + +const ( + decimalBinaryTypeCode byte = 0x17 + maxDecimalScale uint32 = 76 +) + +// ScaledDecimal represents a decimal value as a two's complement big-endian byte slice and a scale. +// NULL decimals are represented by an offset of 32. +type ScaledDecimal struct { + scale uint32 + unscaled [32]byte + offset uint8 +} + +type ShopspringDecimal interface { + Coefficient() *big.Int + Exponent() int32 +} + +// NewScaledDecimal constructs a decimal from a two's complement big-endian unscaled value and a scale. +// A nil/empty unscaled slice produces a NULL decimal. +func NewScaledDecimal(unscaled []byte, scale uint32) (ScaledDecimal, error) { + if len(unscaled) == 0 { + return ScaledDecimal{ + offset: 32, + }, nil + } + normalized, offset, err := normalizeTwosComplement(unscaled) + if err != nil { + return ScaledDecimal{}, err + } + return ScaledDecimal{ + scale: scale, + unscaled: normalized, + offset: offset, + }, nil +} + +// NewDecimal constructs a decimal from an arbitrary-precision integer and a scale. +// Providing a nil unscaled value produces a NULL decimal. +func NewDecimal(unscaled *big.Int, scale uint32) (ScaledDecimal, error) { + if unscaled == nil { + return ScaledDecimal{ + offset: 32, + }, nil + } + unscaledRaw, offset, err := bigIntToTwosComplement(unscaled) + if err != nil { + return ScaledDecimal{}, err + } + return ScaledDecimal{ + scale: scale, + unscaled: unscaledRaw, + offset: offset, + }, nil +} + +// NewDecimalFromInt64 constructs a decimal from a 64-bit integer and a scale. +func NewDecimalFromInt64(unscaled int64, scale uint32) ScaledDecimal { + var be [8]byte + binary.BigEndian.PutUint64(be[:], uint64(unscaled)) + offset := trimTwosComplement(be[:]) + payload := [32]byte{} + copy(payload[32-(8-offset):], be[offset:]) + return ScaledDecimal{ + scale: scale, + unscaled: payload, + offset: uint8(32 - (8 - offset)), + } +} + +// isNull reports whether the decimal represents NULL. +func (d ScaledDecimal) isNull() bool { + return d.offset >= 32 +} + +func (d ScaledDecimal) ensureValidScale() error { + if d.isNull() { + return nil + } + if d.scale > maxDecimalScale { + return fmt.Errorf("decimal scale %d exceeds maximum %d", d.scale, maxDecimalScale) + } + return nil +} + +func convertShopspringDecimal(value ShopspringDecimal) (ScaledDecimal, error) { + coeff := value.Coefficient() + if coeff == nil { + return ScaledDecimal{ + offset: 32, + }, nil + } + + exp := value.Exponent() + var scale uint32 + var unscaled *big.Int + if exp >= 0 { + unscaled = new(big.Int).Set(coeff) + unscaled.Mul(unscaled, bigPow10(int(exp))) + scale = 0 + } else { + scale = uint32(-exp) + unscaled = new(big.Int).Set(coeff) + } + return NewDecimal(unscaled, scale) +} + +func bigPow10(exponent int) *big.Int { + if exponent <= 0 { + return big.NewInt(1) + } + result := big.NewInt(1) + ten := big.NewInt(10) + for i := 0; i < exponent; i++ { + result.Mul(result, ten) + } + return result +} + +func bigIntToTwosComplement(value *big.Int) ([32]byte, uint8, error) { + if value.Sign() == 0 { + return [32]byte{0}, 31, nil + } + if value.Sign() > 0 { + bytes := value.Bytes() + if bytes[0]&0x80 != 0 { + bytes = append([]byte{0x00}, bytes...) + } + return normalizeTwosComplement(bytes) + } + + bitLen := value.BitLen() + byteLen := (bitLen + 8) / 8 + if byteLen == 0 { + byteLen = 1 + } + + tmp := new(big.Int).Lsh(big.NewInt(1), uint(byteLen*8)) + tmp.Add(tmp, value) // value is negative, so this subtracts magnitude + bytes := tmp.Bytes() + if len(bytes) < int(byteLen) { + padding := make([]byte, int(byteLen)-len(bytes)) + bytes = append(padding, bytes...) + } + + if bytes[0]&0x80 == 0 { + bytes = append([]byte{0xFF}, bytes...) + } + return normalizeTwosComplement(bytes) +} + +// normalizeTwosComplement normalizes a two's complement big-endian byte slice to fit within 32 bytes and returns the normalized value along with the offset to the first significant byte. +func normalizeTwosComplement(src []byte) ([32]byte, uint8, error) { + if len(src) == 0 { + return [32]byte{0}, 32, nil + } + offset := trimTwosComplement(src) + if len(src)-offset > 32 { + return [32]byte{}, 0, fmt.Errorf("decimal unscaled value exceeds 32 bytes") + } + var trimmed [32]byte + copy(trimmed[32-(len(src)-offset):], src[offset:]) + return trimmed, uint8(32 - (len(src) - offset)), nil +} + +// trimTwosComplement removes redundant sign bytes from a two's complement big-endian byte slice and returns the offset to the first significant byte. +func trimTwosComplement(bytes []byte) int { + if len(bytes) <= 1 { + return 0 + } + signBit := bytes[0] & 0x80 + i := 0 + for i < len(bytes)-1 { + if signBit == 0 { + if bytes[i] == 0x00 && bytes[i+1]&0x80 == 0 { + i++ + continue + } + } else { + if bytes[i] == 0xFF && bytes[i+1]&0x80 != 0 { + i++ + continue + } + } + break + } + return i +} + +// validateDecimalText checks that the provided string is a valid decimal representation. +// It accepts numeric digits, optional sign, decimal point, exponent (e/E) and NaN/Infinity tokens. +func validateDecimalText(text string) error { + if text == "" { + return fmt.Errorf("decimal literal cannot be empty") + } + + switch text { + case "NaN", "Infinity", "+Infinity", "-Infinity": + return nil + } + + i := 0 + length := len(text) + if text[0] == '+' || text[0] == '-' { + if length == 1 { + return fmt.Errorf("decimal literal contains sign without digits") + } + i++ + } + + digits := 0 + seenDot := false + for i < length { + ch := text[i] + switch { + case ch >= '0' && ch <= '9': + digits++ + i++ + case ch == '.': + if seenDot { + return fmt.Errorf("decimal literal has multiple decimal points") + } + seenDot = true + i++ + case ch == 'e' || ch == 'E': + if digits == 0 { + return fmt.Errorf("decimal literal exponent without mantissa") + } + i++ + if i >= length { + return fmt.Errorf("decimal literal has incomplete exponent") + } + if text[i] == '+' || text[i] == '-' { + i++ + if i >= length { + return fmt.Errorf("decimal literal has incomplete exponent") + } + } + expDigits := 0 + for i < length && text[i] >= '0' && text[i] <= '9' { + i++ + expDigits++ + } + if expDigits == 0 { + return fmt.Errorf("decimal literal exponent has no digits") + } + if i != length { + return fmt.Errorf("decimal literal has trailing characters") + } + return nil + default: + return fmt.Errorf("decimal literal contains invalid character %q", ch) + } + } + + if digits == 0 { + return fmt.Errorf("decimal literal must contain at least one digit") + } + return nil +} diff --git a/export_test.go b/export_test.go index bcd02a66..21025153 100644 --- a/export_test.go +++ b/export_test.go @@ -63,12 +63,18 @@ func Messages(s LineSender) []byte { if hs, ok := s.(*httpLineSenderV2); ok { return hs.Messages() } + if hs, ok := s.(*httpLineSenderV3); ok { + return hs.Messages() + } if ts, ok := s.(*tcpLineSender); ok { return ts.Messages() } if ts, ok := s.(*tcpLineSenderV2); ok { return ts.Messages() } + if ts, ok := s.(*tcpLineSenderV3); ok { + return ts.Messages() + } panic("unexpected struct") } @@ -82,12 +88,18 @@ func MsgCount(s LineSender) int { if hs, ok := s.(*httpLineSenderV2); ok { return hs.MsgCount() } + if hs, ok := s.(*httpLineSenderV3); ok { + return hs.MsgCount() + } if ts, ok := s.(*tcpLineSender); ok { return ts.MsgCount() } if ts, ok := s.(*tcpLineSenderV2); ok { return ts.MsgCount() } + if ts, ok := s.(*tcpLineSenderV3); ok { + return ts.MsgCount() + } panic("unexpected struct") } @@ -101,12 +113,18 @@ func BufLen(s LineSender) int { if hs, ok := s.(*httpLineSenderV2); ok { return hs.BufLen() } + if hs, ok := s.(*httpLineSenderV3); ok { + return hs.BufLen() + } if ts, ok := s.(*tcpLineSender); ok { return ts.BufLen() } if ts, ok := s.(*tcpLineSenderV2); ok { return ts.BufLen() } + if ts, ok := s.(*tcpLineSenderV3); ok { + return ts.BufLen() + } panic("unexpected struct") } @@ -120,12 +138,18 @@ func ProtocolVersion(s LineSender) protocolVersion { if _, ok := s.(*httpLineSenderV2); ok { return ProtocolVersion2 } + if _, ok := s.(*httpLineSenderV3); ok { + return ProtocolVersion3 + } if _, ok := s.(*tcpLineSender); ok { return ProtocolVersion1 } if _, ok := s.(*tcpLineSenderV2); ok { return ProtocolVersion2 } + if _, ok := s.(*tcpLineSenderV3); ok { + return ProtocolVersion3 + } panic("unexpected struct") } diff --git a/http_sender.go b/http_sender.go index 7d070fc4..e2900932 100644 --- a/http_sender.go +++ b/http_sender.go @@ -119,6 +119,10 @@ type httpLineSenderV2 struct { httpLineSender } +type httpLineSenderV3 struct { + httpLineSenderV2 +} + func newHttpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, error) { var transport *http.Transport s := &httpLineSender{ @@ -175,12 +179,21 @@ func newHttpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, } s.uri += fmt.Sprintf("://%s/write", s.address) - if pVersion == ProtocolVersion1 { + switch pVersion { + case ProtocolVersion1: return s, nil - } else { + case ProtocolVersion2: return &httpLineSenderV2{ *s, }, nil + case ProtocolVersion3: + return &httpLineSenderV3{ + httpLineSenderV2{ + *s, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported protocol version %d", pVersion) } } @@ -292,6 +305,21 @@ func (s *httpLineSender) Float64Column(name string, val float64) LineSender { return s } +func (s *httpLineSender) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSender) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSender) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + func (s *httpLineSender) StringColumn(name, val string) LineSender { s.buf.StringColumn(name, val) return s @@ -509,16 +537,22 @@ func parseServerSettings(resp *http.Response, conf *lineSenderConfig) (protocolV return ProtocolVersion1, nil } - hasProtocolVersion1 := false + hasProtocol1 := false + hasProtocol2 := false for _, version := range versions { - if version == 2 { - return ProtocolVersion2, nil - } - if version == 1 { - hasProtocolVersion1 = true + switch version { + case 3: + return ProtocolVersion3, nil + case 2: + hasProtocol2 = true + case 1: + hasProtocol1 = true } } - if hasProtocolVersion1 { + if hasProtocol2 { + return ProtocolVersion2, nil + } + if hasProtocol1 { return ProtocolVersion1, nil } @@ -617,3 +651,93 @@ func (s *httpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[flo s.buf.Float64ArrayNDColumn(name, values) return s } + +func (s *httpLineSenderV2) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSenderV2) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSenderV2) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *httpLineSenderV3) Table(name string) LineSender { + s.buf.Table(name) + return s +} + +func (s *httpLineSenderV3) Symbol(name, val string) LineSender { + s.buf.Symbol(name, val) + return s +} + +func (s *httpLineSenderV3) Int64Column(name string, val int64) LineSender { + s.buf.Int64Column(name, val) + return s +} + +func (s *httpLineSenderV3) Long256Column(name string, val *big.Int) LineSender { + s.buf.Long256Column(name, val) + return s +} + +func (s *httpLineSenderV3) TimestampColumn(name string, ts time.Time) LineSender { + s.buf.TimestampColumn(name, ts) + return s +} + +func (s *httpLineSenderV3) StringColumn(name, val string) LineSender { + s.buf.StringColumn(name, val) + return s +} + +func (s *httpLineSenderV3) BoolColumn(name string, val bool) LineSender { + s.buf.BoolColumn(name, val) + return s +} + +func (s *httpLineSenderV3) Float64Column(name string, val float64) LineSender { + s.buf.Float64ColumnBinary(name, val) + return s +} + +func (s *httpLineSenderV3) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.Float64Array1DColumn(name, values) + return s +} + +func (s *httpLineSenderV3) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.Float64Array2DColumn(name, values) + return s +} + +func (s *httpLineSenderV3) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.Float64Array3DColumn(name, values) + return s +} + +func (s *httpLineSenderV3) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.Float64ArrayNDColumn(name, values) + return s +} + +func (s *httpLineSenderV3) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.DecimalColumnScaled(name, val) + return s +} + +func (s *httpLineSenderV3) DecimalColumnString(name string, val string) LineSender { + s.buf.DecimalColumnString(name, val) + return s +} + +func (s *httpLineSenderV3) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.DecimalColumnShopspring(name, val) + return s +} diff --git a/http_sender_test.go b/http_sender_test.go index 1c93c9a2..963a8c66 100644 --- a/http_sender_test.go +++ b/http_sender_test.go @@ -827,7 +827,7 @@ func TestAutoDetectProtocolVersionNewServer2(t *testing.T) { func TestAutoDetectProtocolVersionNewServer3(t *testing.T) { ctx := context.Background() - srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 3}) + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 4}) assert.NoError(t, err) defer srv.Close() sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) @@ -838,13 +838,24 @@ func TestAutoDetectProtocolVersionNewServer3(t *testing.T) { func TestAutoDetectProtocolVersionNewServer4(t *testing.T) { ctx := context.Background() - srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{3}) + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{4}) assert.NoError(t, err) defer srv.Close() _, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) assert.ErrorContains(t, err, "server does not support current client") } +func TestAutoDetectProtocolVersionNewServer5(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 3}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion3) + assert.NoError(t, err) +} + func TestAutoDetectProtocolVersionError(t *testing.T) { ctx := context.Background() @@ -913,6 +924,25 @@ func TestArrayColumnUnsupportedInHttpProtocolV1(t *testing.T) { assert.Contains(t, err.Error(), "current protocol version does not support double-array") } +func TestDecimalColumnUnsupportedInHttpProtocolV2(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) + defer cancel() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.NoError(t, err) + defer sender.Close(ctx) + + err = sender. + Table(testTable). + DecimalColumnString("price", "12.99"). + At(ctx, time.UnixMicro(1)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support decimal") +} + func BenchmarkHttpLineSenderBatch1000(b *testing.B) { ctx := context.Background() diff --git a/integration_test.go b/integration_test.go index 37eb7b17..955cd5f5 100644 --- a/integration_test.go +++ b/integration_test.go @@ -31,6 +31,7 @@ import ( "math/big" "path/filepath" "reflect" + "strings" "testing" "time" @@ -492,47 +493,90 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { Count: 3, }, }, + { + "decimal type", + testTable, + func(s qdb.LineSender) error { + err := s. + Table(testTable). + DecimalColumnString("text_col", "123.45"). + DecimalColumnScaled("binary_col", qdb.NewDecimalFromInt64(12345, 2)). + DecimalColumnScaled("binary_neg_col", qdb.NewDecimalFromInt64(-12345, 2)). + DecimalColumnShopspring("binary_null_col", nil). + At(ctx, time.UnixMicro(1)) + if err != nil { + return err + } + + return s. + Table(testTable). + DecimalColumnString("text_col", "123.46"). + DecimalColumnScaled("binary_col", qdb.NewDecimalFromInt64(12346, 2)). + DecimalColumnScaled("binary_neg_col", qdb.NewDecimalFromInt64(-12346, 2)). + DecimalColumnShopspring("binary_null_col", nil). + At(ctx, time.UnixMicro(2)) + }, + tableData{ + Columns: []column{ + {"text_col", "DECIMAL(18,3)"}, + {"binary_col", "DECIMAL(18,3)"}, + {"binary_neg_col", "DECIMAL(18,3)"}, + {"timestamp", "TIMESTAMP"}, + }, + Dataset: [][]any{ + {"123.450", "123.450", "-123.450", "1970-01-01T00:00:00.000001Z"}, + {"123.460", "123.460", "-123.460", "1970-01-01T00:00:00.000002Z"}, + }, + Count: 2, + }, + }, } for _, tc := range testCases { for _, protocol := range []string{"tcp", "http"} { - for _, pVersion := range []int{0, 1, 2} { + for _, pVersion := range []int{0, 1, 2, 3} { suite.T().Run(fmt.Sprintf("%s: %s", tc.name, protocol), func(t *testing.T) { var ( sender qdb.LineSender err error ) - ignoreArray := false questdbC, err := setupQuestDB(ctx, noAuth) assert.NoError(t, err) switch protocol { case "tcp": - if pVersion == 0 { + switch pVersion { + case 0: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress)) - ignoreArray = true - } else if pVersion == 1 { + case 1: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) - ignoreArray = true - } else if pVersion == 2 { + case 2: sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + case 3: + sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) } assert.NoError(t, err) case "http": - if pVersion == 0 { + switch pVersion { + case 0: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress)) - } else if pVersion == 1 { + case 1: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) - ignoreArray = true - } else if pVersion == 2 { + case 2: sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + case 3: + sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) } assert.NoError(t, err) default: panic(protocol) } - if ignoreArray && tc.name == "double array" { + senderVersion := qdb.ProtocolVersion(sender) + if senderVersion < 2 && tc.name == "double array" { + return + } + if senderVersion < 3 && strings.Contains(tc.name, "decimal") { return } diff --git a/interop_test.go b/interop_test.go index 78f8ab4e..97d7bd3d 100644 --- a/interop_test.go +++ b/interop_test.go @@ -27,8 +27,10 @@ package questdb_test import ( "context" "encoding/json" + "fmt" "io" "os" + "strconv" "strings" "testing" @@ -60,8 +62,10 @@ type testColumn struct { } type testResult struct { - Status string `json:"status"` - Line string `json:"line"` + Status string `json:"status"` + Line string `json:"line"` + AnyLines []string `json:"anyLines"` + BinaryBase64 string `json:"binaryBase64"` } func TestTcpClientInterop(t *testing.T) { @@ -75,44 +79,9 @@ func TestTcpClientInterop(t *testing.T) { srv, err := newTestTcpServer(sendToBackChannel) assert.NoError(t, err) - sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) assert.NoError(t, err) - - sender.Table(tc.Table) - for _, s := range tc.Symbols { - sender.Symbol(s.Name, s.Value) - } - for _, s := range tc.Columns { - switch s.Type { - case "LONG": - sender.Int64Column(s.Name, int64(s.Value.(float64))) - case "DOUBLE": - sender.Float64Column(s.Name, s.Value.(float64)) - case "STRING": - sender.StringColumn(s.Name, s.Value.(string)) - case "BOOLEAN": - sender.BoolColumn(s.Name, s.Value.(bool)) - default: - assert.Fail(t, "unexpected column type: "+s.Type) - } - } - - err = sender.AtNow(ctx) - - switch tc.Result.Status { - case "SUCCESS": - assert.NoError(t, err) - err = sender.Flush(ctx) - assert.NoError(t, err) - - expectLines(t, srv.BackCh, strings.Split(tc.Result.Line, "\n")) - case "ERROR": - assert.Error(t, err) - default: - assert.Fail(t, "unexpected test status: "+tc.Result.Status) - } - - sender.Close(ctx) + execute(t, ctx, sender, srv.BackCh, tc) srv.Close() }) } @@ -129,49 +98,92 @@ func TestHttpClientInterop(t *testing.T) { srv, err := newTestHttpServer(sendToBackChannel) assert.NoError(t, err) - sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion3)) assert.NoError(t, err) - - sender.Table(tc.Table) - for _, s := range tc.Symbols { - sender.Symbol(s.Name, s.Value) - } - for _, s := range tc.Columns { - switch s.Type { - case "LONG": - sender.Int64Column(s.Name, int64(s.Value.(float64))) - case "DOUBLE": - sender.Float64Column(s.Name, s.Value.(float64)) - case "STRING": - sender.StringColumn(s.Name, s.Value.(string)) - case "BOOLEAN": - sender.BoolColumn(s.Name, s.Value.(bool)) - default: - assert.Fail(t, "unexpected column type: "+s.Type) - } - } - - err = sender.AtNow(ctx) - - switch tc.Result.Status { - case "SUCCESS": - assert.NoError(t, err) - err = sender.Flush(ctx) - assert.NoError(t, err) - - expectLines(t, srv.BackCh, strings.Split(tc.Result.Line, "\n")) - case "ERROR": - assert.Error(t, err) - default: - assert.Fail(t, "unexpected test status: "+tc.Result.Status) - } - - sender.Close(ctx) + execute(t, ctx, sender, srv.BackCh, tc) srv.Close() }) } } +func execute(t *testing.T, ctx context.Context, sender qdb.LineSender, backCh chan string, tc testCase) { + sender.Table(tc.Table) + for _, s := range tc.Symbols { + sender.Symbol(s.Name, s.Value) + } + for _, s := range tc.Columns { + switch s.Type { + case "LONG": + sender.Int64Column(s.Name, int64(s.Value.(float64))) + case "DOUBLE": + sender.Float64Column(s.Name, s.Value.(float64)) + case "STRING": + sender.StringColumn(s.Name, s.Value.(string)) + case "BOOLEAN": + sender.BoolColumn(s.Name, s.Value.(bool)) + case "DECIMAL": + dec, err := parseDecimal64(s.Value.(string)) + assert.NoError(t, err) + sender.DecimalColumnScaled(s.Name, dec) + default: + assert.Fail(t, "unexpected column type: "+s.Type) + } + } + + err := sender.AtNow(ctx) + + switch tc.Result.Status { + case "SUCCESS": + assert.NoError(t, err) + err = sender.Flush(ctx) + assert.NoError(t, err) + + if len(tc.Result.BinaryBase64) > 0 { + expectBinaryBase64(t, backCh, tc.Result.BinaryBase64) + } else if len(tc.Result.AnyLines) > 0 { + expectAnyLines(t, backCh, tc.Result.AnyLines) + } else { + expectLines(t, backCh, strings.Split(tc.Result.Line, "\n")) + } + case "ERROR": + assert.Error(t, err) + default: + assert.Fail(t, "unexpected test status: "+tc.Result.Status) + } + + sender.Close(ctx) +} + +// parseDecimal64 quick and dirty parser for a decimal64 value from its string representation +func parseDecimal64(s string) (qdb.ScaledDecimal, error) { + // Remove whitespace + s = strings.TrimSpace(s) + + // Check for empty string + if s == "" { + return qdb.ScaledDecimal{}, fmt.Errorf("empty string") + } + + // Find the decimal point and remove it + pointIndex := strings.Index(s, ".") + if pointIndex != -1 { + s = strings.ReplaceAll(s, ".", "") + } + + // Parse the integer part + unscaled, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return qdb.ScaledDecimal{}, err + } + + scale := 0 + if pointIndex != -1 { + scale = len(s) - pointIndex + } + + return qdb.NewDecimalFromInt64(unscaled, uint32(scale)), nil +} + func readTestCases() (testCases, error) { file, err := os.Open("./test/interop/questdb-client-test/ilp-client-interop-test.json") if err != nil { diff --git a/sender.go b/sender.go index b6f38f28..f4e4abf8 100644 --- a/sender.go +++ b/sender.go @@ -40,6 +40,7 @@ var ( errFlushWithPendingMessage = errors.New("pending ILP message must be finalized with At or AtNow before calling Flush") errClosedSenderAt = errors.New("cannot queue new messages on a closed LineSender") errDoubleSenderClose = errors.New("double sender close") + errDecimalNotSupported = errors.New("current protocol version does not support decimal") ) // LineSender allows you to insert rows into QuestDB by sending ILP @@ -106,6 +107,33 @@ type LineSender interface { // '-', '*' '%%', '~', or a non-printable char. Float64Column(name string, val float64) LineSender + // DecimalColumn adds a decimal column value to the ILP message. + // + // Serializes the decimal value using the text representation. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + DecimalColumnString(name string, val string) LineSender + + // DecimalColumnScaled adds a decimal column value to the ILP message. + // + // Serializes the decimal value using the binary representation. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + DecimalColumnScaled(name string, val ScaledDecimal) LineSender + + // DecimalColumnShopspring adds a decimal column value to the ILP message. + // + // Serializes the decimal value using the binary representation. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', ”', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender + // StringColumn adds a string column value to the ILP message. // // Column name cannot contain any of the following characters: @@ -253,6 +281,7 @@ const ( protocolVersionUnset protocolVersion = 0 ProtocolVersion1 protocolVersion = 1 ProtocolVersion2 protocolVersion = 2 + ProtocolVersion3 protocolVersion = 3 ) type lineSenderConfig struct { @@ -479,8 +508,10 @@ func WithAutoFlushInterval(interval time.Duration) LineSenderOption { // - TCP transport does not negotiate the protocol version and uses [ProtocolVersion1] by // default. You must explicitly set [ProtocolVersion2] in order to ingest // arrays. +// - [ProtocolVersion3] enables decimal binary encoding (ILP v3). // // NOTE: QuestDB server version 9.0.0 or later is required for [ProtocolVersion2]. +// For [ProtocolVersion3], make sure the server advertises ILP v3 support via /settings. func WithProtocolVersion(version protocolVersion) LineSenderOption { return func(s *lineSenderConfig) { s.protocolVersion = version @@ -721,9 +752,9 @@ func validateConf(conf *lineSenderConfig) error { if conf.autoFlushInterval < 0 { return fmt.Errorf("auto flush interval is negative: %d", conf.autoFlushInterval) } - if conf.protocolVersion < protocolVersionUnset || conf.protocolVersion > ProtocolVersion2 { - return errors.New("current client only supports protocol version 1(text format for all datatypes), " + - "2(binary format for part datatypes) or explicitly unset") + if conf.protocolVersion < protocolVersionUnset || conf.protocolVersion > ProtocolVersion3 { + return errors.New("current client only supports protocol version 1 (text format for all datatypes), " + + "2 (binary format for floats/arrays), 3 (binary decimals) or explicitly unset") } return nil diff --git a/sender_pool.go b/sender_pool.go index 7da85726..eb71332d 100644 --- a/sender_pool.go +++ b/sender_pool.go @@ -314,6 +314,21 @@ func (ps *pooledSender) Float64Column(name string, val float64) LineSender { return ps } +func (ps *pooledSender) DecimalColumnString(name string, val string) LineSender { + ps.wrapped.DecimalColumnString(name, val) + return ps +} + +func (ps *pooledSender) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + ps.wrapped.DecimalColumnShopspring(name, val) + return ps +} + +func (ps *pooledSender) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + ps.wrapped.DecimalColumnScaled(name, val) + return ps +} + func (ps *pooledSender) StringColumn(name, val string) LineSender { ps.wrapped.StringColumn(name, val) return ps diff --git a/tcp_sender.go b/tcp_sender.go index f46c5de1..ac4e25be 100644 --- a/tcp_sender.go +++ b/tcp_sender.go @@ -50,6 +50,10 @@ type tcpLineSenderV2 struct { tcpLineSender } +type tcpLineSenderV3 struct { + tcpLineSenderV2 +} + func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, error) { var ( d net.Dialer @@ -136,12 +140,27 @@ func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, s.conn = conn - if conf.protocolVersion == protocolVersionUnset || conf.protocolVersion == ProtocolVersion1 { + pVersion := conf.protocolVersion + if pVersion == protocolVersionUnset { + pVersion = ProtocolVersion1 + } + + switch pVersion { + case ProtocolVersion1: return s, nil - } else { + case ProtocolVersion2: return &tcpLineSenderV2{ *s, }, nil + case ProtocolVersion3: + return &tcpLineSenderV3{ + tcpLineSenderV2{ + *s, + }, + }, nil + default: + conn.Close() + return nil, fmt.Errorf("unsupported protocol version %d", pVersion) } } @@ -184,6 +203,21 @@ func (s *tcpLineSender) Float64Column(name string, val float64) LineSender { return s } +func (s *tcpLineSender) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSender) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSender) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + func (s *tcpLineSender) StringColumn(name, val string) LineSender { s.buf.StringColumn(name, val) return s @@ -346,3 +380,93 @@ func (s *tcpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[floa s.buf.Float64ArrayNDColumn(name, values) return s } + +func (s *tcpLineSenderV2) DecimalColumnString(name string, val string) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSenderV2) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSenderV2) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.SetLastErr(errDecimalNotSupported) + return s +} + +func (s *tcpLineSenderV3) Table(name string) LineSender { + s.buf.Table(name) + return s +} + +func (s *tcpLineSenderV3) Symbol(name, val string) LineSender { + s.buf.Symbol(name, val) + return s +} + +func (s *tcpLineSenderV3) Int64Column(name string, val int64) LineSender { + s.buf.Int64Column(name, val) + return s +} + +func (s *tcpLineSenderV3) Long256Column(name string, val *big.Int) LineSender { + s.buf.Long256Column(name, val) + return s +} + +func (s *tcpLineSenderV3) TimestampColumn(name string, ts time.Time) LineSender { + s.buf.TimestampColumn(name, ts) + return s +} + +func (s *tcpLineSenderV3) StringColumn(name, val string) LineSender { + s.buf.StringColumn(name, val) + return s +} + +func (s *tcpLineSenderV3) BoolColumn(name string, val bool) LineSender { + s.buf.BoolColumn(name, val) + return s +} + +func (s *tcpLineSenderV3) Float64Column(name string, val float64) LineSender { + s.buf.Float64ColumnBinary(name, val) + return s +} + +func (s *tcpLineSenderV3) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.Float64Array1DColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.Float64Array2DColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.Float64Array3DColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.Float64ArrayNDColumn(name, values) + return s +} + +func (s *tcpLineSenderV3) DecimalColumnString(name string, val string) LineSender { + s.buf.DecimalColumnString(name, val) + return s +} + +func (s *tcpLineSenderV3) DecimalColumnScaled(name string, val ScaledDecimal) LineSender { + s.buf.DecimalColumnScaled(name, val) + return s +} + +func (s *tcpLineSenderV3) DecimalColumnShopspring(name string, val ShopspringDecimal) LineSender { + s.buf.DecimalColumnShopspring(name, val) + return s +} diff --git a/tcp_sender_test.go b/tcp_sender_test.go index ddc496e1..cf1f57bf 100644 --- a/tcp_sender_test.go +++ b/tcp_sender_test.go @@ -366,6 +366,25 @@ func TestArrayColumnUnsupportedInTCPProtocolV1(t *testing.T) { assert.Contains(t, err.Error(), "current protocol version does not support double-array") } +func TestDecimalColumnUnsupportedInTCPProtocolV2(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) + defer cancel() + + srv, err := newTestTcpServer(readAndDiscard) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + assert.NoError(t, err) + defer sender.Close(ctx) + + err = sender. + Table(testTable). + DecimalColumnString("price", "12.99"). + At(ctx, time.UnixMicro(1)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support decimal") +} + func BenchmarkLineSenderBatch1000(b *testing.B) { ctx := context.Background() diff --git a/test/interop/questdb-client-test b/test/interop/questdb-client-test index 42a30831..1aaa3f96 160000 --- a/test/interop/questdb-client-test +++ b/test/interop/questdb-client-test @@ -1 +1 @@ -Subproject commit 42a30831f1852ed0ac85ab466f5d9ad711b787ee +Subproject commit 1aaa3f96ab06c6bef7d1b08400c418ef87562e56 diff --git a/utils_test.go b/utils_test.go index d37cf5d5..de532ec5 100644 --- a/utils_test.go +++ b/utils_test.go @@ -26,6 +26,7 @@ package questdb_test import ( "bufio" + "encoding/base64" "encoding/json" "fmt" "io" @@ -39,6 +40,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) type serverType int64 @@ -72,7 +74,7 @@ func newTestTcpServer(serverType serverType) (*testServer, error) { } func newTestHttpServer(serverType serverType) (*testServer, error) { - return newTestServerWithProtocol(serverType, "http", []int{1, 2}) + return newTestServerWithProtocol(serverType, "http", []int{1, 2, 3}) } func newTestHttpServerWithErrMsg(serverType serverType, errMsg string) (*testServer, error) { @@ -341,3 +343,30 @@ func expectLines(t *testing.T, linesCh chan string, expected []string) { return reflect.DeepEqual(expected, actual) }, 10*time.Second, 100*time.Millisecond) } + +func expectAnyLines(t *testing.T, linesCh chan string, expected []string) { + assert.Eventually(t, func() bool { + select { + case l := <-linesCh: + return slices.Contains(expected, l) + default: + return false + } + }, 10*time.Second, 100*time.Millisecond) +} + +func expectBinaryBase64(t *testing.T, linesCh chan string, expected string) { + data, err := base64.StdEncoding.DecodeString(expected) + assert.NoError(t, err) + + actual := make([]byte, 0) + assert.Eventually(t, func() bool { + select { + case l := <-linesCh: + actual = append(actual, []byte(l+"\n")...) + default: + return false + } + return slices.Equal(data, actual) + }, 10*time.Second, 100*time.Millisecond) +}