Skip to content

Commit

Permalink
Add context and dest for mapper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yuin committed Apr 6, 2024
1 parent 1dd50c8 commit 382a4d1
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 66 deletions.
31 changes: 18 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ type TodoMapperHelper interface {
}

type TodoMapper interface {
TodoModelToTodo(pkg00000.Context, *pkg00001.TodoModel) (*pkg00002.Todo, error)
TodoToTodoModel(pkg00000.Context, *pkg00002.Todo) (*pkg00001.TodoModel, error)
TodoModelToTodo(pkg00000.Context, *pkg00001.TodoModel, *pkg00002.Todo) error
TodoToTodoModel(pkg00000.Context, *pkg00002.Todo, *pkg00001.TodoModel) error
}

// ... (TodoMapper default implementation)
Expand All @@ -162,7 +162,8 @@ Mapping codes look like the following:
t.Fatal(err)
}
todoMapper, _ := obj.(TodoMapper)
entity, err := todoMapper.ModelToEntity(ctx, model)
var entity Todo
err := todoMapper.ModelToEntity(ctx, model, &entity)
```
### Custom mappers
Expand All @@ -180,16 +181,18 @@ Example: `string <-> time.Time` mapper
type TimeStringMapper struct {
}

func (m *TimeStringMapper) StringToTime(ctx context.Context, source string) (*time.Time, error) {
t, err := time.Parse(time.RFC3339, source)
if err != nil {
return nil, err
}
return &t, nil
func (m *TimeStringMapper) StringToTime(ctx context.Context, source string, dest *time.Time) error {
t, err := time.Parse(time.RFC3339, source)
if err != nil {
return err
}
*dest = t
return nil
}

func (m *TimeStringMapper) TimeToString(ctx context.Context, source *time.Time) (string, error) {
return source.Format(time.RFC3339), nil
func (m *TimeStringMapper) TimeToString(ctx context.Context, source *time.Time, dest *string) error {
*dest = source.Format(time.RFC3339)
return nil
}

type Mappers interface {
Expand Down Expand Up @@ -224,12 +227,14 @@ func AddTimeToStringMapper(mappers Mappers) {
`Mappers.AddMapperFuncFactory` takes qualified type names as arguments. A qualified type name is `FULL_PACKAGE_PATH#TYPENAME`(i.e. `time#Time`, `example.com/testmod/domain#Todo`).
Argument types and return types in custom mapping functions must be a
Source argument types in custom mapping functions must be a
- Raw value: primitive types(i.e. `string`, `int`, `slice` ...)
- Pointer: others
So `func (m *TimeStringMapper) TimeToString(ctx context.Context, source *time.Time) (string, error)` defines source type as a pointer(`*time.Time`) and return type as a raw value(`string`) .
Destination arguments are pointers.
So `func (m *TimeStringMapper) TimeToString(ctx context.Context, source *time.Time, dest *string) error` defines source type as a pointer(`*time.Time`).

`Mappers.Add` finds given mapper methods name like 'XxxToYyy' and calls `AddMapperFuncFactory`.

