Skip to content
This repository has been archived by the owner on Mar 14, 2022. It is now read-only.

Commit

Permalink
Add custom Marshal/Unmarshal to support empty lists
Browse files Browse the repository at this point in the history
This makes it possible to round trip our data without loosing empty
lists. Note, dynamodb doesn't support empty strings, and these changes
do nothing to work around that.

Fixes #387

See: aws/aws-sdk-go#682
  • Loading branch information
jcoyne committed Apr 24, 2018
1 parent 94bb998 commit 781515d
Show file tree
Hide file tree
Showing 8 changed files with 606 additions and 26 deletions.
3 changes: 1 addition & 2 deletions db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/sul-dlss-labs/taco/datautils"
)

Expand Down Expand Up @@ -45,7 +44,7 @@ func (d *DynamodbDatabase) query(params *dynamodb.QueryInput) (*datautils.Resour

func respToResource(item map[string]*dynamodb.AttributeValue) (*datautils.Resource, error) {
var json datautils.JSONObject
if err := dynamodbattribute.UnmarshalMap(item, &json); err != nil {
if err := UnmarshalMap(item, &json); err != nil {
return nil, err
}

Expand Down
73 changes: 53 additions & 20 deletions db/database_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package db

import (
"encoding/json"
"io/ioutil"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -17,28 +19,28 @@ func TestRetrieveVersion(t *testing.T) {
id := "9999"
json := datautils.JSONObject{
"externalIdentifier": id,
"tacoIdentifier": "7777777",
"version": 1,
"label": "Hello world",
"tacoIdentifier": "7777777",
"version": 1,
"label": "Hello world",
}
if err := database.Insert(datautils.NewResource(json)); err != nil {
panic(err)
}

json = datautils.JSONObject{
"externalIdentifier": id,
"tacoIdentifier": "7777778",
"version": 2,
"label": "Middle one",
"tacoIdentifier": "7777778",
"version": 2,
"label": "Middle one",
}
if err := database.Insert(datautils.NewResource(json)); err != nil {
panic(err)
}
json = datautils.JSONObject{
"externalIdentifier": id,
"tacoIdentifier": "7777779",
"version": 3,
"label": "Hello world",
"tacoIdentifier": "7777779",
"version": 3,
"label": "Hello world",
}
if err := database.Insert(datautils.NewResource(json)); err != nil {
panic(err)
Expand All @@ -57,35 +59,53 @@ func TestRetrieveLatest(t *testing.T) {
id := "9999"
json := datautils.JSONObject{
"externalIdentifier": id,
"tacoIdentifier": "7777777",
"version": 1,
"label": "Hello world",
"tacoIdentifier": "7777777",
"version": 1,
"label": "Hello world",
}
if err := database.Insert(datautils.NewResource(json)); err != nil {
panic(err)
}

json = datautils.JSONObject{
"externalIdentifier": id,
"tacoIdentifier": "7777778",
"version": 2,
"label": "Middle one",
"tacoIdentifier": "7777778",
"version": 2,
"label": "Middle one",
}
if err := database.Insert(datautils.NewResource(json)); err != nil {
panic(err)
}
json = datautils.JSONObject{
"externalIdentifier": id,
"tacoIdentifier": "7777779",
"version": 3,
"label": "Hello world",
"tacoIdentifier": "7777779",
"version": 3,
"label": "Hello world",
}
if err := database.Insert(datautils.NewResource(json)); err != nil {
panic(err)
}
record, err := database.RetrieveLatest(id)
result, err := database.RetrieveLatest(id)
assert.Nil(t, err)
assert.Equal(t, 3, record.Version())
assert.Equal(t, 3, result.Version())
}

func TestRoundTrip(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
id := "7777777"
database := initDatabase()
jsonData := jsonData()
jsonData["externalIdentifier"] = id
resource := datautils.NewResource(jsonData)
err := database.Insert(resource)
assert.Nil(t, err)

result, err := database.RetrieveLatest(id)
assert.Nil(t, err)
// Data that comes out should be the same as the data that went in.
assert.Equal(t, jsonData, result.JSON)
}

func TestRetrieveLatestNotFound(t *testing.T) {
Expand All @@ -103,3 +123,16 @@ func initDatabase() Database {
Table: testConfig.ResourceTableName,
}
}

func jsonData() datautils.JSONObject {
byt, err := ioutil.ReadFile("../examples/update_request.json")
if err != nil {
panic(err)
}
var postData datautils.JSONObject

if err := json.Unmarshal(byt, &postData); err != nil {
panic(err)
}
return postData
}
235 changes: 235 additions & 0 deletions db/decode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
package db

import (
"reflect"
"strconv"
"time"

"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/sul-dlss-labs/taco/datautils"
)

// A Decoder provides unmarshaling AttributeValues to Go value types.
type Decoder struct {
// Instructs the decoder to decode AttributeValue Numbers as
// Number type instead of float64 when the destination type
// is interface{}. Similar to encoding/json.Number
UseNumber bool
}

// UnmarshalMap is an alias for Unmarshal which unmarshals from
// a map of AttributeValues.
//
// The output value provided must be a non-nil pointer
func UnmarshalMap(m map[string]*dynamodb.AttributeValue, out *datautils.JSONObject) error {
return NewDecoder().Decode(&dynamodb.AttributeValue{M: m}, out)
}

// NewDecoder creates a new Decoder with default configuration. Use
// the `opts` functional options to override the default configuration.
func NewDecoder(opts ...func(*Decoder)) *Decoder {
d := &Decoder{}
for _, o := range opts {
o(d)
}

return d
}

// Decode will unmarshal an AttributeValue into a Go value type. An error
// will be return if the decoder is unable to unmarshal the AttributeValue
// to the provide Go value type.
//
// The output value provided must be a non-nil pointer
func (d *Decoder) Decode(av *dynamodb.AttributeValue, out *datautils.JSONObject, opts ...func(*Decoder)) error {
v := reflect.ValueOf(out)
if v.Kind() != reflect.Ptr || v.IsNil() || !v.IsValid() {
return &dynamodbattribute.InvalidUnmarshalError{Type: reflect.TypeOf(out)}
}

return d.decode(av, v)
}

var stringInterfaceMapType = reflect.TypeOf(map[string]interface{}(nil))
var timeType = reflect.TypeOf(time.Time{})

func (d *Decoder) decode(av *dynamodb.AttributeValue, v reflect.Value) error {
var u dynamodbattribute.Unmarshaler
if av == nil || av.NULL != nil {
u, v = indirect(v, true)
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
return d.decodeNull(v)
}

u, v = indirect(v, false)
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}

switch {
case len(av.B) != 0:
panic("Not implemented B")
case av.BOOL != nil:
return d.decodeBool(av.BOOL, v)
case len(av.BS) != 0:
panic("Not implemented BS")
case av.L != nil:
return d.decodeList(av.L, v)
case len(av.M) != 0:
return d.decodeMap(av.M, v)
case av.N != nil:
return d.decodeNumber(av.N, v)
case len(av.NS) != 0:
panic("Not implemented NS")
case av.S != nil:
return d.decodeString(av.S, v)
case len(av.SS) != 0:
panic("Not implemented SS")
}

return nil
}

func (d *Decoder) decodeBool(b *bool, v reflect.Value) error {
switch v.Kind() {
case reflect.Bool, reflect.Interface:
v.Set(reflect.ValueOf(*b).Convert(v.Type()))
default:
return &dynamodbattribute.UnmarshalTypeError{Value: "bool", Type: v.Type()}
}

return nil
}

func (d *Decoder) decodeNumber(n *string, v reflect.Value) error {
// Default to float64 for all numbers
i, err := strconv.ParseFloat(*n, 64)
if err != nil {
return &dynamodbattribute.UnmarshalTypeError{Value: "number", Type: v.Type()}
}
v.Set(reflect.ValueOf(i))
return nil
}

func (d *Decoder) decodeList(avList []*dynamodb.AttributeValue, v reflect.Value) error {
switch v.Kind() {
case reflect.Interface:
s := make([]interface{}, len(avList))
for i, av := range avList {
if err := d.decode(av, reflect.ValueOf(&s[i]).Elem()); err != nil {
return err
}
}
v.Set(reflect.ValueOf(s))
return nil
default:
return &dynamodbattribute.UnmarshalTypeError{Value: "list", Type: v.Type()}
}
}

func (d *Decoder) decodeMap(avMap map[string]*dynamodb.AttributeValue, v reflect.Value) error {
switch v.Kind() {
case reflect.Map:
t := v.Type()
if t.Key().Kind() != reflect.String {
return &dynamodbattribute.UnmarshalTypeError{Value: "map string key", Type: t.Key()}
}
if v.IsNil() {
v.Set(reflect.MakeMap(t))
}
case reflect.Struct:
case reflect.Interface:
v.Set(reflect.MakeMap(stringInterfaceMapType))
v = v.Elem()
default:
return &dynamodbattribute.UnmarshalTypeError{Value: "map", Type: v.Type()}
}

if v.Kind() == reflect.Map {
for k, av := range avMap {
key := reflect.ValueOf(k)
elem := reflect.New(v.Type().Elem()).Elem()
if err := d.decode(av, elem); err != nil {
return err
}
v.SetMapIndex(key, elem)
}
} else if v.Kind() == reflect.Struct {
panic("Not implemented struct")
}

return nil
}

func (d *Decoder) decodeNull(v reflect.Value) error {
if v.IsValid() && v.CanSet() {
v.Set(reflect.Zero(v.Type()))
}

return nil
}

func (d *Decoder) decodeString(s *string, v reflect.Value) error {
// To maintain backwards compatibility with ConvertFrom family of methods which
// converted strings to time.Time structs
if v.Type().ConvertibleTo(timeType) {
t, err := time.Parse(time.RFC3339, *s)
if err != nil {
return err
}
v.Set(reflect.ValueOf(t).Convert(v.Type()))
return nil
}

switch v.Kind() {
case reflect.String:
v.SetString(*s)
case reflect.Interface:
// Ensure type aliasing is handled properly
v.Set(reflect.ValueOf(*s).Convert(v.Type()))
default:
return &dynamodbattribute.UnmarshalTypeError{Value: "string", Type: v.Type()}
}

return nil
}

// indirect will walk a value's interface or pointer value types. Returning
// the final value or the value a unmarshaler is defined on.
//
// Based on the enoding/json type reflect value type indirection in Go Stdlib
// https://golang.org/src/encoding/json/decode.go indirect func.
func indirect(v reflect.Value, decodingNull bool) (dynamodbattribute.Unmarshaler, reflect.Value) {
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}
for {
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) {
v = e
continue
}
}
if v.Kind() != reflect.Ptr {
break
}
if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() {
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
if u, ok := v.Interface().(dynamodbattribute.Unmarshaler); ok {
return u, reflect.Value{}
}
}
v = v.Elem()
}

return nil, v
}
Loading

0 comments on commit 781515d

Please sign in to comment.