Skip to content

Commit

Permalink
Add wrapper types for Text/Binary Marshaler and Stringer
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenafamo committed May 20, 2024
1 parent 77e9ec0 commit ce8198f
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `bobgen-sql` a code generation driver for SQL schema files. Supports PostgreSQL and SQLite.
- Add new properties `compare_expr` and `compare_expr_imports` to the `types` configuration. This is used when comparing primary keys and in testing.
- Add `never_required` to relationships configuration. This makes sure the factories does not require the relationship to be set. Useful if you're not using foreign keys. (thanks @jacobmolby)
- Add wrapper types for Stringer, TextMarshaler/Unmarshaler, and BinaryMarshaler/Unmarshaler to the `types` configuration.

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion gen/bobgen-atlas/driver/atlas.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (d *driver) translateColumnType(c drivers.Column, tableKey string, typ sche
}

if enum, ok := d.enums[enumName]; ok {
c.Type = enum.Type
c.Type = helpers.EnumType(d.types, enum.Type)
} else {
c.Type = "string"
}
Expand Down
27 changes: 17 additions & 10 deletions gen/bobgen-helpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ func GetConfigFromProvider[DriverConfig any](provider koanf.Provider, driverConf
return config, driverConfig, nil
}

const parrayImport = `"github.com/stephenafamo/bob/types/parray"`

