diff --git a/engine_pilosa_test.go b/engine_pilosa_test.go index 00380a1b9..a6da4c5a9 100644 --- a/engine_pilosa_test.go +++ b/engine_pilosa_test.go @@ -207,4 +207,3 @@ func TestCreateIndex(t *testing.T) { require.NoError(os.RemoveAll(tmpDir)) }() } - diff --git a/engine_test.go b/engine_test.go index 3ce2c8f81..1ba6ef32a 100644 --- a/engine_test.go +++ b/engine_test.go @@ -582,7 +582,7 @@ var queries = []struct { { `SELECT COALESCE(NULL, NULL, NULL, COALESCE(NULL, 1234567890))`, []sql.Row{ - {int64(1234567890)}, + {int32(1234567890)}, }, }, { @@ -981,7 +981,7 @@ var queries = []struct { }, { `SELECT -1`, - []sql.Row{{int64(-1)}}, + []sql.Row{{int8(-1)}}, }, { ` @@ -1043,13 +1043,13 @@ var queries = []struct { { `SELECT nullif(123, 321)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, NULL)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { @@ -1061,19 +1061,19 @@ var queries = []struct { { `SELECT ifnull(NULL, 123)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, 123)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { `SELECT ifnull(123, 321)`, []sql.Row{ - {int64(123)}, + {int8(123)}, }, }, { @@ -1085,7 +1085,7 @@ var queries = []struct { { `SELECT round(15, 1)`, []sql.Row{ - {int64(15)}, + {int8(15)}, }, }, { @@ -1452,7 +1452,7 @@ var queries = []struct { }, { `SELECT 1 FROM mytable GROUP BY i HAVING i > 1`, - []sql.Row{{int64(1)}, {int64(1)}}, + []sql.Row{{int8(1)}, {int8(1)}}, }, { `SELECT avg(i) FROM mytable GROUP BY i HAVING avg(i) > 1`, @@ -1887,8 +1887,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -1905,8 +1905,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -1923,8 +1923,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -1941,8 +1941,8 @@ func TestInsertInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2087,8 +2087,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -2105,8 +2105,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(math.MaxInt8), int64(math.MaxInt16), int64(math.MaxInt32), int64(math.MaxInt64), - int64(math.MaxUint8), int64(math.MaxUint16), int64(math.MaxUint32), uint64(math.MaxUint64), + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), float64(math.MaxFloat32), float64(math.MaxFloat64), timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), "random text", true, `{"key":"value"}`, "blobdata", @@ -2123,8 +2123,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2141,8 +2141,8 @@ func TestReplaceInto(t *testing.T) { []sql.Row{{int64(1)}}, "SELECT * FROM typestable WHERE id = 999;", []sql.Row{{ - int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), - int64(0), int64(0), int64(0), int64(0), + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), "", false, ``, "", @@ -2947,7 +2947,7 @@ func TestSessionVariables(t *testing.T) { rows, err := sql.RowIterToRows(iter) require.NoError(err) - require.Equal([]sql.Row{{int64(1), ",STRICT_TRANS_TABLES"}}, rows) + require.Equal([]sql.Row{{int8(1), ",STRICT_TRANS_TABLES"}}, rows) } func TestSessionVariablesONOFF(t *testing.T) { diff --git a/internal/sockstate/netstat_linux.go b/internal/sockstate/netstat_linux.go index 435c7deaa..a7eb6ff62 100644 --- a/internal/sockstate/netstat_linux.go +++ b/internal/sockstate/netstat_linux.go @@ -22,8 +22,8 @@ import ( const ( pathTCP4Tab = "/proc/net/tcp" pathTCP6Tab = "/proc/net/tcp6" - ipv4StrLen = 8 - ipv6StrLen = 32 + ipv4StrLen = 8 + ipv6StrLen = 32 ) type procFd struct { diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 4c8c1d757..c056f82da 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -203,6 +203,18 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { dVal = float64(dNum) case int32: dVal = float64(dNum) + case int16: + dVal = float64(dNum) + case int8: + dVal = float64(dNum) + case uint64: + dVal = float64(dNum) + case uint32: + dVal = float64(dNum) + case uint16: + dVal = float64(dNum) + case uint8: + dVal = float64(dNum) case int: dVal = float64(dNum) default: @@ -233,6 +245,18 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return int64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil case int32: return int32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int16: + return int16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int8: + return int8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint64: + return uint64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint32: + return uint32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint16: + return uint16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint8: + return uint8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil case int: return int(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil default: diff --git a/sql/expression/function/ceil_round_floor_test.go b/sql/expression/function/ceil_round_floor_test.go index 2ad014c42..4af2456ef 100644 --- a/sql/expression/function/ceil_round_floor_test.go +++ b/sql/expression/function/ceil_round_floor_test.go @@ -156,6 +156,50 @@ func TestRound(t *testing.T) { {"int32 with float d", sql.Int32, sql.Float64, sql.NewRow(int32(5), float32(2.123)), int32(5), nil}, {"int32 with float negative d", sql.Int32, sql.Float64, sql.NewRow(int32(52), float32(-1)), int32(50), nil}, {"int32 with blob d", sql.Int32, sql.Blob, sql.NewRow(int32(5), []byte{1, 2, 3}), int32(5), nil}, + {"int16 is nil", sql.Int16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"int16 without d", sql.Int16, sql.Int16, sql.NewRow(int16(5), nil), int16(5), nil}, + {"int16 with d", sql.Int16, sql.Int16, sql.NewRow(int16(5), 2), int16(5), nil}, + {"int16 with negative d", sql.Int16, sql.Int16, sql.NewRow(int16(52), -1), int16(50), nil}, + {"int16 with float d", sql.Int16, sql.Float64, sql.NewRow(int16(5), float32(2.123)), int16(5), nil}, + {"int16 with float negative d", sql.Int16, sql.Float64, sql.NewRow(int16(52), float32(-1)), int16(50), nil}, + {"int16 with blob d", sql.Int16, sql.Blob, sql.NewRow(int16(5), []byte{1, 2, 3}), int16(5), nil}, + {"int8 is nil", sql.Int8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"int8 without d", sql.Int8, sql.Int8, sql.NewRow(int8(5), nil), int8(5), nil}, + {"int8 with d", sql.Int8, sql.Int8, sql.NewRow(int8(5), 2), int8(5), nil}, + {"int8 with negative d", sql.Int8, sql.Int8, sql.NewRow(int8(52), -1), int8(50), nil}, + {"int8 with float d", sql.Int8, sql.Float64, sql.NewRow(int8(5), float32(2.123)), int8(5), nil}, + {"int8 with float negative d", sql.Int8, sql.Float64, sql.NewRow(int8(52), float32(-1)), int8(50), nil}, + {"int8 with blob d", sql.Int8, sql.Blob, sql.NewRow(int8(5), []byte{1, 2, 3}), int8(5), nil}, + {"uint64 is nil", sql.Uint64, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint64 without d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), nil), uint64(5), nil}, + {"uint64 with d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), 2), uint64(5), nil}, + {"uint64 with negative d", sql.Uint64, sql.Int32, sql.NewRow(uint64(52), -1), uint64(50), nil}, + {"uint64 with float d", sql.Uint64, sql.Float64, sql.NewRow(uint64(5), float32(2.123)), uint64(5), nil}, + {"uint64 with float negative d", sql.Uint64, sql.Float64, sql.NewRow(uint64(52), float32(-1)), uint64(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint32 is nil", sql.Uint32, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint32 without d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), nil), uint32(5), nil}, + {"uint32 with d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), 2), uint32(5), nil}, + {"uint32 with negative d", sql.Uint32, sql.Int32, sql.NewRow(uint32(52), -1), uint32(50), nil}, + {"uint32 with float d", sql.Uint32, sql.Float64, sql.NewRow(uint32(5), float32(2.123)), uint32(5), nil}, + {"uint32 with float negative d", sql.Uint32, sql.Float64, sql.NewRow(uint32(52), float32(-1)), uint32(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint16 is nil", sql.Uint16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"uint16 without d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), nil), uint16(5), nil}, + {"uint16 with d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), 2), uint16(5), nil}, + {"uint16 with negative d", sql.Uint16, sql.Int16, sql.NewRow(uint16(52), -1), uint16(50), nil}, + {"uint16 with float d", sql.Uint16, sql.Float64, sql.NewRow(uint16(5), float32(2.123)), uint16(5), nil}, + {"uint16 with float negative d", sql.Uint16, sql.Float64, sql.NewRow(uint16(52), float32(-1)), uint16(50), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, + {"uint8 is nil", sql.Uint8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"uint8 without d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), nil), uint8(5), nil}, + {"uint8 with d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), 2), uint8(5), nil}, + {"uint8 with negative d", sql.Uint8, sql.Int8, sql.NewRow(uint8(52), -1), uint8(50), nil}, + {"uint8 with float d", sql.Uint8, sql.Float64, sql.NewRow(uint8(5), float32(2.123)), uint8(5), nil}, + {"uint8 with float negative d", sql.Uint8, sql.Float64, sql.NewRow(uint8(52), float32(-1)), uint8(50), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, {"blob is nil", sql.Blob, sql.Int32, sql.NewRow(nil, nil), nil, nil}, {"blob is ok", sql.Blob, sql.Int32, sql.NewRow([]byte{1, 2, 3}, nil), int32(0), nil}, {"text int without d", sql.Text, sql.Int32, sql.NewRow("5", nil), int32(5), nil}, diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index 2385aec0d..c0dcaf3d4 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -363,8 +363,10 @@ func (d *YearWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } if val != nil { - if mode, ok = val.(int64); ok { - mode %= 8 // mode in [0, 7] + if i64, err := sql.Int64.Convert(val); err == nil { + if mode, ok = i64.(int64); ok { + mode %= 8 // mode in [0, 7] + } } } yyyy, week := calcWeek(yyyy, mm, dd, weekMode(mode)|weekBehaviourYear) diff --git a/sql/index/pilosa/lookup.go b/sql/index/pilosa/lookup.go index e50aeeabd..29173921b 100644 --- a/sql/index/pilosa/lookup.go +++ b/sql/index/pilosa/lookup.go @@ -621,6 +621,14 @@ func decodeGob(k []byte, value interface{}) (interface{}, error) { var v string err := decoder.Decode(&v) return v, err + case int8: + var v int8 + err := decoder.Decode(&v) + return v, err + case int16: + var v int16 + err := decoder.Decode(&v) + return v, err case int32: var v int32 err := decoder.Decode(&v) @@ -629,6 +637,14 @@ func decodeGob(k []byte, value interface{}) (interface{}, error) { var v int64 err := decoder.Decode(&v) return v, err + case uint8: + var v uint8 + err := decoder.Decode(&v) + return v, err + case uint16: + var v uint16 + err := decoder.Decode(&v) + return v, err case uint32: var v uint32 err := decoder.Decode(&v) @@ -688,6 +704,36 @@ func compare(a, b interface{}) (int, error) { } return strings.Compare(a, v), nil + case int8: + v, ok := b.(int8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int16: + v, ok := b.(int16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil case int32: v, ok := b.(int32) if !ok { @@ -717,6 +763,36 @@ func compare(a, b interface{}) (int, error) { return -1, nil } + return 1, nil + case uint8: + v, ok := b.(uint8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint16: + v, ok := b.(uint16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + return 1, nil case uint32: v, ok := b.(uint32) diff --git a/sql/index/pilosa/lookup_test.go b/sql/index/pilosa/lookup_test.go index e93a3e587..b93da3ca2 100644 --- a/sql/index/pilosa/lookup_test.go +++ b/sql/index/pilosa/lookup_test.go @@ -98,8 +98,12 @@ func TestCompare(t *testing.T) { func TestDecodeGob(t *testing.T) { testCases := []interface{}{ "foo", + int8(1), + int16(1), int32(1), int64(1), + uint8(1), + uint16(1), uint32(1), uint64(1), float64(1), diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 14d5e3045..e77819f2a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -680,8 +680,14 @@ func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*exp } nl, ok := e.(*expression.Literal) - if !ok || nl.Type() != sql.Int64 { + if !ok || !sql.IsInteger(nl.Type()) { return nil, ErrUnsupportedFeature.New(errStr) + } else { + i64, err := sql.Int64.Convert(nl.Value()) + if err != nil { + return nil, ErrUnsupportedFeature.New(errStr) + } + return expression.NewLiteral(i64, sql.Int64), nil } return nl, nil @@ -749,12 +755,14 @@ func selectToProjectOrGroupBy( for i, ge := range groupingExprs { // if GROUP BY index if l, ok := ge.(*expression.Literal); ok && sql.IsNumber(l.Type()) { - if idx, ok := l.Value().(int64); ok && idx > 0 && idx <= agglen { - aggexpr := selectExprs[idx-1] - if alias, ok := aggexpr.(*expression.Alias); ok { - aggexpr = expression.NewUnresolvedColumn(alias.Name()) + if i64, err := sql.Int64.Convert(l.Value()); err == nil { + if idx, ok := i64.(int64); ok && idx > 0 && idx <= agglen { + aggexpr := selectExprs[idx-1] + if alias, ok := aggexpr.(*expression.Alias); ok { + aggexpr = expression.NewUnresolvedColumn(alias.Name()) + } + groupingExprs[i] = aggexpr } - groupingExprs[i] = aggexpr } } } @@ -950,22 +958,46 @@ func isAggregateFunc(v *sqlparser.FuncExpr) bool { return v.IsAggregate() } +// Convert an integer, represented by the specified string in the specified +// base, to its smallest representation possible, out of: +// int8, uint8, int16, uint16, int32, uint32, int64 and uint64 +func convertInt(value string, base int) (sql.Expression, error) { + if i8, err := strconv.ParseInt(value, base, 8); err == nil { + return expression.NewLiteral(int8(i8), sql.Int8), nil + } + if ui8, err := strconv.ParseUint(value, base, 8); err == nil { + return expression.NewLiteral(uint8(ui8), sql.Uint8), nil + } + if i16, err := strconv.ParseInt(value, base, 16); err == nil { + return expression.NewLiteral(int16(i16), sql.Int16), nil + } + if ui16, err := strconv.ParseUint(value, base, 16); err == nil { + return expression.NewLiteral(uint16(ui16), sql.Uint16), nil + } + if i32, err := strconv.ParseInt(value, base, 32); err == nil { + return expression.NewLiteral(int32(i32), sql.Int32), nil + } + if ui32, err := strconv.ParseUint(value, base, 32); err == nil { + return expression.NewLiteral(uint32(ui32), sql.Uint32), nil + } + if i64, err := strconv.ParseInt(value, base, 64); err == nil { + return expression.NewLiteral(int64(i64), sql.Int64), nil + } + + ui64, err := strconv.ParseUint(value, base, 64) + if err != nil { + return nil, err + } + + return expression.NewLiteral(uint64(ui64), sql.Uint64), nil +} + func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { switch v.Type { case sqlparser.StrVal: return expression.NewLiteral(string(v.Val), sql.Text), nil case sqlparser.IntVal: - //TODO: Use smallest integer representation and widen later. - val, err := strconv.ParseInt(string(v.Val), 10, 64) - if err != nil { - // Might be a uint64 value that is greater than int64 max - val, checkErr := strconv.ParseUint(string(v.Val), 10, 64) - if checkErr != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Uint64), nil - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(string(v.Val), 10) case sqlparser.FloatVal: val, err := strconv.ParseFloat(string(v.Val), 64) if err != nil { @@ -980,11 +1012,7 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { v = strings.Trim(v[1:], "'") } - val, err := strconv.ParseInt(v, 16, 64) - if err != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(v, 16) case sqlparser.HexVal: val, err := v.HexDecode() if err != nil { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index be176a9c6..ad8bfc62d 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1,6 +1,7 @@ package parse import ( + "math" "testing" "github.com/src-d/go-mysql-server/sql/expression" @@ -181,7 +182,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("qux"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -258,7 +259,7 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), false, []string{"col1", "col2"}, @@ -267,7 +268,7 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), true, []string{"col1", "col2"}, @@ -360,7 +361,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("a"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -432,9 +433,9 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewNot( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -444,16 +445,16 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), ), `SELECT 0x01AF`: plan.NewProject( []sql.Expression{ - expression.NewLiteral(int64(431), sql.Int64), + expression.NewLiteral(int16(431), sql.Int16), }, plan.NewUnresolvedTable("dual", ""), ), @@ -470,12 +471,12 @@ var fixtures = map[string]sql.Node{ "somefunc", false, expression.NewTuple( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), ), expression.NewTuple( - expression.NewLiteral(int64(3), sql.Int64), - expression.NewLiteral(int64(4), sql.Int64), + expression.NewLiteral(int8(3), sql.Int8), + expression.NewLiteral(int8(4), sql.Int8), ), ), plan.NewUnresolvedTable("b", ""), @@ -486,7 +487,7 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewLiteral(":foo_id", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), plan.NewUnresolvedTable("foo", ""), ), @@ -510,13 +511,13 @@ var fixtures = map[string]sql.Node{ ), `SELECT CAST(-3 AS UNSIGNED) FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewConvert(expression.NewLiteral(int64(-3), sql.Int64), expression.ConvertToUnsigned), + expression.NewConvert(expression.NewLiteral(int8(-3), sql.Int8), expression.ConvertToUnsigned), }, plan.NewUnresolvedTable("foo", ""), ), `SELECT 2 = 2 FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewEquals(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(2), sql.Int64)), + expression.NewEquals(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(2), sql.Int8)), }, plan.NewUnresolvedTable("foo", ""), ), @@ -560,10 +561,10 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -573,10 +574,10 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewNotIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), plan.NewUnresolvedTable("foo", ""), @@ -585,12 +586,12 @@ var fixtures = map[string]sql.Node{ `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewSort( []plan.SortField{ { - Column: expression.NewLiteral(int64(2), sql.Int64), + Column: expression.NewLiteral(int8(2), sql.Int8), Order: plan.Ascending, NullOrdering: plan.NullsFirst, }, { - Column: expression.NewLiteral(int64(1), sql.Int64), + Column: expression.NewLiteral(int8(1), sql.Int8), Order: plan.Ascending, NullOrdering: plan.NullsFirst, }, @@ -605,22 +606,22 @@ var fixtures = map[string]sql.Node{ ), `SELECT 1 + 1;`: plan.NewProject( []sql.Expression{ - expression.NewPlus(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewPlus(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), }, plan.NewUnresolvedTable("dual", ""), ), `SELECT 1 * (2 + 1);`: plan.NewProject( []sql.Expression{ - expression.NewMult(expression.NewLiteral(int64(1), sql.Int64), - expression.NewPlus(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewMult(expression.NewLiteral(int8(1), sql.Int8), + expression.NewPlus(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, plan.NewUnresolvedTable("dual", ""), ), `SELECT (0 - 1) * (1 | 1);`: plan.NewProject( []sql.Expression{ expression.NewMult( - expression.NewMinus(expression.NewLiteral(int64(0), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), - expression.NewBitOr(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewMinus(expression.NewLiteral(int8(0), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), + expression.NewBitOr(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), ), }, plan.NewUnresolvedTable("dual", ""), @@ -628,8 +629,8 @@ var fixtures = map[string]sql.Node{ `SELECT (1 << 3) % (2 div 1);`: plan.NewProject( []sql.Expression{ expression.NewMod( - expression.NewShiftLeft(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(3), sql.Int64)), - expression.NewIntDiv(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewShiftLeft(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(3), sql.Int8)), + expression.NewIntDiv(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, plan.NewUnresolvedTable("dual", ""), ), @@ -645,7 +646,7 @@ var fixtures = map[string]sql.Node{ `SELECT '1.0' + 2;`: plan.NewProject( []sql.Expression{ expression.NewPlus( - expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int8(2), sql.Int8), ), }, plan.NewUnresolvedTable("dual", ""), @@ -714,7 +715,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedFunction( "max", true, expression.NewUnresolvedColumn("i"), ), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), "/", ), }, @@ -742,7 +743,7 @@ var fixtures = map[string]sql.Node{ `SET autocommit=1, foo="bar"`: plan.NewSet( plan.SetVariable{ Name: "autocommit", - Value: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral(int8(1), sql.Int8), }, plan.SetVariable{ Name: "foo", @@ -752,7 +753,7 @@ var fixtures = map[string]sql.Node{ `SET @@session.autocommit=1, foo="bar"`: plan.NewSet( plan.SetVariable{ Name: "@@session.autocommit", - Value: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral(int8(1), sql.Int8), }, plan.SetVariable{ Name: "foo", @@ -818,11 +819,11 @@ var fixtures = map[string]sql.Node{ `SET SESSION NET_READ_TIMEOUT= 700, SESSION NET_WRITE_TIMEOUT= 700`: plan.NewSet( plan.SetVariable{ Name: "@@session.net_read_timeout", - Value: expression.NewLiteral(int64(700), sql.Int64), + Value: expression.NewLiteral(int16(700), sql.Int16), }, plan.SetVariable{ Name: "@@session.net_write_timeout", - Value: expression.NewLiteral(int64(700), sql.Int64), + Value: expression.NewLiteral(int16(700), sql.Int16), }, ), `SET gtid_mode=DEFAULT`: plan.NewSet( @@ -975,11 +976,11 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), []expression.CaseBranch{ { - Cond: expression.NewLiteral(int64(1), sql.Int64), + Cond: expression.NewLiteral(int8(1), sql.Int8), Value: expression.NewLiteral("foo", sql.Text), }, { - Cond: expression.NewLiteral(int64(2), sql.Int64), + Cond: expression.NewLiteral(int8(2), sql.Int8), Value: expression.NewLiteral("bar", sql.Text), }, }, @@ -992,11 +993,11 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), []expression.CaseBranch{ { - Cond: expression.NewLiteral(int64(1), sql.Int64), + Cond: expression.NewLiteral(int8(1), sql.Int8), Value: expression.NewLiteral("foo", sql.Text), }, { - Cond: expression.NewLiteral(int64(2), sql.Int64), + Cond: expression.NewLiteral(int8(2), sql.Int8), Value: expression.NewLiteral("bar", sql.Text), }, }, @@ -1011,14 +1012,14 @@ var fixtures = map[string]sql.Node{ { Cond: expression.NewEquals( expression.NewUnresolvedColumn("foo"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), Value: expression.NewLiteral("foo", sql.Text), }, { Cond: expression.NewEquals( expression.NewUnresolvedColumn("foo"), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), Value: expression.NewLiteral("bar", sql.Text), }, @@ -1055,7 +1056,7 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", @@ -1066,7 +1067,7 @@ var fixtures = map[string]sql.Node{ []sql.Expression{expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "-", @@ -1076,7 +1077,7 @@ var fixtures = map[string]sql.Node{ `SELECT INTERVAL 1 DAY + '2018-05-01'`: plan.NewProject( []sql.Expression{expression.NewArithmetic( expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), expression.NewLiteral("2018-05-01", sql.Text), @@ -1089,13 +1090,13 @@ var fixtures = map[string]sql.Node{ expression.NewArithmetic( expression.NewLiteral("2018-05-01", sql.Text), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", ), expression.NewInterval( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), "DAY", ), "+", @@ -1105,7 +1106,7 @@ var fixtures = map[string]sql.Node{ `SELECT COUNT(*) FROM foo GROUP BY a HAVING COUNT(*) > 5`: plan.NewHaving( expression.NewGreaterThan( expression.NewUnresolvedFunction("count", true, expression.NewStar()), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewGroupBy( []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, @@ -1117,7 +1118,7 @@ var fixtures = map[string]sql.Node{ plan.NewHaving( expression.NewGreaterThan( expression.NewUnresolvedFunction("count", true, expression.NewStar()), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(5), sql.Int8), ), plan.NewGroupBy( []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, @@ -1132,8 +1133,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1143,8 +1144,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1154,8 +1155,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1165,8 +1166,8 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewUnresolvedTable("bar", ""), expression.NewEquals( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), ), ), ), @@ -1191,6 +1192,23 @@ var fixtures = map[string]sql.Node{ []sql.Expression{}, plan.NewUnresolvedTable("foo", ""), ), + `SELECT -128, 127, 255, -32768, 32767, 65535, -2147483648, 2147483647, 4294967295, -9223372036854775808, 9223372036854775807, 18446744073709551615`: plan.NewProject( + []sql.Expression{ + expression.NewLiteral(int8(math.MinInt8), sql.Int8), + expression.NewLiteral(int8(math.MaxInt8), sql.Int8), + expression.NewLiteral(uint8(math.MaxUint8), sql.Uint8), + expression.NewLiteral(int16(math.MinInt16), sql.Int16), + expression.NewLiteral(int16(math.MaxInt16), sql.Int16), + expression.NewLiteral(uint16(math.MaxUint16), sql.Uint16), + expression.NewLiteral(int32(math.MinInt32), sql.Int32), + expression.NewLiteral(int32(math.MaxInt32), sql.Int32), + expression.NewLiteral(uint32(math.MaxUint32), sql.Uint32), + expression.NewLiteral(int64(math.MinInt64), sql.Int64), + expression.NewLiteral(int64(math.MaxInt64), sql.Int64), + expression.NewLiteral(uint64(math.MaxUint64), sql.Uint64), + }, + plan.NewUnresolvedTable("dual", ""), + ), } func TestParse(t *testing.T) { diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 807191f4b..06e638d26 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -11,15 +11,12 @@ import ( // ErrInsertIntoNotSupported is thrown when a table doesn't support inserts var ErrInsertIntoNotSupported = errors.NewKind("table doesn't support INSERT INTO") var ErrReplaceIntoNotSupported = errors.NewKind("table doesn't support REPLACE INTO") -var ErrInsertIntoMismatchValueCount = - errors.NewKind("number of values does not match number of columns provided") +var ErrInsertIntoMismatchValueCount = errors.NewKind("number of values does not match number of columns provided") var ErrInsertIntoUnsupportedValues = errors.NewKind("%T is unsupported for inserts") var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v") var ErrInsertIntoNonexistentColumn = errors.NewKind("invalid column name %v") -var ErrInsertIntoNonNullableDefaultNullColumn = - errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") -var ErrInsertIntoNonNullableProvidedNull = - errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") +var ErrInsertIntoNonNullableDefaultNullColumn = errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") +var ErrInsertIntoNonNullableProvidedNull = errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") // InsertInto is a node describing the insertion into some table. type InsertInto struct { @@ -148,6 +145,20 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) { return i, err } + // Convert integer values in row to specified type in schema + for colIdx, oldValue := range row { + dstColType := projExprs[colIdx].Type() + + if sql.IsInteger(dstColType) && oldValue != nil { + newValue, err := dstColType.Convert(oldValue) + if err != nil { + return i, err + } + + row[colIdx] = newValue + } + } + if replaceable != nil { if err = replaceable.Delete(ctx, row); err != nil { if err != sql.ErrDeleteRowNotFound { diff --git a/sql/type.go b/sql/type.go index caadd9254..94e825328 100644 --- a/sql/type.go +++ b/sql/type.go @@ -335,10 +335,18 @@ func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { } switch t.t { + case sqltypes.Int8: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Int16: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int32: return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil case sqltypes.Int64: return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Uint8: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Uint16: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint32: return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil case sqltypes.Uint64: @@ -735,7 +743,6 @@ func (t charT) Compare(a interface{}, b interface{}) (int, error) { return strings.Compare(a.(string), b.(string)), nil } - type varCharT struct { length int } @@ -1155,12 +1162,12 @@ func IsNumber(t Type) bool { // IsSigned checks if t is a signed type. func IsSigned(t Type) bool { - return t == Int32 || t == Int64 + return t == Int8 || t == Int16 || t == Int32 || t == Int64 } // IsUnsigned checks if t is an unsigned type. func IsUnsigned(t Type) bool { - return t == Uint64 || t == Uint32 + return t == Uint8 || t == Uint16 || t == Uint32 || t == Uint64 } // IsInteger checks if t is a (U)Int32/64 type. diff --git a/sql/type_test.go b/sql/type_test.go index a73bf08be..5ca9ac9e5 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -46,16 +46,118 @@ func TestBoolean(t *testing.T) { eq(t, Boolean, false, false) } +// Test conversion of all types of numbers to the specified signed integer type +// in typ, where minusOne, zero and one are the expected values with the +// same type as typ +func testSignedInt(t *testing.T, typ Type, minusOne, zero, one interface{}) { + t.Helper() + + convert(t, typ, -1, minusOne) + convert(t, typ, int8(-1), minusOne) + convert(t, typ, int16(-1), minusOne) + convert(t, typ, int32(-1), minusOne) + convert(t, typ, int64(-1), minusOne) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convert(t, typ, "-1", minusOne) + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, minusOne, one) + eq(t, Int8, zero, zero) + eq(t, Int8, minusOne, minusOne) + eq(t, Int8, one, one) + gt(t, Int8, one, minusOne) +} + +// Test conversion of all types of numbers to the specified unsigned integer +// type in typ, where zero and one are the expected values with the same type +// as typ. The expected errors when converting from negative numbers are also +// tested +func testUnsignedInt(t *testing.T, typ Type, zero, one interface{}) { + t.Helper() + + convertErr(t, typ, -1) + convertErr(t, typ, int8(-1)) + convertErr(t, typ, int16(-1)) + convertErr(t, typ, int32(-1)) + convertErr(t, typ, int64(-1)) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convertErr(t, typ, "-1") + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, zero, one) + eq(t, Int8, zero, zero) + eq(t, Int8, one, one) + gt(t, Int8, one, zero) +} + +func TestInt8(t *testing.T) { + testSignedInt(t, Int8, int8(-1), int8(0), int8(1)) +} + +func TestInt16(t *testing.T) { + testSignedInt(t, Int16, int16(-1), int16(0), int16(1)) +} + func TestInt32(t *testing.T) { - convert(t, Int32, int32(1), int32(1)) - convert(t, Int32, 1, int32(1)) - convert(t, Int32, int64(1), int32(1)) - convert(t, Int32, "5", int32(5)) - convertErr(t, Int32, "") + testSignedInt(t, Int32, int32(-1), int32(0), int32(1)) +} - lt(t, Int32, int32(1), int32(2)) - eq(t, Int32, int32(1), int32(1)) - gt(t, Int32, int32(3), int32(2)) +func TestInt64(t *testing.T) { + testSignedInt(t, Int64, int64(-1), int64(0), int64(1)) +} + +func TestUint8(t *testing.T) { + testUnsignedInt(t, Uint8, uint8(0), uint8(1)) +} + +func TestUint16(t *testing.T) { + testUnsignedInt(t, Uint16, uint16(0), uint16(1)) +} + +func TestUint32(t *testing.T) { + testUnsignedInt(t, Uint32, uint32(0), uint32(1)) +} + +func TestUint64(t *testing.T) { + testUnsignedInt(t, Uint64, uint64(0), uint64(1)) } func TestNumberComparison(t *testing.T) { @@ -140,18 +242,6 @@ func TestNumberComparison(t *testing.T) { } } -func TestInt64(t *testing.T) { - convert(t, Int64, int32(1), int64(1)) - convert(t, Int64, 1, int64(1)) - convert(t, Int64, int64(1), int64(1)) - convertErr(t, Int64, "") - convert(t, Int64, "5", int64(5)) - - lt(t, Int64, int64(1), int64(2)) - eq(t, Int64, int64(1), int64(1)) - gt(t, Int64, int64(3), int64(2)) -} - func TestFloat64(t *testing.T) { require := require.New(t)