Skip to content

Commit

Permalink
feat: add bfloat16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jul 22, 2022
1 parent 1e8ca47 commit 510b9ca
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 69 deletions.
25 changes: 25 additions & 0 deletions ch/bfloat16/bfloat16.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package bfloat16

import (
"math"
)

type Map map[T]uint64

type T uint16

func From(f float64) T {
return FromFloat32(float32(f))
}

func FromFloat32(f float32) T {
return T(math.Float32bits(f) >> 16)
}

func (f T) Float32() float32 {
return math.Float32frombits(uint32(f) << 16)
}

func (f T) Float64() float64 {
return float64(f.Float32())
}
66 changes: 66 additions & 0 deletions ch/chschema/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strings"
"time"

"github.com/uptrace/go-clickhouse/ch/bfloat16"
"github.com/uptrace/go-clickhouse/ch/chproto"
"github.com/uptrace/go-clickhouse/ch/internal"

Expand Down Expand Up @@ -1164,3 +1165,68 @@ func newLCKeyType(typ int64) lcKey {
panic("not reached")
}
}

//------------------------------------------------------------------------------

type BFloat16HistColumn struct {
ColumnOf[bfloat16.Map]
}

var _ Columnar = (*BFloat16HistColumn)(nil)

func NewBFloat16HistColumn(typ reflect.Type, chType string, numRow int) Columnar {
return &BFloat16HistColumn{
ColumnOf: NewColumnOf[bfloat16.Map](numRow),
}
}

func (c BFloat16HistColumn) Type() reflect.Type {
return bfloat16MapType
}

func (c *BFloat16HistColumn) ReadFrom(rd *chproto.Reader, numRow int) error {
if numRow == 0 {
return nil
}

c.Alloc(numRow)

for i := range c.Column {
n, err := rd.Uvarint()
if err != nil {
return err
}

data := make(bfloat16.Map, n)

for j := 0; j < int(n); j++ {
value, err := rd.UInt16()
if err != nil {
return err
}

count, err := rd.UInt64()
if err != nil {
return err
}

data[bfloat16.T(value)] = count
}

c.Column[i] = data
}

return nil
}

func (c BFloat16HistColumn) WriteTo(wr *chproto.Writer) error {
for _, m := range c.Column {
wr.Uvarint(uint64(len(m)))

for k, v := range m {
wr.UInt16(uint16(k))
wr.UInt64(v)
}
}
return nil
}
120 changes: 69 additions & 51 deletions ch/chschema/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"time"

"github.com/uptrace/go-clickhouse/ch/bfloat16"
"github.com/uptrace/go-clickhouse/ch/chtype"
"github.com/uptrace/go-clickhouse/ch/internal"
)
Expand Down Expand Up @@ -41,50 +42,6 @@ var chType = [...]string{
reflect.UnsafePointer: "",
}

// keep in sync with ColumnFactory
func clickhouseType(typ reflect.Type) string {
switch typ {
case timeType:
return chtype.DateTime
case ipType:
return chtype.IPv6
}

kind := typ.Kind()
switch kind {
case reflect.Ptr:
if typ.Elem().Kind() == reflect.Struct {
return chtype.String
}
return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem()))
case reflect.Slice:
switch elem := typ.Elem(); elem.Kind() {
case reflect.Ptr:
if elem.Elem().Kind() == reflect.Struct {
return chtype.String // json
}
case reflect.Struct:
if elem != timeType {
return chtype.String // json
}
case reflect.Uint8:
return chtype.String // []byte
}

return "Array(" + clickhouseType(typ.Elem()) + ")"
case reflect.Array:
if isUUID(typ) {
return chtype.UUID
}
}

if s := chType[kind]; s != "" {
return s
}

panic(fmt.Errorf("ch: unsupported Go type: %s", typ))
}

type NewColumnFunc func(typ reflect.Type, chType string, numRow int) Columnar