Expand Down
4 changes: 2 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ type MapperFuncField struct {

// Signature returns a function signature.
func (m *MapperFuncField) Signature(mctx *MappingContext) string {
return fmt.Sprintf("func(%s.Context, %s) (%s, error)",
return fmt.Sprintf("func(%s.Context, %s, %s) error",
mctx.GetImportAlias("context"),
GetPreferableTypeSource(m.Source, mctx),
GetPreferableTypeSource(m.Dest, mctx))
GetPointerTypeSource(m.Dest, mctx))
}

// NewMappingContext returns new [MappingContext] .
Expand Down
114 changes: 81 additions & 33 deletions gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ type MappingValue interface {
// CanGet returns true if this value is readable.
CanGet() bool

// CanAddr returns true if this value with getter is addressable.
CanAddr() bool

// GetSetterSource returns a source code of the setter.
GetSetterSource(valueSource string) string

Expand Down Expand Up @@ -431,6 +434,10 @@ func (v *localMappingValue) CanGet() bool {
return true
}

func (v *localMappingValue) CanAddr() bool {
return true
}

func (v *localMappingValue) GetSetterSource(valueSource string) string {
return fmt.Sprintf("%s = %s", v.name, valueSource)
}
Expand Down Expand Up @@ -514,6 +521,19 @@ func (v *objectPropertyMappingValue) GetGetterSource() string {
return ""
}

func (v *objectPropertyMappingValue) CanAddr() bool {
if len(v.exportedFieldName) != 0 {
return true
}
if v.getter != nil {
typ := v.getter.Type().(*types.Signature).Results().At(0).Type()
if _, ok := typ.(*types.Pointer); ok {
return true
}
}
return false
}

func (v *objectPropertyMappingValue) CanGet() bool {
return len(v.exportedFieldName) > 0 || v.getter != nil
}
Expand Down Expand Up @@ -753,7 +773,7 @@ func (g *generator) Generate() error {
a := elem.A
b := elem.B
aArgSource := GetPreferableTypeSource(a.Type(), mctx)
bArgSource := GetPreferableTypeSource(b.Type(), mctx)
bArgSource := GetPointerTypeSource(b.Type(), mctx)
p("type %sHelper interface {", mapping.Name)
p(" %s(%s.Context, %s, %s) error", mapping.MethodName(OperandA),
mctx.GetImportAlias("context"), aArgSource, bArgSource)
Expand All @@ -764,10 +784,10 @@ func (g *generator) Generate() error {
p("}")
p("")
p("type %s interface {", mapping.Name)
p("%s(%s.Context, %s) (%s, error) ", mapping.MethodName(OperandA),
p("%s(%s.Context, %s, %s) error", mapping.MethodName(OperandA),
mctx.GetImportAlias("context"), aArgSource, bArgSource)
if mapping.Bidirectional {
p("%s(%s.Context, %s) (%s, error) ", mapping.MethodName(OperandB),
p("%s(%s.Context, %s, %s) error", mapping.MethodName(OperandB),
mctx.GetImportAlias("context"), bArgSource, aArgSource)
}
p("}")
Expand Down Expand Up @@ -921,19 +941,18 @@ func genMapFunc(printer Printer, mapping *Mapping,
source types.Object, dest types.Object, typ OperandType, mctx *MappingContext) error {
p := printer.P

p("func (m *%s) %s(ctx %s.Context, source *%s) (*%s, error) {",
p("func (m *%s) %s(ctx %s.Context, source *%s, dest *%s) error {",
mapping.PrivateName(), mapping.MethodName(typ), mctx.GetImportAlias("context"),
GetSource(source.Type(), mctx), GetSource(dest.Type(), mctx))
p(" dest := &%s{}", GetSource(dest.Type(), mctx))
if err := genMapFuncBody(printer, source, "source", dest, "dest", &mapping.ObjectMapping, typ, mctx); err != nil {
return err
}
p(" if m.helper != nil {")
p(" if err := m.helper.%s(ctx, source, dest); err != nil {", mapping.MethodName(typ))
p(" return nil, err")
p(" return err")
p(" }")
p(" }")
p(" return dest, nil")
p(" return nil")
p("}")

return nil
Expand Down Expand Up @@ -1184,39 +1203,53 @@ func genAssignStmt(printer Printer,
sourceIsPointerPreferable := IsPointerPreferableType(sourceType)
destTypeName := GetQualifiedTypeName(destType)
_, destIsPointer := destType.(*types.Pointer)
destIsPointerPreferable := IsPointerPreferableType(destType)

// Try to execute custom mapper
argName := ""
switch {
case sourceIsPointerPreferable && sourceIsPointer:
argName = sourceSig
case sourceIsPointerPreferable && !sourceIsPointer:
argName = "&(" + sourceSig + ")"
case !sourceIsPointerPreferable && sourceIsPointer:
argName = "*(" + sourceSig + ")"
case !sourceIsPointerPreferable && !sourceIsPointer:
argName = sourceSig
}

mf := mctx.GetMapperFuncFieldName(sourceType, destType)
if mf != nil {
p("if m.%s != nil {", mf.FieldName)
p(" if v, err := m.%s(ctx, %s); err != nil {", mf.FieldName, argName)
p(" return nil, err")
p(" } else {")
var argName string
switch {
case destIsPointer && destIsPointerPreferable:
p(destValue.GetSetterSource("v"))
case destIsPointer && !destIsPointerPreferable:
p(destValue.GetSetterSource("&v"))
case !destIsPointer && destIsPointerPreferable:
p(destValue.GetSetterSource("*v"))
case !destIsPointer && !destIsPointerPreferable:
case sourceIsPointerPreferable && sourceIsPointer:
argName = sourceSig
case sourceIsPointerPreferable && !sourceIsPointer:
if sourceValue.CanAddr() {
argName = "&(" + sourceSig + ")"
} else {
p("s := %s", sourceSig)
argName = "&s"
}
case !sourceIsPointerPreferable && sourceIsPointer:
argName = "*(" + sourceSig + ")"
case !sourceIsPointerPreferable && !sourceIsPointer:
argName = sourceSig
}

var destName string
if destValue.CanAddr() {
if destIsPointer {
destName = destValue.GetGetterSource()
} else {
destName = "&(" + destValue.GetGetterSource() + ")"
}
} else {
p("var v %s", GetSource(destType, mctx))
if destIsPointer {
destName = "v"
} else {
destName = "&v"
}
}
p(" if err := m.%s(ctx, %s, %s); err != nil {", mf.FieldName, argName, destName)
p(" return err")
if destValue.CanAddr() {
p("}")
} else {
p(" } else {")
p(destValue.GetSetterSource("v"))
p(" }")
}
p(" }")
p("}")
p("} else { ")
}

if sourceTypeName == destTypeName {
Expand All @@ -1226,19 +1259,34 @@ func genAssignStmt(printer Printer,
case sourceIsPointer && !destIsPointer:
p(destValue.GetSetterSource("*(" + sourceSig + ")"))
case !sourceIsPointer && destIsPointer:
p(destValue.GetSetterSource("&(" + sourceSig + ")"))
if sourceValue.CanAddr() {
p(destValue.GetSetterSource("&(" + sourceSig + ")"))
} else {
a := mctx.NextVarCount()
p("s%d := ", a, sourceSig)
p(destValue.GetSetterSource(fmt.Sprintf("&s%d", a)))
}
case !sourceIsPointer && !destIsPointer:
p(destValue.GetSetterSource(sourceSig))

}
if mf != nil {
p("}")
}
return
}

if CanCast(sourceType, destType) {
genAssignStmt(printer,
NewLocalMappingValue(fmt.Sprintf("%s(%s)", GetSource(destType, mctx), sourceSig), destType),
destValue, mctx)
if mf != nil {
p("}")
}
return
}
if mf != nil {
p("}")
}

}
4 changes: 2 additions & 2 deletions mappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ func (d *mappers) Add(name string, obj any) {
method := typ.Method(i)
if mapperFuncNamePattern.MatchString(method.Name) {
ft := reflect.TypeOf(method.Func.Interface())
if ft.NumIn() != 3 || ft.NumOut() != 2 {
if ft.NumIn() != 4 || ft.NumOut() != 1 {
continue
}
in := ft.In(2)
if in.Kind() == reflect.Ptr {
in = in.Elem()
}
out := ft.Out(0)
out := ft.In(3)
if out.Kind() == reflect.Ptr {
out = out.Elem()
}
Expand Down
12 changes: 7 additions & 5 deletions testdata/testmod/mapper/my_mappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import (
type TimeStringMapper struct {
}

func (m *TimeStringMapper) StringToTime(ctx context.Context, source string) (*time.Time, error) {
func (m *TimeStringMapper) StringToTime(ctx context.Context, source string, dest *time.Time) error {
t, err := time.Parse(time.RFC3339, source)
if err != nil {
return nil, err
return err
}
return &t, nil
*dest = t
return nil
}

func (m *TimeStringMapper) TimeToString(ctx context.Context, source *time.Time) (string, error) {
return source.Format(time.RFC3339), nil
func (m *TimeStringMapper) TimeToString(ctx context.Context, source *time.Time, dest *string) error {
*dest = source.Format(time.RFC3339)
return nil
}

func AddTimeToStringMapper(mappers interface {
Expand Down

0 comments on commit 382a4d1

Please sign in to comment.