Skip to content

Commit

Permalink
Merge pull request #215 from stephenafamo/wrap-marshallable-types
Browse files Browse the repository at this point in the history
Add wrapper types for Text/Binary Marshaler and Stringer
  • Loading branch information
stephenafamo committed May 21, 2024
2 parents 77e9ec0 + b09d8a2 commit a462d60
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Options for analysis running.
run:
timeout: 15m
skip-files: [scanto/fakedb_test.go]

linters:
# # Disable all linters.
Expand Down Expand Up @@ -65,3 +64,4 @@ issues:
- gocyclo
- dupl
- gosec
- nilnil
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `bobgen-sql` a code generation driver for SQL schema files. Supports PostgreSQL and SQLite.
- Add new properties `compare_expr` and `compare_expr_imports` to the `types` configuration. This is used when comparing primary keys and in testing.
- Add `never_required` to relationships configuration. This makes sure the factories does not require the relationship to be set. Useful if you're not using foreign keys. (thanks @jacobmolby)
- Add wrapper types for Stringer, TextMarshaler/Unmarshaler, and BinaryMarshaler/Unmarshaler to the `types` configuration.
- Make generated enum types implement the `fmt.Stringer`, `encoding.TextMarshaler`, `encoding.TextUnmarshaler`, `encoding.BinaryMarshaler` and `encoding.BinaryUnmarshaler` interfaces.

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion gen/bobgen-atlas/driver/atlas.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (d *driver) translateColumnType(c drivers.Column, tableKey string, typ sche
}

if enum, ok := d.enums[enumName]; ok {
c.Type = enum.Type
c.Type = helpers.EnumType(d.types, enum.Type)
} else {
c.Type = "string"
}
Expand Down
35 changes: 22 additions & 13 deletions gen/bobgen-helpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ func GetConfigFromProvider[DriverConfig any](provider koanf.Provider, driverConf
return config, driverConfig, nil
}

const parrayImport = `"github.com/stephenafamo/bob/types/parray"`

