Skip to content

Commit

Permalink
Move some utility functions around
Browse files Browse the repository at this point in the history
  • Loading branch information
roblillack committed Mar 23, 2021
1 parent 648864e commit ba9ed71
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 135 deletions.
25 changes: 20 additions & 5 deletions cmd/mars-gen/reflect.go
Expand Up @@ -33,8 +33,6 @@ type SourceInfo struct {
// controllerSpecs lists type info for all structs found under
// app/controllers/... that embed (directly or indirectly) mars.Controller
controllerSpecs []*TypeInfo
// testSuites list the types that constitute the set of application tests.
testSuites []*TypeInfo
}

// TypeInfo summarizes information about a struct type in the app source code.
Expand Down Expand Up @@ -445,6 +443,15 @@ func getStructTypeDecl(decl ast.Decl, fset *token.FileSet) (spec *ast.TypeSpec,
return
}

func containsString(list []string, target string) bool {
for _, el := range list {
if el == target {
return true
}
}
return false
}

// TypesThatEmbed returns all types that (directly or indirectly) embed the
// target type, which must be a fully qualified type name,
// e.g. "github.com/roblillack/mars.Controller"
Expand All @@ -462,8 +469,7 @@ func (s *SourceInfo) TypesThatEmbed(targetType string) (filtered []*TypeInfo) {
// Look through all known structs.
for _, spec := range s.StructSpecs {
// If this one has been processed or is already in nodeQueue, then skip it.
if mars.ContainsString(processed, spec.String()) ||
mars.ContainsString(nodeQueue, spec.String()) {
if containsString(processed, spec.String()) || containsString(nodeQueue, spec.String()) {
continue
}

Expand Down Expand Up @@ -502,10 +508,19 @@ type TypeExpr struct {
Valid bool
}

func firstNonEmpty(strs ...string) string {
for _, str := range strs {
if len(str) > 0 {
return str
}
}
return ""
}

// TypeName returns the fully-qualified type name for this expression.
// The caller may optionally specify a package name to override the default.
func (e TypeExpr) TypeName(pkgOverride string) string {
pkgName := mars.FirstNonEmpty(pkgOverride, e.PkgName)
pkgName := firstNonEmpty(pkgOverride, e.PkgName)
if pkgName == "" {
return e.Expr
}
Expand Down
20 changes: 20 additions & 0 deletions cookie.go
@@ -0,0 +1,20 @@
package mars

import (
"net/url"
"regexp"
)

var (
cookieKeyValueParser = regexp.MustCompile("\x00([^:]*):([^\x00]*)\x00")
)

// parseKeyValueCookie takes the raw (escaped) cookie value and parses out key values.
func parseKeyValueCookie(val string, cb func(key, val string)) {
val, _ = url.QueryUnescape(val)
if matches := cookieKeyValueParser.FindAllStringSubmatch(val, -1); matches != nil {
for _, match := range matches {
cb(match[1], match[2])
}
}
}
2 changes: 1 addition & 1 deletion filterconfig.go
Expand Up @@ -85,7 +85,7 @@ func FilterAction(methodRef interface{}) FilterConfigurator {
}

controllerType := methodType.In(0)
method := FindMethod(controllerType, methodValue)
method := findMethod(controllerType, methodValue)
if method == nil {
panic("Action not found on controller " + controllerType.Name())
}
Expand Down
2 changes: 1 addition & 1 deletion flash.go
Expand Up @@ -66,7 +66,7 @@ func restoreFlash(req *http.Request) Flash {
Out: make(map[string]string),
}
if cookie, err := req.Cookie(CookiePrefix + "_FLASH"); err == nil {
ParseKeyValueCookie(cookie.Value, func(key, val string) {
parseKeyValueCookie(cookie.Value, func(key, val string) {
flash.Data[key] = val
})
}
Expand Down
18 changes: 18 additions & 0 deletions reflection.go
@@ -0,0 +1,18 @@
package mars

import (
"reflect"
)

// Return the reflect.Method, given a Receiver type and Func value.
func findMethod(recvType reflect.Type, funcVal reflect.Value) *reflect.Method {
// It is not possible to get the name of the method from the Func.
// Instead, compare it to each method of the Controller.
for i := 0; i < recvType.NumMethod(); i++ {
method := recvType.Method(i)
if method.Func.Pointer() == funcVal.Pointer() {
return &method
}
}
return nil
}
36 changes: 36 additions & 0 deletions reflection_test.go
@@ -0,0 +1,36 @@
package mars

import (
"reflect"
"testing"
)

type T struct{}

func (t *T) Hello() {}

func TestFindMethod(t *testing.T) {
for name, tv := range map[string]struct {
reflect.Type
reflect.Value
}{
"Hello": {reflect.TypeOf(&T{}), reflect.ValueOf((*T).Hello)},
"Helper": {reflect.TypeOf(t), reflect.ValueOf((*testing.T).Helper)},
"": {reflect.TypeOf(t), reflect.ValueOf((reflect.Type).Comparable)},
} {
m := findMethod(tv.Type, tv.Value)
if name == "" {
if m != nil {
t.Errorf("method found that shouldn't be here: %v", m)
}
continue
}
if m == nil {
t.Errorf("No method found when looking for %s", name)
continue
}
if m.Name != name {
t.Errorf("Expected method %s, got %s: %v", name, m.Name, m)
}
}
}
2 changes: 1 addition & 1 deletion results.go
Expand Up @@ -387,7 +387,7 @@ func getRedirectUrl(item interface{}) (string, error) {
if typ.Kind() == reflect.Func && typ.NumIn() > 0 {
// Get the Controller Method
recvType := typ.In(0)
method := FindMethod(recvType, val)
method := findMethod(recvType, val)
if method == nil {
return "", errors.New("couldn't find method")
}
Expand Down
2 changes: 1 addition & 1 deletion session.go
Expand Up @@ -125,7 +125,7 @@ func GetSessionFromCookie(cookie *http.Cookie) Session {
return session
}

ParseKeyValueCookie(data, func(key, val string) {
parseKeyValueCookie(data, func(key, val string) {
session[key] = val
})

Expand Down
11 changes: 10 additions & 1 deletion template.go
Expand Up @@ -470,6 +470,15 @@ func (loader *TemplateLoader) Template(name string, funcMaps ...Args) (Template,
return GoTemplate{tmpl, loader, funcMap}, err
}

// Reads the lines of the given file.
func readLines(filename string) ([]string, error) {
bytes, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
return strings.Split(string(bytes), "\n"), nil
}

// Adapter for Go Templates.
type GoTemplate struct {
*template.Template
Expand All @@ -487,7 +496,7 @@ func (gotmpl GoTemplate) Render(wr io.Writer, arg interface{}) error {
}

func (gotmpl GoTemplate) Content() []string {
content, _ := ReadLines(gotmpl.loader.templatePaths[gotmpl.Name()])
content, _ := readLines(gotmpl.loader.templatePaths[gotmpl.Name()])
return content
}

Expand Down
43 changes: 43 additions & 0 deletions testing/equal.go
@@ -0,0 +1,43 @@
package testing

import "reflect"

// Equal is a helper for comparing value equality, following these rules:
// - Values with equivalent types are compared with reflect.DeepEqual
// - int, uint, and float values are compared without regard to the type width.
// for example, Equal(int32(5), int64(5)) == true
// - strings and byte slices are converted to strings before comparison.
// - else, return false.
func Equal(a, b interface{}) bool {
if reflect.TypeOf(a) == reflect.TypeOf(b) {
return reflect.DeepEqual(a, b)
}
switch a.(type) {
case int, int8, int16, int32, int64:
switch b.(type) {
case int, int8, int16, int32, int64:
return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int()
}
case uint, uint8, uint16, uint32, uint64:
switch b.(type) {
case uint, uint8, uint16, uint32, uint64:
return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint()
}
case float32, float64:
switch b.(type) {
case float32, float64:
return reflect.ValueOf(a).Float() == reflect.ValueOf(b).Float()
}
case string:
switch b.(type) {
case []byte:
return a.(string) == string(b.([]byte))
}
case []byte:
switch b.(type) {
case string:
return b.(string) == string(a.([]byte))
}
}
return false
}
2 changes: 1 addition & 1 deletion util_test.go → testing/equal_test.go
@@ -1,4 +1,4 @@
package mars
package testing

import (
"reflect"
Expand Down
8 changes: 4 additions & 4 deletions testing/testsuite.go
Expand Up @@ -16,9 +16,9 @@ import (
"regexp"
"strings"

"github.com/roblillack/mars"

"golang.org/x/net/websocket"

"github.com/roblillack/mars"
)

type TestSuite struct {
Expand Down Expand Up @@ -276,13 +276,13 @@ func (t *TestSuite) AssertHeader(name, value string) {
}

func (t *TestSuite) AssertEqual(expected, actual interface{}) {
if !mars.Equal(expected, actual) {
if !Equal(expected, actual) {
panic(fmt.Errorf("(expected) %v != %v (actual)", expected, actual))
}
}

func (t *TestSuite) AssertNotEqual(expected, actual interface{}) {
if mars.Equal(expected, actual) {
if Equal(expected, actual) {
panic(fmt.Errorf("(expected) %v == %v (actual)", expected, actual))
}
}
Expand Down
119 changes: 0 additions & 119 deletions util.go

This file was deleted.

0 comments on commit ba9ed71

Please sign in to comment.