/
binder.go
116 lines (100 loc) · 2.97 KB
/
binder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package web
import (
"encoding/json"
stdErrors "errors"
"net/http"
"reflect"
"strconv"
"strings"
)
var (
//ErrContentTypeNotAllowed is used when POSTing a body that is not json
ErrContentTypeNotAllowed = stdErrors.New("Only Content-Type application/json is allowed")
)
//DefaultBinder is the default HTTP binder
type DefaultBinder struct {
}
//NewDefaultBinder creates a new default binder
func NewDefaultBinder() *DefaultBinder {
return &DefaultBinder{}
}
func methodHasBody(method string) bool {
return method == http.MethodPost ||
method == http.MethodDelete ||
method == http.MethodPut
}
//Bind request data to object i
func (b *DefaultBinder) Bind(target interface{}, c *Context) error {
if methodHasBody(c.Request.Method) && c.Request.ContentLength > 0 {
contentType := strings.Split(c.Request.GetHeader("Content-Type"), ";")
if len(contentType) == 0 || contentType[0] != JSONContentType {
return ErrContentTypeNotAllowed
}
if err := json.Unmarshal([]byte(c.Request.Body), target); err != nil {
return err
}
}
targetValue := reflect.ValueOf(target).Elem()
targetType := targetValue.Type()
for i := 0; i < targetValue.NumField(); i++ {
b.bindRoute(i, targetValue, targetType, c.params)
b.format(i, targetValue, targetType)
}
return nil
}
func (b *DefaultBinder) bindRoute(idx int, target reflect.Value, targetType reflect.Type, params StringMap) error {
name := targetType.Field(idx).Tag.Get("route")
if name != "" {
field := target.Field(idx)
fieldTypeKind := field.Type().Kind()
if isInt(fieldTypeKind) {
value, err := strconv.ParseInt(params[name], 10, 64)
if err != nil {
obj := reflect.New(field.Type())
m := obj.MethodByName("UnmarshalText")
if m.IsValid() {
b := []byte(params[name])
r := m.Call([]reflect.Value{reflect.ValueOf(b)})
if r[0].IsNil() {
field.Set(obj.Elem())
}
}
} else {
field.SetInt(value)
}
} else if isString(fieldTypeKind) {
field.SetString(params[name])
}
}
return nil
}
func (b *DefaultBinder) format(idx int, target reflect.Value, targetType reflect.Type) {
field := target.Field(idx)
fieldType := field.Type()
fieldTypeKind := fieldType.Kind()
format := targetType.Field(idx).Tag.Get("format")
if isString(fieldTypeKind) {
value := field.Interface().(string)
field.SetString(applyFormat(format, value))
} else if fieldTypeKind == reflect.Slice && isString(fieldType.Elem().Kind()) {
values := field.Interface().([]string)
for i, value := range values {
field.Index(i).SetString(applyFormat(format, value))
}
}
}
func isInt(k reflect.Kind) bool {
return k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || k == reflect.Int32 || k == reflect.Int64
}
func isString(k reflect.Kind) bool {
return k == reflect.String
}
func applyFormat(format string, value string) string {
value = strings.TrimSpace(value)
if format == "lower" {
value = strings.ToLower(value)
} else if format == "upper" {
value = strings.ToUpper(value)
}
return value
}