Skip to content

Commit

Permalink
lock down indirectType calls
Browse files Browse the repository at this point in the history
  • Loading branch information
s12chung committed Oct 9, 2023
1 parent ee08d9e commit 088fffb
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 46 deletions.
6 changes: 5 additions & 1 deletion pkg/firm/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
// NewDefinition returns a new Definition
func NewDefinition[T any]() *Definition {
var zero T
typ := reflect.TypeOf(zero)
if typ.Kind() == reflect.Pointer {
panic(fmt.Sprintf("NewDefinition created with pointer type, dereference it: %v", typ.String()))
}
validator := &Definition{
typ: indirectType(reflect.TypeOf(zero)),
typ: typ,
topLevelRules: []Rule{},
ruleMap: RuleMap{},
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/firm/firm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ var ValidateAny = DefaultRegistry.ValidateAny
var DefaultRegistry = &Registry{}

// DefaultValidator is the validator used by registries for not found types when DefaultValidator is not defined
var DefaultValidator = MustNewValue[Any](NotFoundRule{})
var DefaultValidator = RuleValidator{Rule: NotFoundRule{}}

// NotFoundRule is the rule used for not found types in the DefaultValidator
type NotFoundRule struct{}
Expand Down Expand Up @@ -54,7 +54,6 @@ type BasicRule interface {
// Validator validates the data
type Validator interface {
Rule
Type() reflect.Type
ValidateAny(data any) ErrorMap
ValidateMerge(value reflect.Value, key string, errorMap ErrorMap)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/firm/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ var integrationAnyTestCases = []integrationTestCase{
{name: "Data___topLevelValidates_half_raw", isValid: true, anyF: func() any {
return topLevelValidates{Primitive: 1}
}},
{name: "Data___topLevelValidates_half_pt", isValid: true, anyF: func() any {
{name: "Data___topLevelValidates_half_pt", isValid: false, anyF: func() any {
return &topLevelValidates{Primitive: 1}
}},
{name: "Data___topLevelValidates_empty_raw", isValid: false, anyF: func() any {
Expand All @@ -141,7 +141,7 @@ var integrationAnyTestCases = []integrationTestCase{
{name: "Full___any_raw", isValid: true, anyF: func() any {
return fullParent()
}},
{name: "Full___any_pt", isValid: true, anyF: func() any {
{name: "Full___any_pt", isValid: false, anyF: func() any {
full := fullParent()
return &full
}},
Expand All @@ -164,7 +164,7 @@ func TestIntegration(t *testing.T) {
if tc.f != nil {
data := tc.f()
require.Equal(tc.isValid, testRegistry.ValidateAny(data) == nil)
require.Equal(tc.isValid, testRegistry.ValidateAny(&data) == nil)
require.Equal(false, testRegistry.ValidateAny(&data) == nil)
return
}
require.Equal(tc.isValid, testRegistry.ValidateAny(tc.anyF()) == nil)
Expand Down
6 changes: 3 additions & 3 deletions pkg/firm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (r *Registry) registeredStruct(definition *Definition) *ValueAny {
}

func (r *Registry) registerRecursionType(typ reflect.Type, rules *[]Rule) {
// Registry handles indirect types only, user can't control the types here
typ = indirectType(typ)

//nolint:exhaustive // just need these cases
Expand All @@ -75,15 +76,13 @@ func (r *Registry) registerRecursionType(typ reflect.Type, rules *[]Rule) {
}
}

// Type returns the Type the Registry handles
func (r *Registry) Type() reflect.Type { return r.Validator(nil).Type() }

// ValidateAny validates the data with the correct validator
func (r *Registry) ValidateAny(data any) ErrorMap {
value := reflect.ValueOf(data)
if !value.IsValid() {
return errInvalidValue
}
// value is used here, so can't use validateValue to save reflect.TypeOf call
return validateValueResult(r.DefaultedValidator(value.Type()), value)
}

Expand Down Expand Up @@ -119,6 +118,7 @@ func (r *Registry) Validator(typ reflect.Type) Validator {
if typ == nil || r.typeToValidator == nil {
return nil
}
// Registry only contains indirect types, make the function safe
typ = indirectType(typ)
validator := r.typeToValidator[typ]
if validator == nil {
Expand Down
10 changes: 5 additions & 5 deletions pkg/firm/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func notFoundError(data any) ErrorMap {
}

// nolint:funlen // a bunch of test cases
func TestRegistry_ValidateAny(t *testing.T) {
func TestRegistry_ValidateAll(t *testing.T) {
type testCase struct {
name string
definition *Definition
Expand Down Expand Up @@ -124,13 +124,13 @@ func TestRegistry_ValidateAny(t *testing.T) {
require.Equal(notFoundError(&data), registry.ValidateAny(&data))

notFoundTemplateError := &TemplateError{Template: "type, {{.RootTypeName}}, not found in Registry"}
testValidatesFull(t, true, registry, data, notFoundTemplateError, tc.expectedKeySuffix)
testValidatesFull(t, true, registry, &data, notFoundTemplateError, tc.expectedKeySuffix)
testValidateAllFull(t, true, registry, data, notFoundTemplateError, tc.expectedKeySuffix)
testValidateAllFull(t, true, registry, &data, notFoundTemplateError, tc.expectedKeySuffix)
return
}
data := tc.data()
testValidates(t, registry, data, tc.err, tc.expectedKeySuffix)
testValidates(t, registry, &data, tc.err, tc.expectedKeySuffix)
testValidateAll(t, registry, data, tc.err, tc.expectedKeySuffix)
require.Equal(typeCheckErrorResult(registry, &data), registry.ValidateAny(&data))
})
}
}
Expand Down
27 changes: 18 additions & 9 deletions pkg/firm/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewStructAny(typ reflect.Type, ruleMap RuleMap) (StructAny, error) {
rules := v
rm[k] = &rules
}
return StructAny{typ: indirectType(typ), ruleMap: rm}, nil
return StructAny{typ: typ, ruleMap: rm}, nil
}

// Struct validates structs
Expand Down Expand Up @@ -116,7 +116,7 @@ func NewSliceAny(typ reflect.Type, elementRules ...Rule) (SliceAny, error) {
return SliceAny{}, fmt.Errorf("element type: %w", err)
}
}
return SliceAny{typ: indirectType(typ), elementRules: elementRules}, nil
return SliceAny{typ: typ, elementRules: elementRules}, nil
}

// Slice validates slices and arrays
Expand Down Expand Up @@ -218,6 +218,18 @@ func (v ValueAny) TypeCheck(typ reflect.Type) *RuleTypeError { return typeCheck(
// Rules returns the rules for ValueAny
func (v ValueAny) Rules() []Rule { return v.rules }

// RuleValidator is a Validator wrapper around Rule
type RuleValidator struct{ Rule }

// ValidateAny validates the data
func (r RuleValidator) ValidateAny(data any) ErrorMap { return validateAny(r, data) }

// ValidateMerge validates the data value, also doing a merge with the errorMap (assumes TypeCheck is called)
func (r RuleValidator) ValidateMerge(value reflect.Value, key string, errorMap ErrorMap) {
value = indirect(value)
validateMerge(value, key, errorMap, []Rule{r.Rule})
}

func mustNewValidator[T any](f func() (T, error)) T {
validator, err := f()
if err != nil {
Expand All @@ -237,7 +249,9 @@ func validateAny(validator Validator, data any) ErrorMap {
}

func validateValueResult(validator Validator, value reflect.Value) ErrorMap {
if err := validator.TypeCheck(value.Type()); err != nil {
// Users often don't have control over whether any is a pointer, so we're generous
typ := indirectType(value.Type())
if err := validator.TypeCheck(typ); err != nil {
return ErrorMap{"TypeCheck": err.TemplateError()}
}

Expand All @@ -259,12 +273,7 @@ func validateMerge(value reflect.Value, key string, errorMap ErrorMap, rules []R
}

func typeCheck(typ, expectedType reflect.Type, kindString string) *RuleTypeError {
if expectedType == anyTyp {
return nil
}

iType := indirectType(typ)
if iType != expectedType {
if typ != expectedType {
if kindString != "" {
kindString += " "
}
Expand Down
107 changes: 87 additions & 20 deletions pkg/firm/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ func TestStruct_Validate(t *testing.T) {
})
}

func TestStructAny_ValidateAny(t *testing.T) {
func TestStructAny_ValidateAll(t *testing.T) {
validator := testRegistry.Validator(reflect.TypeOf(parent{}))

tcs := []struct {
Expand All @@ -531,8 +531,8 @@ func TestStructAny_ValidateAny(t *testing.T) {
for i, key := range tc.errorKeys {
errKeySuffixes[i] = joinKeys(key, presentRuleKey)
}
testValidates(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
testValidates(t, validator, &rawData, presentRuleError(""), errKeySuffixes...)
testValidateAll(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
require.Equal(t, typeCheckErrorResult(validator, &rawData), validator.ValidateAny(&rawData))
})
}
}
Expand All @@ -548,7 +548,7 @@ func TestStructAny_TypeCheck(t *testing.T) {
badCondition string
}{
{name: "matching struct", data: parent{}},
{name: "matching struct pointer", data: &parent{}},
{name: "matching struct pointer", data: &parent{}, badCondition: badCondition},
{name: "other struct", data: Child{}, badCondition: badCondition},
{name: "not struct", data: 1, badCondition: badCondition},
}
Expand Down Expand Up @@ -641,7 +641,7 @@ func TestSlice_Validate(t *testing.T) {
})
}

func TestSliceAny_ValidateAny(t *testing.T) {
func TestSliceAny_ValidateAll(t *testing.T) {
validator := sliceValidator

tcs := []struct {
Expand All @@ -665,8 +665,8 @@ func TestSliceAny_ValidateAny(t *testing.T) {
for i, key := range tc.errorKeys {
errKeySuffixes[i] = joinKeys(key, presentRuleKey)
}
testValidates(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
testValidates(t, validator, &rawData, presentRuleError(""), errKeySuffixes...)
testValidateAll(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
require.Equal(t, typeCheckErrorResult(validator, &rawData), validator.ValidateAny(&rawData))
})
}
}
Expand All @@ -681,7 +681,7 @@ func TestSliceAny_TypeCheck(t *testing.T) {
badCondition string
}{
{name: "matching slice", data: []sliceValidatorElement{}},
{name: "matching slice pointer", data: &[]sliceValidatorElement{}},
{name: "matching slice pointer", data: &[]sliceValidatorElement{}, badCondition: badCondition},
{name: "other slice", data: []int{}, badCondition: badCondition},
{name: "not slice", data: 1, badCondition: badCondition},
}
Expand Down Expand Up @@ -744,23 +744,43 @@ func TestValue_Validate(t *testing.T) {
})
}

func TestValueAny_ValidateAny(t *testing.T) {
validator, err := NewValueAny(reflect.TypeOf(0), presentRule{})
require.NoError(t, err)

func TestValueAny_ValidateAll(t *testing.T) {
edgeTcs := []struct {
name string
validator Validator
data any
result ErrorMap
name string
rule Rule
data any

newError bool
result ErrorMap
typeCheckError bool
}{
{name: "invalid", validator: validator, data: nil, result: errInvalidValue},
{name: "invalid", rule: presentRule{}, data: nil, result: errInvalidValue},
{name: "bad_type_with_rule_validator", rule: onlyKindRule{kind: reflect.String}, data: 1, newError: true},
{name: "bad_type_after_new", rule: onlyKindRule{kind: reflect.Bool}, data: 1, typeCheckError: true},
}
for _, tc := range edgeTcs {
tc := tc
t.Run(tc.name, func(t *testing.T) { require.Equal(t, tc.result, tc.validator.ValidateAny(tc.data)) })

t.Run(tc.name, func(t *testing.T) {
require := require.New(t)

validator, err := NewValueAny(reflect.TypeOf(true), tc.rule)
if tc.newError {
require.Equal(NewRuleTypeError(reflect.TypeOf(true), "is not string"), err)
return
}

require.NoError(err)
result := tc.result
if result == nil && tc.typeCheckError {
result = typeCheckErrorResult(validator, tc.data)
}
require.Equal(result, validator.ValidateAny(tc.data))
})
}

validator, err := NewValueAny(reflect.TypeOf(0), presentRule{})
require.NoError(t, err)
type testCase struct {
name string
data any
Expand All @@ -772,7 +792,10 @@ func TestValueAny_ValidateAny(t *testing.T) {
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) { testValidates(t, validator, tc.data, tc.err, presentRuleKey) })
t.Run(tc.name, func(t *testing.T) {
testValidateAll(t, validator, tc.data, tc.err, presentRuleKey)
require.Equal(t, typeCheckErrorResult(validator, &tc.data), validator.ValidateAny(&tc.data))
})
}
}

Expand All @@ -786,7 +809,7 @@ func TestValueAny_TypeCheck(t *testing.T) {
badCondition string
}{
{name: "matching int", data: 0},
{name: "matching int pointer", data: &i},
{name: "matching int pointer", data: &i, badCondition: "is not matching of type int"},
{name: "not int", data: []int{}, badCondition: "is not matching of type int"},
}

Expand All @@ -803,3 +826,47 @@ func TestValueAny_TypeCheck(t *testing.T) {
})
}
}

func TestRuleValidator_ValidateAll(t *testing.T) {
edgeTcs := []struct {
name string
rule Rule
data any
result ErrorMap
typeCheckError bool
}{
{name: "invalid", rule: presentRule{}, data: nil, result: errInvalidValue},
{name: "bad_type", rule: onlyKindRule{kind: reflect.Bool}, data: 1, typeCheckError: true},
}
for _, tc := range edgeTcs {
tc := tc

t.Run(tc.name, func(t *testing.T) {
result := tc.result
if result == nil && tc.typeCheckError {
result = typeCheckErrorResult(tc.rule, tc.data)
}

validator := RuleValidator{Rule: tc.rule}
require.Equal(t, result, validator.ValidateAny(tc.data))
})
}

validator := RuleValidator{Rule: presentRule{}}
type testCase struct {
name string
data any
err *TemplateError
}
tcs := []testCase{
{name: "not_zero", data: 1},
{name: "zero", data: 0, err: presentRuleError("")},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) {
testValidateAll(t, validator, tc.data, tc.err, presentRuleKey)
testValidateAll(t, validator, &tc.data, nil)
})
}
}
6 changes: 3 additions & 3 deletions pkg/firm/ztesthelpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ func typeCheckErrorResult(rule Rule, data any) ErrorMap {
return ErrorMap{"TypeCheck": rule.TypeCheck(reflect.TypeOf(data)).TemplateError()}
}

func testValidates(t *testing.T, validator Validator, data any, err *TemplateError, keySuffixes ...string) {
testValidatesFull(t, false, validator, data, err, keySuffixes...)
func testValidateAll(t *testing.T, validator Validator, data any, err *TemplateError, keySuffixes ...string) {
testValidateAllFull(t, false, validator, data, err, keySuffixes...)
}

func testValidatesFull(t *testing.T, skipValidate bool, validator Validator, data any, err *TemplateError, keySuffixes ...string) {
func testValidateAllFull(t *testing.T, skipValidate bool, validator Validator, data any, err *TemplateError, keySuffixes ...string) {
require := require.New(t)

var validateValueExpected ErrorMap
Expand Down

0 comments on commit 088fffb

Please sign in to comment.