Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Make the struct mapping output public" #6

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ type TypeConverter interface {
type RowValidator = func(cols []string, vals []reflect.Value) bool

type StructMapperSource interface {
GetMapping(reflect.Type) (Mapping, error)
getMapping(reflect.Type) (mapping, error)
}
14 changes: 7 additions & 7 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ func (v visited) copy() visited {
}

type mapinfo struct {
Name string
Position []int
Init [][]int
IsPointer bool
name string
position []int
init [][]int
isPointer bool
}

type Mapping []mapinfo
type mapping []mapinfo

func (m Mapping) cols() []string {
func (m mapping) cols() []string {
cols := make([]string, len(m))
for i, info := range m {
cols[i] = info.Name
cols[i] = info.name
}

return cols
Expand Down
24 changes: 12 additions & 12 deletions mapper_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func structMapperFrom[T any](ctx context.Context, c cols, s StructMapperSource,
return ErrorMapper[T](err)
}

mapping, err := s.GetMapping(typ)
mapping, err := s.getMapping(typ)
if err != nil {
return ErrorMapper[T](err)
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func WithMapperMods(mods ...MapperMod) MappingOption {
}
}

func mapperFromMapping[T any](m Mapping, typ reflect.Type, isPointer bool, opts mappingOptions) func(context.Context, cols) (func(*Row) (any, error), func(any) (T, error)) {
func mapperFromMapping[T any](m mapping, typ reflect.Type, isPointer bool, opts mappingOptions) func(context.Context, cols) (func(*Row) (any, error), func(any) (T, error)) {
return func(ctx context.Context, c cols) (func(*Row) (any, error), func(any) (T, error)) {
// Filter the mapping so we only ask for the available columns
filtered, err := filterColumns(ctx, c, m, opts.structTagPrefix)
Expand Down Expand Up @@ -142,7 +142,7 @@ func mapperFromMapping[T any](m Mapping, typ reflect.Type, isPointer bool, opts
type regular[T any] struct {
isPointer bool
typ reflect.Type
filtered Mapping
filtered mapping
converter TypeConverter
validator RowValidator
}
Expand All @@ -157,7 +157,7 @@ func (s regular[T]) regular() (func(*Row) (any, error), func(any) (T, error)) {
}

for _, info := range s.filtered {
for _, v := range info.Init {
for _, v := range info.init {
pv := row.FieldByIndex(v)
if !pv.IsZero() {
continue
Expand All @@ -166,8 +166,8 @@ func (s regular[T]) regular() (func(*Row) (any, error), func(any) (T, error)) {
pv.Set(reflect.New(pv.Type().Elem()))
}

fv := row.FieldByIndex(info.Position)
v.ScheduleScanx(info.Name, fv.Addr())
fv := row.FieldByIndex(info.position)
v.ScheduleScanx(info.name, fv.Addr())
}

return row, nil
Expand All @@ -189,9 +189,9 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
for i, info := range s.filtered {
var ft reflect.Type
if s.isPointer {
ft = s.typ.Elem().FieldByIndex(info.Position).Type
ft = s.typ.Elem().FieldByIndex(info.position).Type
} else {
ft = s.typ.FieldByIndex(info.Position).Type
ft = s.typ.FieldByIndex(info.position).Type
}

if s.converter != nil {
Expand All @@ -200,7 +200,7 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
row[i] = reflect.New(ft)
}

v.ScheduleScanx(info.Name, row[i])
v.ScheduleScanx(info.name, row[i])
}

return row, nil
Expand All @@ -220,7 +220,7 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
}

for i, info := range s.filtered {
for _, v := range info.Init {
for _, v := range info.init {
pv := row.FieldByIndex(v)
if !pv.IsZero() {
continue
Expand All @@ -236,8 +236,8 @@ func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error))
val = vals[i].Elem()
}

fv := row.FieldByIndex(info.Position)
if info.IsPointer {
fv := row.FieldByIndex(info.position)
if info.isPointer {
fv.Elem().Set(val)
} else {
fv.Set(val)
Expand Down
4 changes: 2 additions & 2 deletions mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,14 @@ func TestScannable(t *testing.T) {
t.Fatalf("couldn't get mapper source: %v", err)
}

m, err := src.GetMapping(reflect.TypeOf(BlogWithScannableUser{}))
m, err := src.getMapping(reflect.TypeOf(BlogWithScannableUser{}))
if err != nil {
t.Fatalf("couldn't get mapping: %v", err)
}

var marked bool
for _, info := range m {
if info.Name == "user" {
if info.name == "user" {
marked = true
}
}
Expand Down
40 changes: 20 additions & 20 deletions source.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func newDefaultMapperSourceImpl() *mapperSourceImpl {
fieldMapperFn: snakeCaseFieldFunc,
scannableTypes: []reflect.Type{reflect.TypeOf((*sql.Scanner)(nil)).Elem()},
maxDepth: 3,
cache: make(map[reflect.Type]Mapping),
cache: make(map[reflect.Type]mapping),
}
}

Expand Down Expand Up @@ -118,11 +118,11 @@ type mapperSourceImpl struct {
fieldMapperFn func(string) string
scannableTypes []reflect.Type
maxDepth int
cache map[reflect.Type]Mapping
cache map[reflect.Type]mapping
mutex sync.RWMutex
}

func (s *mapperSourceImpl) GetMapping(typ reflect.Type) (Mapping, error) {
func (s *mapperSourceImpl) getMapping(typ reflect.Type) (mapping, error) {
s.mutex.RLock()
m, ok := s.cache[typ]
s.mutex.RUnlock()
Expand All @@ -140,7 +140,7 @@ func (s *mapperSourceImpl) GetMapping(typ reflect.Type) (Mapping, error) {
return m, nil
}

func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visited, m *Mapping, inits [][]int, position ...int) {
func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visited, m *mapping, inits [][]int, position ...int) {
count := v[typ]
if count > s.maxDepth {
return
Expand All @@ -160,10 +160,10 @@ func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visite
for _, scannable := range s.scannableTypes {
if reflect.PtrTo(typ).Implements(scannable) {
*m = append(*m, mapinfo{
Name: prefix,
Position: position,
Init: inits,
IsPointer: isPointer,
name: prefix,
position: position,
init: inits,
isPointer: isPointer,
})
return
}
Expand Down Expand Up @@ -219,28 +219,28 @@ func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visite
}

*m = append(*m, mapinfo{
Name: key,
Position: currentIndex,
Init: inits,
IsPointer: isPointer,
name: key,
position: currentIndex,
init: inits,
isPointer: isPointer,
})
}

// If it has no exported field (such as time.Time) then we attempt to
// directly scan into it
if !hasExported {
*m = append(*m, mapinfo{
Name: prefix,
Position: position,
Init: inits,
IsPointer: isPointer,
name: prefix,
position: position,
init: inits,
isPointer: isPointer,
})
}
}

func filterColumns(ctx context.Context, c cols, m Mapping, prefix string) (Mapping, error) {
func filterColumns(ctx context.Context, c cols, m mapping, prefix string) (mapping, error) {
// Filter the mapping so we only ask for the available columns
filtered := make(Mapping, 0, len(c))
filtered := make(mapping, 0, len(c))
for _, name := range c {
key := name
if prefix != "" {
Expand All @@ -252,8 +252,8 @@ func filterColumns(ctx context.Context, c cols, m Mapping, prefix string) (Mappi
}

for _, info := range m {
if key == info.Name {
info.Name = name
if key == info.name {
info.name = name
filtered = append(filtered, info)
break
}
Expand Down