From 92523fd9ca36974555cf45d486ed8852855ca4f5 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 1 Aug 2017 10:42:38 +0200 Subject: [PATCH] scan bytea as []byte and vice-versa - []byte is now converted to bytea - []byte does no longer scan an array but a chunk of bytes in Scan - []byte does no longer has a postgres array but a chunk of bytes in Value Signed-off-by: Miguel Molina --- README.md | 3 +- generator/migration.go | 9 +++- generator/migration_test.go | 2 + types/slices.go | 66 +++++------------------ types/slices_test.go | 103 +++++++++++++++++++----------------- 5 files changed, 81 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index 81882e0..6d3e7fb 100644 --- a/README.md +++ b/README.md @@ -724,7 +724,8 @@ kallax migrate up --dir ./my-migrations --dsn 'user:pass@localhost:5432/dbname?s | `url.URL` | `text` | | `time.Time` | `timestamptz` | | `time.Duration` | `bigint` | -| `[]T` | `T'[]` * where `T'` is the SQL type of type `T` | +| `[]byte` | `bytea` | +| `[]T` | `T'[]` * where `T'` is the SQL type of type `T`, except for `T` = `byte` | | `map[K]V` | `jsonb` | | `struct` | `jsonb` | | `*struct` | `jsonb` | diff --git a/generator/migration.go b/generator/migration.go index a1886a3..fc88d2c 100644 --- a/generator/migration.go +++ b/generator/migration.go @@ -197,6 +197,7 @@ func (s *ColumnSchema) String() string { type ColumnType string const ( + ByteaColumn ColumnType = "bytea" SmallIntColumn ColumnType = "smallint" IntegerColumn ColumnType = "integer" BigIntColumn ColumnType = "bigint" @@ -225,6 +226,7 @@ func ArrayColumn(typ ColumnType) ColumnType { if strings.HasSuffix(string(typ), "[]") { return typ } + return typ + "[]" } @@ -833,7 +835,12 @@ func (t *packageTransformer) transformType(f *Field, pk bool) (ColumnType, error } if f.Kind == Array || f.Kind == Slice { - return ArrayColumn(typeMappings[removeTypePrefix(f.Type)]), nil + typ := removeTypePrefix(f.Type) + if typ == "byte" { + return ByteaColumn, nil + } + + return ArrayColumn(typeMappings[typ]), nil } if pk { diff --git a/generator/migration_test.go b/generator/migration_test.go index e130bdc..6d2b936 100644 --- a/generator/migration_test.go +++ b/generator/migration_test.go @@ -527,6 +527,7 @@ type Profile struct { // should be added anyway // should be added as bigint, as it is not a pk Metadata ProfileMetadata + SomeData []byte } type ProfileMetadata struct { @@ -569,6 +570,7 @@ func (s *PackageTransformerSuite) TestTransform() { mkCol("background", TextColumn, false, true, nil), mkCol("user_id", UUIDColumn, false, false, mkRef("users", "id", true)), mkCol("spouse", UUIDColumn, false, false, nil), + mkCol("some_data", ByteaColumn, false, true, nil), ), mkTable( "metadata", diff --git a/types/slices.go b/types/slices.go index aea57b0..952dbc3 100644 --- a/types/slices.go +++ b/types/slices.go @@ -71,10 +71,10 @@ func Slice(v interface{}) SQLType { return (*Int8Array)(&v) case *[]int8: return (*Int8Array)(v) - case []uint8: - return (*Uint8Array)(&v) - case *[]uint8: - return (*Uint8Array)(v) + case []byte: + return (*ByteArray)(&v) + case *[]byte: + return (*ByteArray)(v) case *[]float32: return (*Float32Array)(v) case []float32: @@ -646,67 +646,29 @@ func (a Int8Array) Value() (driver.Value, error) { return "{}", nil } -// Uint8Array represents a one-dimensional array of the PostgreSQL unsigned integer type. -type Uint8Array []uint8 +// ByteArray represents a byte array `bytea`. +type ByteArray []uint8 // Scan implements the sql.Scanner interface. -func (a *Uint8Array) Scan(src interface{}) error { +func (a *ByteArray) Scan(src interface{}) error { switch src := src.(type) { case []byte: - return a.scanBytes(src) + *(*[]byte)(a) = src + return nil case string: - return a.scanBytes([]byte(src)) + *(*[]byte)(a) = []byte(src) + return nil case nil: *a = nil return nil } - return fmt.Errorf("kallax: cannot convert %T to Uint8Array", src) -} - -func (a *Uint8Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Uint8Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Uint8Array, len(elems)) - for i, v := range elems { - val, err := strconv.ParseUint(string(v), 10, 8) - if err != nil { - return fmt.Errorf("kallax: parsing array element index %d: %v", i, err) - } - b[i] = uint8(val) - } - *a = b - } - return nil + return fmt.Errorf("kallax: cannot convert %T to ByteArray", src) } // Value implements the driver.Valuer interface. -func (a Uint8Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendUint(b, uint64(a[0]), 10) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendUint(b, uint64(a[i]), 10) - } - - return string(append(b, '}')), nil - } - - return "{}", nil +func (a ByteArray) Value() (driver.Value, error) { + return ([]byte)(a), nil } // Float32Array represents a one-dimensional array of the PostgreSQL real type. diff --git a/types/slices_test.go b/types/slices_test.go index f846f0f..f6584ef 100644 --- a/types/slices_test.go +++ b/types/slices_test.go @@ -14,8 +14,6 @@ import ( ) func TestSlice(t *testing.T) { - require := require.New(t) - cases := []struct { v interface{} input interface{} @@ -76,16 +74,6 @@ func TestSlice(t *testing.T) { []int8{1, 3, 4}, &([]int8{}), }, - { - &([]uint8{1, 3, 4}), - []uint8{1, 3, 4}, - &([]uint8{}), - }, - { - &([]byte{1, 3, 4}), - []byte{1, 3, 4}, - &([]byte{}), - }, { &([]float32{1., 3., .4}), []float32{1., 3., .4}, @@ -94,22 +82,35 @@ func TestSlice(t *testing.T) { } for _, c := range cases { - arr := Slice(c.v) - val, err := arr.Value() - require.Nil(err) + t.Run(reflect.TypeOf(c.input).String(), func(t *testing.T) { + require := require.New(t) + arr := Slice(c.v) + val, err := arr.Value() + require.NoError(err) + + pqArr := pq.Array(c.input) + pqVal, err := pqArr.Value() + require.NoError(err) + + require.Equal(pqVal, val) + require.NoError(Slice(c.dest).Scan(val)) + require.Equal(c.v, c.dest) + }) + } - pqArr := pq.Array(c.input) - pqVal, err := pqArr.Value() - require.Nil(err) + t.Run("[]byte", func(t *testing.T) { + require := require.New(t) + arr := Slice([]byte{1, 2, 3}) + val, err := arr.Value() + require.NoError(err) - require.Equal(pqVal, val) - require.Nil(Slice(c.dest).Scan(val)) - require.Equal(c.v, c.dest) - } + var b []byte + require.NoError(Slice(&b).Scan(val)) + require.Equal([]byte{1, 2, 3}, b) + }) } func TestSlice_Integration(t *testing.T) { - s := require.New(t) cases := []struct { name string typ string @@ -118,85 +119,91 @@ func TestSlice_Integration(t *testing.T) { }{ { "int8", - "smallint", + "smallint[]", []int8{math.MaxInt8, math.MinInt8}, &([]int8{}), }, { - "unsigned int8", - "smallint", - []uint8{math.MaxUint8, 0}, - &([]uint8{}), + "byte", + "bytea", + []byte{math.MaxUint8, 0}, + &([]byte{}), }, { "int16", - "smallint", + "smallint[]", []int16{math.MaxInt16, math.MinInt16}, &([]int16{}), }, { "unsigned int16", - "integer", + "integer[]", []uint16{math.MaxUint16, 0}, &([]uint16{}), }, { "int32", - "integer", + "integer[]", []int32{math.MaxInt32, math.MinInt32}, &([]int32{}), }, { "unsigned int32", - "bigint", + "bigint[]", []uint32{math.MaxUint32, 0}, &([]uint32{}), }, { "int/int64", - "bigint", + "bigint[]", []int{math.MaxInt64, math.MinInt64}, &([]int{}), }, { "unsigned int/int64", - "numeric(20)", + "numeric(20)[]", []uint{math.MaxUint64, 0}, &([]uint{}), }, { "float32", - "decimal(10,3)", + "decimal(10,3)[]", []float32{.3, .6}, &([]float32{.3, .6}), }, } db, err := openTestDB() - s.Nil(err) + require.NoError(t, err) defer func() { _, err = db.Exec("DROP TABLE IF EXISTS foo") - s.Nil(err) + require.NoError(t, err) - s.Nil(db.Close()) + require.NoError(t, db.Close()) }() for _, c := range cases { - _, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo ( - testcol %s[] + t.Run(c.name, func(t *testing.T) { + require := require.New(t) + + _, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo ( + testcol %s )`, c.typ)) - s.Nil(err, c.name) + require.NoError(err, c.name) - _, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input)) - s.Nil(err, c.name) + defer func() { + _, err := db.Exec("DROP TABLE foo") + require.NoError(err) + }() - s.Nil(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name) - slice := reflect.ValueOf(c.dst).Elem().Interface() - s.Equal(c.input, slice, c.name) + _, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input)) + require.NoError(err, c.name) - _, err = db.Exec("DROP TABLE foo") - s.Nil(err, c.name) + require.NoError(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name) + slice := reflect.ValueOf(c.dst).Elem().Interface() + require.Equal(c.input, slice, c.name) + }) } }