Skip to content

Commit

Permalink
SNOW-1416000 Simplify structured objects reading
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed May 22, 2024
1 parent a0145ec commit 5f4ac95
Show file tree
Hide file tree
Showing 3 changed files with 373 additions and 6 deletions.
21 changes: 15 additions & 6 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,12 @@ The data not have any corresponding schema, so values in table may be slightly d
Semistuctured variants, objects and arrays are always represented as strings for scanning:
rows, err := db.Query("SELECT {'a': 'b'}::OBJECT")
// handle error
defer rows.Close()
rows.Next()
var v string
err := rows.Scan(&v)
rows, err := db.Query("SELECT {'a': 'b'}::OBJECT")
// handle error
defer rows.Close()
rows.Next()
var v string
err := rows.Scan(&v)
When inserting, a marker indicating correct type must be used, for example:
Expand All @@ -466,6 +466,8 @@ Example table definition:
2. Implement sql.Scanner interface:
a)
func (so *simpleObject) Scan(val any) error {
st := val.(StructuredObject)
var err error
Expand All @@ -478,6 +480,13 @@ Example table definition:
return nil
}
b)
func (so *simpleObject) Scan(val any) error {
st := val.(StructuredObject)
return st.ScanTo(so)
}
3. Use it in regular scan:
var res simpleObject
Expand Down
168 changes: 168 additions & 0 deletions structured_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"math/big"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)

// ObjectType Empty marker of an object used in column type ScanType function
Expand Down Expand Up @@ -41,6 +43,7 @@ type StructuredObject interface {
GetNullTime(fieldName string) (sql.NullTime, error)
GetStruct(fieldName string, scanner sql.Scanner) (sql.Scanner, error)
GetRaw(fieldName string) (any, error)
ScanTo(so sql.Scanner) error
}

// ArrayOfScanners Helper type for scanning array of sql.Scanner values.
Expand Down Expand Up @@ -388,6 +391,153 @@ func (st *structuredType) GetRaw(fieldName string) (any, error) {
return st.values[fieldName], nil
}

func (st *structuredType) ScanTo(so sql.Scanner) error {
v := reflect.Indirect(reflect.ValueOf(so))
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if st.shouldIgnoreField(field) {
continue
}
switch field.Type.Kind() {
case reflect.String:
s, err := st.GetString(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).SetString(s)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := st.GetInt64(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).SetInt(i)
case reflect.Uint8:
b, err := st.GetByte(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).SetUint(uint64(int64(b)))
case reflect.Float32, reflect.Float64:
f, err := st.GetFloat64(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).SetFloat(f)
case reflect.Bool:
b, err := st.GetBool(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).SetBool(b)
case reflect.Slice, reflect.Array:
switch field.Type.Elem().Kind() {
case reflect.Uint8:
b, err := st.GetBytes(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).SetBytes(b)
default:
raw, err := st.GetRaw(st.getFieldName(field))
if err != nil {
return err
}
if raw != nil {
v.FieldByName(field.Name).Set(reflect.ValueOf(raw))
}
}
case reflect.Map:
raw, err := st.GetRaw(st.getFieldName(field))
if err != nil {
return err
}
if raw != nil {
v.FieldByName(field.Name).Set(reflect.ValueOf(raw))
}
case reflect.Struct:
a := v.FieldByName(field.Name).Interface()
if _, ok := a.(time.Time); ok {
time, err := st.GetTime(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(time))
} else if _, ok := a.(sql.Scanner); ok {
scanner := reflect.New(reflect.TypeOf(a)).Interface().(sql.Scanner)
s, err := st.GetStruct(st.getFieldName(field), scanner)
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.Indirect(reflect.ValueOf(s)))
} else if _, ok := a.(sql.NullString); ok {
ns, err := st.GetNullString(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(ns))
} else if _, ok := a.(sql.NullByte); ok {
nb, err := st.GetNullByte(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(nb))
} else if _, ok := a.(sql.NullBool); ok {
nb, err := st.GetNullBool(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(nb))
} else if _, ok := a.(sql.NullInt16); ok {
ni, err := st.GetNullInt16(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(ni))
} else if _, ok := a.(sql.NullInt32); ok {
ni, err := st.GetNullInt32(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(ni))
} else if _, ok := a.(sql.NullInt64); ok {
ni, err := st.GetNullInt64(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(ni))
} else if _, ok := a.(sql.NullFloat64); ok {
nf, err := st.GetNullFloat64(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(nf))
} else if _, ok := a.(sql.NullTime); ok {
nt, err := st.GetNullTime(st.getFieldName(field))
if err != nil {
return err
}
v.FieldByName(field.Name).Set(reflect.ValueOf(nt))
}
case reflect.Pointer:
switch field.Type.Elem().Kind() {
case reflect.Struct:
a := reflect.New(field.Type.Elem()).Interface()
s, err := st.GetStruct(st.getFieldName(field), a.(sql.Scanner))
if err != nil {
return err
}
if s != nil {
v.FieldByName(field.Name).Set(reflect.ValueOf(s))
}
default:
return errors.New("only struct pointers are supported")
}
}
}
return nil
}

func (st *structuredType) fieldMetadataByFieldName(fieldName string) (fieldMetadata, error) {
for _, fm := range st.fieldMetadata {
if fm.Name == fieldName {
Expand All @@ -405,3 +555,21 @@ func mapValuesNullableEnabled(ctx context.Context) bool {
d, ok := v.(bool)
return ok && d
}

func (st *structuredType) getFieldName(field reflect.StructField) string {
sfTag := field.Tag.Get("sf")
if sfTag != "" {
return strings.Split(sfTag, ",")[0]
}
r := []rune(field.Name)
r[0] = unicode.ToLower(r[0])
return string(r)
}

func (st *structuredType) shouldIgnoreField(field reflect.StructField) bool {
sfTag := field.Tag.Get("sf")
if sfTag == "" {
return false
}
return contains(strings.Split(sfTag, ",")[1:], "ignore")
}
Loading

0 comments on commit 5f4ac95

Please sign in to comment.