func AddPgEnumType(types drivers.Types, enum string) string {
func EnumType(types drivers.Types, enum string) string {
types[enum] = drivers.Type{
NoRandomizationTest: true, // enums are often not random enough
RandomExpr: fmt.Sprintf(`all := all%s()
Expand All @@ -148,6 +146,8 @@ func AddPgEnumType(types drivers.Types, enum string) string {
return enum
}

const parrayImport = `"github.com/stephenafamo/bob/types/parray"`

func AddPgEnumArrayType(types drivers.Types, enum string) string {
typ := fmt.Sprintf("parray.EnumArray[%s]", enum)

Expand Down Expand Up @@ -195,29 +195,38 @@ func AddPgGenericArrayType(types drivers.Types, singleTyp string) string {
func Types() drivers.Types {
return drivers.Types{
"[]byte": {
CompareExpr: `bytes.Equal(AAA, BBB)`,
CompareExprImports: importers.List{`"bytes"`},
CompareExpr: `bytes.Equal(AAA, BBB)`,
CompareExprImports: importers.List{`"bytes"`},
NoScannerValuerTest: true,
},
"time.Time": {
Imports: importers.List{`"time"`},
RandomExpr: `year := time.Hour * 24 * 365
min := time.Now().Add(-year)
max := time.Now().Add(year)
return any(f.Time().TimeBetween(min, max)).(T)`,
CompareExpr: `AAA.Equal(BBB)`,
CompareExpr: `AAA.Equal(BBB)`,
NoScannerValuerTest: true,
},
"netip.Addr": {
Imports: importers.List{`"net/netip"`},
"types.Text[netip.Addr, *netip.Addr]": {
Imports: importers.List{
`"net/netip"`,
`"github.com/stephenafamo/bob/types"`,
},
RandomExpr: `var addr [4]byte
rand.Read(addr[:])
return any(netip.AddrFrom4(addr)).(T)`,
ipAddr := netip.AddrFrom4(addr)
return any(types.Text[netip.Addr, *netip.Addr]{Val: ipAddr}).(T)`,
RandomExprImports: importers.List{`"crypto/rand"`},
},
"net.HardwareAddr": {
Imports: importers.List{`"net"`},
"types.Stringer[net.HardwareAddr]": {
Imports: importers.List{
`"net"`,
`"github.com/stephenafamo/bob/types"`,
},
RandomExpr: `addr, _ := net.ParseMAC(f.Internet().MacAddress())
return any(addr).(T)`,
CompareExpr: `slices.Equal(AAA, BBB)`,
return any(types.Stringer[net.HardwareAddr]{Val: addr}).(T)`,
CompareExpr: `slices.Equal(AAA.Val, BBB.Val)`,
CompareExprImports: importers.List{`"slices"`},
},
"pq.BoolArray": {
Expand Down
8 changes: 5 additions & 3 deletions gen/bobgen-mysql/driver/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func New(config Config) Interface {
config.Concurrency = 10
}

return &driver{config: config}
return &driver{config: config, types: helpers.Types()}
}

// driver holds the database connection string and a handle
Expand All @@ -59,6 +59,7 @@ type driver struct {
conn *sql.DB
dbName string

types drivers.Types
enums []drivers.Enum
enumMu sync.Mutex
}
Expand Down Expand Up @@ -222,7 +223,8 @@ func (d *driver) TableDetails(ctx context.Context, info drivers.TableInfo, colFi
column = d.translateColumnType(column, colFullType)
} else {
enumTyp := strmangle.TitleCase(tableName + "_" + colName)
column.Type = enumTyp
column.Type = helpers.EnumType(d.types, enumTyp)

d.enumMu.Lock()
d.enums = append(d.enums, drivers.Enum{
Type: enumTyp,
Expand Down Expand Up @@ -302,7 +304,7 @@ func (*driver) translateColumnType(c drivers.Column, fullType string) drivers.Co
}

func (d *driver) Types() drivers.Types {
return helpers.Types()
return d.types
}

func (d *driver) Constraints(ctx context.Context, _ drivers.ColumnFilter) (drivers.DBConstraints, error) {
Expand Down
2 changes: 1 addition & 1 deletion gen/bobgen-prisma/driver/prisma.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func (d *driver) translateColumnType(c drivers.Column, isArray bool) drivers.Col
c.Type = "types.JSON[json.RawMessage]"
default:
if enum, ok := d.enums[c.DBType]; ok {
c.Type = enum.Type
c.Type = helpers.EnumType(d.types, enum.Type)
} else {
c.Type = "string"
}
Expand Down
24 changes: 12 additions & 12 deletions gen/bobgen-psql/driver/psql.golden.json
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "cidr_nnull",
Expand All @@ -1258,7 +1258,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "circle_null",
Expand Down Expand Up @@ -1313,7 +1313,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "inet_nnull",
Expand All @@ -1324,7 +1324,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "line_null",
Expand Down Expand Up @@ -1379,7 +1379,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "net.HardwareAddr"
"type": "types.Stringer[net.HardwareAddr]"
},
{
"name": "macaddr_nnull",
Expand All @@ -1390,7 +1390,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "net.HardwareAddr"
"type": "types.Stringer[net.HardwareAddr]"
},
{
"name": "money_null",
Expand Down Expand Up @@ -3046,7 +3046,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "cidr_nnull",
Expand All @@ -3057,7 +3057,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "circle_null",
Expand Down Expand Up @@ -3112,7 +3112,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "inet_nnull",
Expand All @@ -3123,7 +3123,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "netip.Addr"
"type": "types.Text[netip.Addr, *netip.Addr]"
},
{
"name": "line_null",
Expand Down Expand Up @@ -3178,7 +3178,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "net.HardwareAddr"
"type": "types.Stringer[net.HardwareAddr]"
},
{
"name": "macaddr_nnull",
Expand All @@ -3189,7 +3189,7 @@
"generated": false,
"autoincr": false,
"domain_name": "",
"type": "net.HardwareAddr"
"type": "types.Stringer[net.HardwareAddr]"
},
{
"name": "money_null",
Expand Down
6 changes: 3 additions & 3 deletions gen/bobgen-psql/driver/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ func (d *driver) translateColumnType(c drivers.Column, info colInfo) drivers.Col
case "uuid":
c.Type = "uuid.UUID"
case "inet", "cidr":
c.Type = "netip.Addr"
c.Type = "types.Text[netip.Addr, *netip.Addr]"
case "macaddr":
c.Type = "net.HardwareAddr"
c.Type = "types.Stringer[net.HardwareAddr]"
case "ENUM":
c.Type = "string"
for _, e := range d.enums {
if e.Schema == info.UDTSchema && e.Name == info.UDTName {
c.Type = helpers.AddPgEnumType(d.types, e.Type)
c.Type = helpers.EnumType(d.types, e.Type)
}
}
case "ARRAY":
Expand Down
4 changes: 4 additions & 0 deletions gen/drivers/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type Type struct {
// Set this to true if the randomization should not be tested
// this is useful for low-cardinality types like bool
NoRandomizationTest bool `yaml:"no_randomization_test"`
// Set this to true if the test to see if the type implements
// the scanner and valuer interfaces should be skipped
// this is useful for types that are based on a primitive type
NoScannerValuerTest bool `yaml:"no_scanner_valuer_test"`
// CompareExpr is used to compare two values of this type
// if not provided, == is used
// Used AAA and BBB as placeholders for the two values
Expand Down
11 changes: 9 additions & 2 deletions gen/templates/factory/singleton/bobfactory_random_test.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
{{- if eq $colTyp "bool"}}{{continue}}{{end -}}
{{- if $typInfo.NoRandomizationTest}}{{continue}}{{end -}}
{{- $.Importer.ImportList $typInfo.Imports -}}

func TestRandom_{{$colTyp | replace "." "_" | replace "[" "_" | replace "]" "_"}}(t *testing.T) {
func TestRandom_{{
$colTyp
| replace " " "_"
| replace "." "_"
| replace "," "_"
| replace "*" "_"
| replace "[" "_"
| replace "]" "_"
}}(t *testing.T) {
t.Parallel()
seen := make([]{{$colTyp}}, 10)
Expand Down
47 changes: 46 additions & 1 deletion gen/templates/models/singleton/bob_enums.go.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
{{if .Enums}}
{{$.Importer.Import "fmt"}}
{{$.Importer.Import "database/sql/driver"}}
{{end}}

{{- range $enum := $.Enums}}
{{$allvals := "\n"}}
type {{$enum.Type}} string

// Enum values for {{$enum.Type}}
const (
Expand All @@ -15,5 +19,46 @@
return []{{$enum.Type}}{ {{$allvals}} }
}

type {{$enum.Type}} string

func (e {{$enum.Type}}) String() string {
return string(e)
}

func (e {{$enum.Type}}) MarshalText() ([]byte, error) {
return []byte(e), nil
}

func (e *{{$enum.Type}}) UnmarshalText(text []byte) error {
return e.Scan(text)
}

func (e {{$enum.Type}}) MarshalBinary() ([]byte, error) {
return []byte(e), nil
}

func (e *{{$enum.Type}}) UnmarshalBinary(data []byte) error {
return e.Scan(data)
}

func (e {{$enum.Type}}) Value() (driver.Value, error) {
return string(e), nil
}

func (e *{{$enum.Type}}) Scan(value any) error {
switch x := value.(type) {
case string:
*e = {{$enum.Type}}(x)
return nil
case []byte:
*e = {{$enum.Type}}(x)
return nil
case nil:
return nil
default:
return fmt.Errorf("cannot scan type %T: %v", value, value)
}
}

{{end -}}

21 changes: 21 additions & 0 deletions gen/templates/models/singleton/bob_main_test.go.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{{$doneTypes := dict }}
{{- range $table := .Tables}}
{{- $tAlias := $.Aliases.Table $table.Key}}
{{range $column := $table.Columns -}}
{{- $colTyp := $column.Type -}}
{{- if hasKey $doneTypes $colTyp}}{{continue}}{{end -}}
{{- $_ := set $doneTypes $colTyp nil -}}
{{- $typInfo := index $.Types $column.Type -}}
{{- if $typInfo.NoScannerValuerTest}}{{continue}}{{end -}}
{{- if isPrimitiveType $colTyp}}{{continue}}{{end -}}
{{- $.Importer.ImportList $typInfo.Imports -}}
{{$.Importer.Import "database/sql"}}
{{$.Importer.Import "database/sql/driver"}}
// Make sure the type {{$colTyp}} satisfies database/sql.Scanner
var _ sql.Scanner = (*{{$colTyp}})(nil)

// Make sure the type {{$colTyp}} satisfies database/sql/driver.Valuer
var _ driver.Valuer = *new({{$colTyp}})

{{end -}}
{{- end}}
2 changes: 1 addition & 1 deletion types/hstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (h *HStore) Scan(value any) error {
// database column value will be set to NULL.
func (h HStore) Value() (driver.Value, error) {
if h == nil {
return nil, nil
return nil, nil //nolint:nilnil
}
parts := []string{}
for key, val := range h {
Expand Down
Loading

0 comments on commit a462d60

Please sign in to comment.