Skip to content

Commit

Permalink
make the struct mapping output public
Browse files Browse the repository at this point in the history
  • Loading branch information
RangelReale committed Jun 6, 2023
1 parent 6903067 commit 3de7f7a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 42 deletions.
2 changes: 1 addition & 1 deletion interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ type TypeConverter interface {
type RowValidator = func(cols []string, vals []reflect.Value) bool

type StructMapperSource interface {
getMapping(reflect.Type) (mapping, error)
GetMapping(reflect.Type) (Mapping, error)
}
14 changes: 7 additions & 7 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ func (v visited) copy() visited {
}

type mapinfo struct {
name string
position []int
init [][]int
isPointer bool
Name string
Position []int
Init [][]int
IsPointer bool
}

type mapping []mapinfo
type Mapping []mapinfo

func (m mapping) cols() []string {
func (m Mapping) cols() []string {
cols := make([]string, len(m))
for i, info := range m {
cols[i] = info.name
cols[i] = info.Name
}

return cols
Expand Down
24 changes: 12 additions & 12 deletions mapper_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func structMapperFrom[T any](ctx context.Context, c cols, s StructMapperSource,
return ErrorMapper[T](err)
}

mapping, err := s.getMapping(typ)
mapping, err := s.GetMapping(typ)
if err != nil {
return ErrorMapper[T](err)
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func WithMapperMods(mods ...MapperMod) MappingOption {
}
}

func mapperFromMapping[T any](m mapping, typ reflect.Type, isPointer bool, opts mappingOptions) func(context.Context, cols) (func(*Row) (any, error), func(any) (T, error)) {
func mapperFromMapping[T any](m Mapping, typ reflect.Type, isPointer bool, opts mappingOptions) func(context.Context, cols) (func(*Row) (any, error), func(any) (T, error)) {
return func(ctx context.Context, c cols) (func(*Row) (any, error), func(any) (T, error)) {
// Filter the mapping so we only ask for the available columns
filtered, err := filterColumns(ctx, c, m, opts.structTagPrefix)
Expand Down Expand Up @@ -142,7 +142,7 @@ func mapperFromMapping[T any](m mapping, typ reflect.Type, isPointer bool, opts
type regular[T any] struct {
isPointer bool
typ reflect.Type
filtered mapping
filtered Mapping
converter TypeConverter
validator RowValidator
}
Expand All @@ -157,7 +157,7 @@ func (s regular[T]) regular() (func(*Row) (any, error), func(any) (T, error)) {
}

for _, info := range s.filtered {
for _, v := range info.init {
for _, v := range info.Init {
pv := row.FieldByIndex(v)
if !pv.IsZero() {
continue
Expand All @@ -166,8 +166,8 @@ func (s regular[T]) regular() (func(*Row) (any, error), func(any) (T, error)) {
pv.Set(reflect.New(pv.Type().Elem()))
}

fv := row.FieldByIndex(info.position)
v.ScheduleScanx(info.name, fv.Addr())
fv := row.FieldByIndex(info.Position)
v.ScheduleScanx(info.Name, fv.Addr())
}

return row, nil
Expand All @@ -189,9 +189,9 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
for i, info := range s.filtered {
var ft reflect.Type
if s.isPointer {
ft = s.typ.Elem().FieldByIndex(info.position).Type
ft = s.typ.Elem().FieldByIndex(info.Position).Type
} else {
ft = s.typ.FieldByIndex(info.position).Type
ft = s.typ.FieldByIndex(info.Position).Type
}

if s.converter != nil {
Expand All @@ -200,7 +200,7 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
row[i] = reflect.New(ft)
}

v.ScheduleScanx(info.name, row[i])
v.ScheduleScanx(info.Name, row[i])
}

return row, nil
Expand All @@ -220,7 +220,7 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
}

for i, info := range s.filtered {
for _, v := range info.init {
for _, v := range info.Init {
pv := row.FieldByIndex(v)
if !pv.IsZero() {
continue
Expand All @@ -236,8 +236,8 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
val = vals[i].Elem()
}

fv := row.FieldByIndex(info.position)
if info.isPointer {
fv := row.FieldByIndex(info.Position)
if info.IsPointer {
fv.Elem().Set(val)
} else {
fv.Set(val)
Expand Down
4 changes: 2 additions & 2 deletions mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,14 @@ func TestScannable(t *testing.T) {
t.Fatalf("couldn't get mapper source: %v", err)
}

m, err := src.getMapping(reflect.TypeOf(BlogWithScannableUser{}))
m, err := src.GetMapping(reflect.TypeOf(BlogWithScannableUser{}))
if err != nil {
t.Fatalf("couldn't get mapping: %v", err)
}

var marked bool
for _, info := range m {
if info.name == "user" {
if info.Name == "user" {
marked = true
}
}
Expand Down
40 changes: 20 additions & 20 deletions source.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func newDefaultMapperSourceImpl() *mapperSourceImpl {
fieldMapperFn: snakeCaseFieldFunc,
scannableTypes: []reflect.Type{reflect.TypeOf((*sql.Scanner)(nil)).Elem()},
maxDepth: 3,
cache: make(map[reflect.Type]mapping),
cache: make(map[reflect.Type]Mapping),
}
}

Expand Down Expand Up @@ -118,11 +118,11 @@ type mapperSourceImpl struct {
fieldMapperFn func(string) string
scannableTypes []reflect.Type
maxDepth int
cache map[reflect.Type]mapping
cache map[reflect.Type]Mapping
mutex sync.RWMutex
}

func (s *mapperSourceImpl) getMapping(typ reflect.Type) (mapping, error) {
func (s *mapperSourceImpl) GetMapping(typ reflect.Type) (Mapping, error) {
s.mutex.RLock()
m, ok := s.cache[typ]
s.mutex.RUnlock()
Expand All @@ -140,7 +140,7 @@ func (s *mapperSourceImpl) getMapping(typ reflect.Type) (mapping, error) {
return m, nil
}

func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visited, m *mapping, inits [][]int, position ...int) {
func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visited, m *Mapping, inits [][]int, position ...int) {
count := v[typ]
if count > s.maxDepth {
return
Expand All @@ -160,10 +160,10 @@ func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visite
for _, scannable := range s.scannableTypes {
if reflect.PtrTo(typ).Implements(scannable) {
*m = append(*m, mapinfo{
name: prefix,
position: position,
init: inits,
isPointer: isPointer,
Name: prefix,
Position: position,
Init: inits,
IsPointer: isPointer,
})
return
}
Expand Down Expand Up @@ -219,28 +219,28 @@ func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visite
}

*m = append(*m, mapinfo{
name: key,
position: currentIndex,
init: inits,
isPointer: isPointer,
Name: key,
Position: currentIndex,
Init: inits,
IsPointer: isPointer,
})
}

// If it has no exported field (such as time.Time) then we attempt to
// directly scan into it
if !hasExported {
*m = append(*m, mapinfo{
name: prefix,
position: position,
init: inits,
isPointer: isPointer,
Name: prefix,
Position: position,
Init: inits,
IsPointer: isPointer,
})
}
}

func filterColumns(ctx context.Context, c cols, m mapping, prefix string) (mapping, error) {
func filterColumns(ctx context.Context, c cols, m Mapping, prefix string) (Mapping, error) {
// Filter the mapping so we only ask for the available columns
filtered := make(mapping, 0, len(c))
filtered := make(Mapping, 0, len(c))
for _, name := range c {
key := name
if prefix != "" {
Expand All @@ -252,8 +252,8 @@ func filterColumns(ctx context.Context, c cols, m mapping, prefix string) (mappi
}

for _, info := range m {
if key == info.name {
info.name = name
if key == info.Name {
info.Name = name
filtered = append(filtered, info)
break
}
Expand Down

0 comments on commit 3de7f7a

Please sign in to comment.