forked from yuin/sesame
/
context.go
151 lines (134 loc) · 3.93 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package sesame
import (
"fmt"
"go/types"
"path"
"strings"
)
// MappingContext is an interface that contains contextual data for
// the generation.
type MappingContext struct {
absPkgPath string
imports map[string]string
varCount int
mapperFuncFields []*MapperFuncField
mapperFuncCount int
aliases map[string]int
}
// MapperFuncField is a mapper function field.
type MapperFuncField struct {
// FieldName is a name of the field.
FieldName string
// MapperFuncName is a name of the mapper function.
MapperFuncName string
// Source is a source type of the function.
Source types.Type
// Dest is a source type of the function.
Dest types.Type
}
// Signature returns a function signature.
func (m *MapperFuncField) Signature(mctx *MappingContext) string {
return fmt.Sprintf("func(%s) (%s, error)",
GetPreferableTypeSource(m.Source, mctx),
GetPreferableTypeSource(m.Dest, mctx))
}
// NewMappingContext returns new [MappingContext] .
func NewMappingContext(absPkgPath string) *MappingContext {
return &MappingContext{
absPkgPath: absPkgPath,
imports: map[string]string{},
aliases: map[string]int{},
mapperFuncFields: []*MapperFuncField{},
mapperFuncCount: 0,
}
}
// AbsolutePackagePath returns na absolute package path of a file will be
// generated this mapping.
func (c *MappingContext) AbsolutePackagePath() string {
return c.absPkgPath
}
// AddImport adds import path and generate new alias name for it.
func (c *MappingContext) AddImport(importpath string) {
if importpath == c.AbsolutePackagePath() {
return
}
if _, ok := c.imports[importpath]; !ok {
_, last := path.Split(importpath)
alias := strings.ReplaceAll(last, "-", "_")
if i := c.aliases[alias]; i > 0 {
alias = fmt.Sprintf("%s%d", alias, i+1)
}
c.imports[importpath] = alias
c.aliases[alias]++
}
}
// GetImportAlias returns an alias for the given import path.
func (c *MappingContext) GetImportAlias(path string) string {
c.AddImport(path)
v, ok := c.imports[path]
if !ok {
return ""
}
return v
}
// GetImportPath returns a fully qualified path for the given import alias.
// If alias is not found, GetImportPath returns given alias.
func (c *MappingContext) GetImportPath(alias string) string {
if alias == "" {
return c.AbsolutePackagePath()
}
for key, value := range c.imports {
if value == alias {
return key
}
}
return alias
}
// Imports returns a map of the all imports.
// Result map key is an import path and value is an alias.
func (c *MappingContext) Imports() map[string]string {
return c.imports
}
// NextVarCount returns a var count and increments it.
func (c *MappingContext) NextVarCount() int {
v := c.varCount
c.varCount++
return v
}
// AddMapperFuncField adds a mapper function and generates a field name for it.
func (c *MappingContext) AddMapperFuncField(sourceType types.Type, destType types.Type) {
sname := GetQualifiedTypeName(sourceType)
dname := GetQualifiedTypeName(destType)
if sname == dname {
return
}
mapperFuncName := mappersName(sourceType, destType)
for _, m := range c.mapperFuncFields {
if m.MapperFuncName == mapperFuncName {
return
}
}
fieldName := fmt.Sprintf("mapper%05d", c.mapperFuncCount)
c.mapperFuncCount++
c.mapperFuncFields = append(c.mapperFuncFields, &MapperFuncField{
FieldName: fieldName,
MapperFuncName: mapperFuncName,
Source: sourceType,
Dest: destType,
})
}
// GetMapperFuncFieldName returns a mapper function field name.
func (c *MappingContext) GetMapperFuncFieldName(sourceType types.Type, destType types.Type) *MapperFuncField {
c.AddMapperFuncField(sourceType, destType)
mapperFuncName := mappersName(sourceType, destType)
for _, m := range c.mapperFuncFields {
if m.MapperFuncName == mapperFuncName {
return m
}
}
return nil
}
// MapperFuncFields returns a list of [MapperFuncField] .
func (c *MappingContext) MapperFuncFields() []*MapperFuncField {
return c.mapperFuncFields
}