Skip to content

Commit

Permalink
fix: Do not extend integer to float if it's set by existing schema
Browse files Browse the repository at this point in the history
Same for UUID, Times, byte arrays. Obey the schema.
  • Loading branch information
efirs committed May 13, 2023
1 parent 1f1d50f commit 7560dfd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
2 changes: 0 additions & 2 deletions cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,11 @@ func guaranteeFloatsInFirstRecord(ctx context.Context, coll string, docs []json.

if cnt == 0 {
schema.DetectArrayOfObjects = true
schema.DetectIntegers = false
schema.ReplaceNumber = true

fixNumbers(ctx, coll, docs)

schema.DetectArrayOfObjects = false
schema.DetectIntegers = true
schema.ReplaceNumber = false

cnt, err := client.GetDB().Count(ctx, coll, driver.Filter("{}"))
Expand Down
29 changes: 16 additions & 13 deletions schema/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ func parseDateTime(s string) bool {
return false
}

func parseNumber(v any) (string, string, error) {
func parseNumber(v any, existing *schema.Field) (string, string, error) {
n, ok := v.(json.Number)
if !ok {
return "", "", ErrExpectedNumber
}

if _, err := n.Int64(); err != nil || !DetectIntegers {
_, err = n.Float64()
if err != nil {
if _, err := n.Int64(); err != nil || (!DetectIntegers && (existing == nil || existing.Type != typeInteger)) {
if _, err = n.Float64(); err != nil {
return "", "", err
}

Expand All @@ -94,27 +93,31 @@ func parseNumber(v any) (string, string, error) {
return typeInteger, "", nil
}

func translateStringType(v interface{}) (string, string, error) {
func needNarrowing(detect bool, existing *schema.Field, format string) bool {
return detect || existing != nil && existing.Format == format
}

func translateStringType(v any, existing *schema.Field) (string, string, error) {
t := reflect.TypeOf(v)

if t.PkgPath() == "encoding/json" && t.Name() == "Number" {
return parseNumber(v)
return parseNumber(v, existing)
}

s, ok := v.(string)
if !ok {
return "", "", ErrExpectedString
}

if parseDateTime(s) && DetectTimes {
if parseDateTime(s) && needNarrowing(DetectTimes, existing, formatDateTime) {
return typeString, formatDateTime, nil
}

if _, err := uuid.Parse(s); err == nil && DetectUUIDs {
if _, err := uuid.Parse(s); err == nil && needNarrowing(DetectUUIDs, existing, formatUUID) {
return typeString, formatUUID, nil
}

if len(s) != 0 && DetectByteArrays {
if len(s) != 0 && needNarrowing(DetectByteArrays, existing, formatByte) {
b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
if _, err := base64.StdEncoding.Decode(b, []byte(s)); err == nil {
return typeString, formatByte, nil
Expand All @@ -124,7 +127,7 @@ func translateStringType(v interface{}) (string, string, error) {
return typeString, "", nil
}

func translateType(v interface{}) (string, string, error) {
func translateType(v any, existing *schema.Field) (string, string, error) {
t := reflect.TypeOf(v)

//nolint:golint,exhaustive
Expand All @@ -134,7 +137,7 @@ func translateType(v interface{}) (string, string, error) {
case reflect.Float64:
return typeNumber, "", nil
case reflect.String:
return translateStringType(v)
return translateStringType(v, existing)
case reflect.Slice, reflect.Array:
return typeArray, "", nil
case reflect.Map:
Expand Down Expand Up @@ -226,7 +229,7 @@ func traverseObject(name string, existingField *schema.Field, newField *schema.F

func traverseArray(name string, existingField *schema.Field, newField *schema.Field, v any) error {
for i := 0; i < reflect.ValueOf(v).Len(); i++ {
t, format, err := translateType(reflect.ValueOf(v).Index(i).Interface())
t, format, err := translateType(reflect.ValueOf(v).Index(i).Interface(), existingField)
if err != nil {
return err
}
Expand Down Expand Up @@ -330,7 +333,7 @@ func traverseFields(sch map[string]*schema.Field, fields map[string]any, autoGen
continue
}

t, format, err := translateType(val)
t, format, err := translateType(val, sch[name])
if err != nil {
return err
}
Expand Down

0 comments on commit 7560dfd

Please sign in to comment.