var kindToColumn = [...]NewColumnFunc{
Expand Down Expand Up @@ -141,6 +98,13 @@ func ColumnFactory(typ reflect.Type, chType string) NewColumnFunc {
chType = chSubType(chType, "SimpleAggregateFunction(")
} else if s := dateTimeType(chType); s != "" {
chType = s
} else if funcName, _ := aggFuncNameAndType(chType); funcName != "" {
switch funcName {
case "quantileBFloat16", "quantilesBFloat16":
return NewBFloat16HistColumn
default:
panic(fmt.Errorf("unsupported ClickHouse type: %s", chType))
}
}

switch typ {
Expand Down Expand Up @@ -270,12 +234,13 @@ var (
float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))

stringType = reflect.TypeOf("")
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
uuidType = reflect.TypeOf((*UUID)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
stringType = reflect.TypeOf("")
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
uuidType = reflect.TypeOf((*UUID)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
bfloat16MapType = reflect.TypeOf((*bfloat16.Map)(nil)).Elem()

int64SliceType = reflect.TypeOf((*[]int64)(nil)).Elem()
uint64SliceType = reflect.TypeOf((*[]uint64)(nil)).Elem()
Expand Down Expand Up @@ -342,6 +307,51 @@ func goType(chType string) reflect.Type {
panic(fmt.Errorf("unsupported ClickHouse type=%q", chType))
}

// clickhouseType returns ClickHouse type for the given Go type.
// Keep in sync with ColumnFactory.
func clickhouseType(typ reflect.Type) string {
switch typ {
case timeType:
return chtype.DateTime
case ipType:
return chtype.IPv6
}

kind := typ.Kind()
switch kind {
case reflect.Ptr:
if typ.Elem().Kind() == reflect.Struct {
return chtype.String
}
return fmt.Sprintf("Nullable(%s)", clickhouseType(typ.Elem()))
case reflect.Slice:
switch elem := typ.Elem(); elem.Kind() {
case reflect.Ptr:
if elem.Elem().Kind() == reflect.Struct {
return chtype.String // json
}
case reflect.Struct:
if elem != timeType {
return chtype.String // json
}
case reflect.Uint8:
return chtype.String // []byte
}

return "Array(" + clickhouseType(typ.Elem()) + ")"
case reflect.Array:
if isUUID(typ) {
return chtype.UUID
}
}

if s := chType[kind]; s != "" {
return s
}

panic(fmt.Errorf("ch: unsupported Go type: %s", typ))
}

func chArrayElemType(s string) string {
s = chSubType(s, "Array(")
if s == "" {
Expand Down Expand Up @@ -401,7 +411,15 @@ func nullableType(s string) string {
}

func aggFuncNameAndType(chType string) (funcName, funcType string) {
s := chSubType(chType, "SimpleAggregateFunction(")
var s string

for _, prefix := range []string{"SimpleAggregateFunction(", "AggregateFunction("} {
s = chSubType(chType, prefix)
if s != "" {
break
}
}

if s == "" {
return "", ""
}
Expand Down
30 changes: 16 additions & 14 deletions ch/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,23 +261,25 @@ func parseDSN(dsn string) ([]Option, error) {

switch u.Scheme {
case "ch", "clickhouse":
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
}
opts = append(opts, WithAddr(addr))
}
// ok
default:
return nil, errors.New("ch: unknown scheme: " + u.Scheme)
}

if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
}
opts = append(opts, WithAddr(addr))
}

if host := q.string("host"); host != "" {
opts = append(opts, WithAddr(host))
}
default:
return nil, errors.New("ch: unknown scheme: " + u.Scheme)
if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
}

if host := q.string("host"); host != "" {
opts = append(opts, WithAddr(host))
}

if u.User != nil {
Expand Down
10 changes: 6 additions & 4 deletions ch/query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,21 @@ func (q *InsertQuery) Exec(ctx context.Context) (sql.Result, error) {
query := internal.String(queryBytes)

ctx, evt := q.db.beforeQuery(ctx, q, query, nil, q.tableModel)

var res *result
var retErr error

if q.tableModel != nil {
fields, err := q.getFields()
if err != nil {
return nil, err
}
res, err = q.db.insert(ctx, q.tableModel, query, fields)
res, retErr = q.db.insert(ctx, q.tableModel, query, fields)
} else {
res, err = q.db.exec(ctx, query)
res, retErr = q.db.exec(ctx, query)
}

q.db.afterQuery(ctx, evt, res, err)
q.db.afterQuery(ctx, evt, res, retErr)

return res, err
return res, retErr
}

0 comments on commit 510b9ca

Please sign in to comment.