func AddPgEnumType(types drivers.Types, enum string) string {
func EnumType(types drivers.Types, enum string) string {
types[enum] = drivers.Type{
NoRandomizationTest: true, // enums are often not random enough
RandomExpr: fmt.Sprintf(`all := all%s()
Expand All @@ -148,6 +146,8 @@ func AddPgEnumType(types drivers.Types, enum string) string {
return enum
}

const parrayImport = `"github.com/stephenafamo/bob/types/parray"`

func AddPgEnumArrayType(types drivers.Types, enum string) string {
typ := fmt.Sprintf("parray.EnumArray[%s]", enum)

Expand Down Expand Up @@ -206,18 +206,25 @@ func Types() drivers.Types {
return any(f.Time().TimeBetween(min, max)).(T)`,
CompareExpr: `AAA.Equal(BBB)`,
},
"netip.Addr": {
Imports: importers.List{`"net/netip"`},
"types.Binary[netip.Addr, *netip.Addr]": {
Imports: importers.List{
`"net/netip"`,
`"github.com/stephenafamo/bob/types"`,
},
RandomExpr: `var addr [4]byte
rand.Read(addr[:])
return any(netip.AddrFrom4(addr)).(T)`,
ipAddr := netip.AddrFrom4(addr)
return any(types.Binary[netip.Addr, *netip.Addr]{Val: ipAddr}).(T)`,
RandomExprImports: importers.List{`"crypto/rand"`},
},
"net.HardwareAddr": {
Imports: importers.List{`"net"`},
"types.Stringer[net.HardwareAddr]": {
Imports: importers.List{
`"net"`,
`"github.com/stephenafamo/bob/types"`,
},
RandomExpr: `addr, _ := net.ParseMAC(f.Internet().MacAddress())
return any(addr).(T)`,
CompareExpr: `slices.Equal(AAA, BBB)`,
return any(types.Stringer[net.HardwareAddr]{Val: addr}).(T)`,
CompareExpr: `slices.Equal(AAA.Val, BBB.Val)`,
CompareExprImports: importers.List{`"slices"`},
},
"pq.BoolArray": {
Expand Down
2 changes: 1 addition & 1 deletion gen/bobgen-prisma/driver/prisma.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func (d *driver) translateColumnType(c drivers.Column, isArray bool) drivers.Col
c.Type = "types.JSON[json.RawMessage]"
default:
if enum, ok := d.enums[c.DBType]; ok {
c.Type = enum.Type
c.Type = helpers.EnumType(d.types, enum.Type)
} else {
c.Type = "string"
}
Expand Down
6 changes: 3 additions & 3 deletions gen/bobgen-psql/driver/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ func (d *driver) translateColumnType(c drivers.Column, info colInfo) drivers.Col
case "uuid":
c.Type = "uuid.UUID"
case "inet", "cidr":
c.Type = "netip.Addr"
c.Type = "types.Binary[netip.Addr, *netip.Addr]"
case "macaddr":
c.Type = "net.HardwareAddr"
c.Type = "types.Stringer[net.HardwareAddr]"
case "ENUM":
c.Type = "string"
for _, e := range d.enums {
if e.Schema == info.UDTSchema && e.Name == info.UDTName {
c.Type = helpers.AddPgEnumType(d.types, e.Type)
c.Type = helpers.EnumType(d.types, e.Type)
}
}
case "ARRAY":
Expand Down
11 changes: 9 additions & 2 deletions gen/templates/factory/singleton/bobfactory_random_test.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
{{- if eq $colTyp "bool"}}{{continue}}{{end -}}
{{- if $typInfo.NoRandomizationTest}}{{continue}}{{end -}}
{{- $.Importer.ImportList $typInfo.Imports -}}

func TestRandom_{{$colTyp | replace "." "_" | replace "[" "_" | replace "]" "_"}}(t *testing.T) {
func TestRandom_{{
$colTyp
| replace " " "_"
| replace "." "_"
| replace "," "_"
| replace "*" "_"
| replace "[" "_"
| replace "]" "_"
}}(t *testing.T) {
t.Parallel()
seen := make([]{{$colTyp}}, 10)
Expand Down
22 changes: 22 additions & 0 deletions gen/templates/models/singleton/bob_main_test.go.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{{$doneTypes := dict }}
{{- range $table := .Tables}}
{{- $tAlias := $.Aliases.Table $table.Key}}
{{range $column := $table.Columns -}}
{{- $colTyp := $column.Type -}}
{{- if hasKey $doneTypes $colTyp}}{{continue}}{{end -}}
{{- $_ := set $doneTypes $colTyp nil -}}
{{- $typInfo := index $.Types $column.Type -}}
{{- if $typInfo.NoRandomizationTest}}{{continue}}{{end -}}
{{- if isPrimitiveType $colTyp}}{{continue}}{{end -}}
{{- if has $colTyp (list "bool" "[]byte" "time.Time")}}{{continue}}{{end -}}
{{- $.Importer.ImportList $typInfo.Imports -}}
{{$.Importer.Import "database/sql"}}
{{$.Importer.Import "database/sql/driver"}}
// Make sure the type {{$colTyp}} satisfies database/sql.Scanner
var _ sql.Scanner = &{{$colTyp}}{}

// Make sure the type {{$colTyp}} satisfies database/sql/driver.Valuer
var _ driver.Valuer = {{$colTyp}}{}

{{end -}}
{{- end}}
11 changes: 3 additions & 8 deletions types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,16 @@ func (j JSON[T]) Value() (driver.Value, error) {

// Scan implements the Scanner interface.
func (j *JSON[T]) Scan(value any) error {
var err error

switch x := value.(type) {
case string:
err = json.NewDecoder(bytes.NewBuffer([]byte(x))).Decode(j)
return json.NewDecoder(bytes.NewBuffer([]byte(x))).Decode(j)
case []byte:
err = json.NewDecoder(bytes.NewBuffer(x)).Decode(j)
return json.NewDecoder(bytes.NewBuffer(x)).Decode(j)
case nil:
return nil

default:
err = fmt.Errorf("cannot scan type %T: %v", value, value)
return fmt.Errorf("cannot scan type %T: %v", value, value)
}

return err
}

// UnmarshalJSON implements json.Unmarshaler.
Expand Down
89 changes: 89 additions & 0 deletions types/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package types

import (
"database/sql/driver"
"encoding"
"fmt"
)

type Text[T interface {
encoding.TextMarshaler
}, Tp interface {
*T
encoding.TextUnmarshaler
}] struct {
Val T
}

func (t Text[T, Tp]) Value() (driver.Value, error) {
return t.Val.MarshalText()
}

func (t *Text[T, Tp]) Scan(value any) error {
switch x := value.(type) {
case string:
v := Tp(&t.Val)
return v.UnmarshalText([]byte(x))
case []byte:
v := Tp(&t.Val)
return v.UnmarshalText(x)
case nil:
return nil
default:
return fmt.Errorf("cannot scan type %T: %v", value, value)
}
}

type Binary[T interface {
encoding.BinaryMarshaler
}, Tp interface {
*T
encoding.BinaryUnmarshaler
}] struct {
Val T
}

func (b Binary[T, Tp]) Value() (driver.Value, error) {
return b.Val.MarshalBinary()
}

func (b *Binary[T, Tp]) Scan(value any) error {
switch x := value.(type) {
case string:
v := Tp(&b.Val)
return v.UnmarshalBinary([]byte(x))
case []byte:
v := Tp(&b.Val)
return v.UnmarshalBinary(x)
case nil:
return nil
default:
return fmt.Errorf("cannot scan type %T: %v", value, value)
}
}

type Stringer[T interface {
~[]byte | ~string
fmt.Stringer
}] struct {
Val T
}

func (s Stringer[T]) Value() (driver.Value, error) {
return []byte(s.Val), nil
}

func (s *Stringer[T]) Scan(value any) error {
switch x := value.(type) {
case string:
s.Val = T(x)
return nil
case []byte:
s.Val = T(x)
return nil
case nil:
return nil
default:
return fmt.Errorf("cannot scan type %T: %v", value, value)
}
}
2 changes: 1 addition & 1 deletion types/parray/enum_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (e *EnumArray[T]) Scan(src any) error {
// Value implements the driver.Valuer interface.
func (e EnumArray[T]) Value() (driver.Value, error) {
if e == nil {
return nil, nil
return nil, nil //nolint:nilnil
}

arr := make(pq.StringArray, len(e))
Expand Down

0 comments on commit ce8198f

Please sign in to comment.