diff --git a/CHANGELOG.md b/CHANGELOG.md index 30db969b..05cc1cb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add new properties `compare_expr` and `compare_expr_imports` to the `types` configuration. This is used when comparing primary keys and in testing. - Add `never_required` to relationships configuration. This makes sure the factories does not require the relationship to be set. Useful if you're not using foreign keys. (thanks @jacobmolby) - Add wrapper types for Stringer, TextMarshaler/Unmarshaler, and BinaryMarshaler/Unmarshaler to the `types` configuration. +- Make generated enum types implement the `fmt.Stringer`, `encoding.TextMarshaler`, `encoding.TextUnmarshaler`, `encoding.BinaryMarshaler` and `encoding.BinaryUnmarshaler` interfaces. ### Fixed diff --git a/gen/templates/models/singleton/bob_enums.go.tpl b/gen/templates/models/singleton/bob_enums.go.tpl index e01d19a3..e514e796 100644 --- a/gen/templates/models/singleton/bob_enums.go.tpl +++ b/gen/templates/models/singleton/bob_enums.go.tpl @@ -1,6 +1,10 @@ +{{if .Enums}} +{{$.Importer.Import "fmt"}} +{{$.Importer.Import "database/sql/driver"}} +{{end}} + {{- range $enum := $.Enums}} {{$allvals := "\n"}} - type {{$enum.Type}} string // Enum values for {{$enum.Type}} const ( @@ -15,5 +19,46 @@ return []{{$enum.Type}}{ {{$allvals}} } } + type {{$enum.Type}} string + + func (e {{$enum.Type}}) String() string { + return string(e) + } + + func (e {{$enum.Type}}) MarshalText() ([]byte, error) { + return []byte(e), nil + } + + func (e *{{$enum.Type}}) UnmarshalText(text []byte) error { + return e.Scan(text) + } + + func (e {{$enum.Type}}) MarshalBinary() ([]byte, error) { + return []byte(e), nil + } + + func (e *{{$enum.Type}}) UnmarshalBinary(data []byte) error { + return e.Scan(data) + } + + func (e {{$enum.Type}}) Value() (driver.Value, error) { + return string(e), nil + } + + func (e *{{$enum.Type}}) Scan(value any) error { + switch x := value.(type) { + case string: + *e = {{$enum.Type}}(x) + return nil + case []byte: + *e = {{$enum.Type}}(x) + return nil + case nil: + return nil + default: + return fmt.Errorf("cannot scan type %T: %v", value, value) + } + } + {{end -}} diff --git a/gen/templates/models/singleton/bob_main_test.go.tpl b/gen/templates/models/singleton/bob_main_test.go.tpl index 319810ae..16155d61 100644 --- a/gen/templates/models/singleton/bob_main_test.go.tpl +++ b/gen/templates/models/singleton/bob_main_test.go.tpl @@ -6,17 +6,16 @@ {{- if hasKey $doneTypes $colTyp}}{{continue}}{{end -}} {{- $_ := set $doneTypes $colTyp nil -}} {{- $typInfo := index $.Types $column.Type -}} - {{- if $typInfo.NoRandomizationTest}}{{continue}}{{end -}} + {{- if $typInfo.NoScannerValuerTest}}{{continue}}{{end -}} {{- if isPrimitiveType $colTyp}}{{continue}}{{end -}} - {{- if has $colTyp (list "bool" "[]byte" "time.Time")}}{{continue}}{{end -}} {{- $.Importer.ImportList $typInfo.Imports -}} {{$.Importer.Import "database/sql"}} {{$.Importer.Import "database/sql/driver"}} // Make sure the type {{$colTyp}} satisfies database/sql.Scanner - var _ sql.Scanner = &{{$colTyp}}{} + var _ sql.Scanner = (*{{$colTyp}})(nil) // Make sure the type {{$colTyp}} satisfies database/sql/driver.Valuer - var _ driver.Valuer = {{$colTyp}}{} + var _ driver.Valuer = *new({{$colTyp}}) {{end -}} {{- end}}