Skip to content

Commit

Permalink
lock down indirectType calls
Browse files Browse the repository at this point in the history
  • Loading branch information
s12chung committed Oct 9, 2023
1 parent ee08d9e commit 5f74e66
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 50 deletions.
6 changes: 5 additions & 1 deletion pkg/firm/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
// NewDefinition returns a new Definition
func NewDefinition[T any]() *Definition {
var zero T
typ := reflect.TypeOf(zero)
if typ.Kind() == reflect.Pointer {
panic(fmt.Sprintf("NewDefinition created with pointer type, dereference it: %v", typ.String()))
}
validator := &Definition{
typ: indirectType(reflect.TypeOf(zero)),
typ: typ,
topLevelRules: []Rule{},
ruleMap: RuleMap{},
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/firm/firm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ var ValidateAny = DefaultRegistry.ValidateAny
var DefaultRegistry = &Registry{}

// DefaultValidator is the validator used by registries for not found types when DefaultValidator is not defined
var DefaultValidator = MustNewValue[Any](NotFoundRule{})
var DefaultValidator = RuleValidator{Rule: NotFoundRule{}}

// NotFoundRule is the rule used for not found types in the DefaultValidator
type NotFoundRule struct{}
Expand Down Expand Up @@ -54,7 +54,6 @@ type BasicRule interface {
// Validator validates the data
type Validator interface {
Rule
Type() reflect.Type
ValidateAny(data any) ErrorMap
ValidateMerge(value reflect.Value, key string, errorMap ErrorMap)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/firm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (r *Registry) registeredStruct(definition *Definition) *ValueAny {
}

func (r *Registry) registerRecursionType(typ reflect.Type, rules *[]Rule) {
// Registry handles indirect types only, user can't control the types here
typ = indirectType(typ)

//nolint:exhaustive // just need these cases
Expand All @@ -75,15 +76,13 @@ func (r *Registry) registerRecursionType(typ reflect.Type, rules *[]Rule) {
}
}

// Type returns the Type the Registry handles
func (r *Registry) Type() reflect.Type { return r.Validator(nil).Type() }

// ValidateAny validates the data with the correct validator
func (r *Registry) ValidateAny(data any) ErrorMap {
value := reflect.ValueOf(data)
if !value.IsValid() {
return errInvalidValue
}
// value is used here, so can't use validateValue to save reflect.TypeOf call
return validateValueResult(r.DefaultedValidator(value.Type()), value)
}

Expand Down Expand Up @@ -119,6 +118,7 @@ func (r *Registry) Validator(typ reflect.Type) Validator {
if typ == nil || r.typeToValidator == nil {
return nil
}
// Registry only contains indirect types, make the function safe
typ = indirectType(typ)
validator := r.typeToValidator[typ]
if validator == nil {
Expand Down
10 changes: 5 additions & 5 deletions pkg/firm/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func notFoundError(data any) ErrorMap {
}

// nolint:funlen // a bunch of test cases
func TestRegistry_ValidateAny(t *testing.T) {
func TestRegistry_ValidateAll(t *testing.T) {
type testCase struct {
name string
definition *Definition
Expand Down Expand Up @@ -124,13 +124,13 @@ func TestRegistry_ValidateAny(t *testing.T) {
require.Equal(notFoundError(&data), registry.ValidateAny(&data))

notFoundTemplateError := &TemplateError{Template: "type, {{.RootTypeName}}, not found in Registry"}
testValidatesFull(t, true, registry, data, notFoundTemplateError, tc.expectedKeySuffix)
testValidatesFull(t, true, registry, &data, notFoundTemplateError, tc.expectedKeySuffix)
testValidateAllFull(t, true, registry, data, notFoundTemplateError, tc.expectedKeySuffix)
testValidateAllFull(t, true, registry, &data, notFoundTemplateError, tc.expectedKeySuffix)
return
}
data := tc.data()
testValidates(t, registry, data, tc.err, tc.expectedKeySuffix)
testValidates(t, registry, &data, tc.err, tc.expectedKeySuffix)
testValidateAll(t, registry, data, tc.err, tc.expectedKeySuffix)
testValidateAll(t, registry, &data, tc.err, tc.expectedKeySuffix)
})
}
}
Expand Down
40 changes: 28 additions & 12 deletions pkg/firm/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ func NewStruct[T any](ruleMap RuleMap) (Struct[T], error) {

// NewStructAny returns a new StructAny
func NewStructAny(typ reflect.Type, ruleMap RuleMap) (StructAny, error) {
if typ == nil || typ.Kind() != reflect.Struct {
return StructAny{}, fmt.Errorf("type is not a Struct")
if typ == nil {
return StructAny{}, fmt.Errorf("type, nil, is not a Struct")
}
if typ.Kind() != reflect.Struct {
return StructAny{}, fmt.Errorf("type, %v, is not a Struct", typ.String())
}

for fieldName, rules := range ruleMap {
Expand All @@ -41,7 +44,7 @@ func NewStructAny(typ reflect.Type, ruleMap RuleMap) (StructAny, error) {
rules := v
rm[k] = &rules
}
return StructAny{typ: indirectType(typ), ruleMap: rm}, nil
return StructAny{typ: typ, ruleMap: rm}, nil
}

// Struct validates structs
Expand Down Expand Up @@ -116,7 +119,7 @@ func NewSliceAny(typ reflect.Type, elementRules ...Rule) (SliceAny, error) {
return SliceAny{}, fmt.Errorf("element type: %w", err)
}
}
return SliceAny{typ: indirectType(typ), elementRules: elementRules}, nil
return SliceAny{typ: typ, elementRules: elementRules}, nil
}

// Slice validates slices and arrays
Expand Down Expand Up @@ -171,8 +174,12 @@ func NewValue[T any](rules ...Rule) (Value[T], error) {
// NewValueAny returns a ValueAny
func NewValueAny(typ reflect.Type, rules ...Rule) (ValueAny, error) {
if typ == nil {
typ = anyTyp
return ValueAny{}, fmt.Errorf("type is nil, not recommended")
}
if typ.Kind() == reflect.Pointer {
return ValueAny{}, fmt.Errorf("type, %v, is a Pointer, not recommended", typ.String())
}

for _, rule := range rules {
if err := rule.TypeCheck(typ); err != nil {
return ValueAny{}, err
Expand Down Expand Up @@ -218,6 +225,18 @@ func (v ValueAny) TypeCheck(typ reflect.Type) *RuleTypeError { return typeCheck(
// Rules returns the rules for ValueAny
func (v ValueAny) Rules() []Rule { return v.rules }

// RuleValidator is a Validator wrapper around Rule
type RuleValidator struct{ Rule }

// ValidateAny validates the data
func (r RuleValidator) ValidateAny(data any) ErrorMap { return validateAny(r, data) }

// ValidateMerge validates the data value, also doing a merge with the errorMap (assumes TypeCheck is called)
func (r RuleValidator) ValidateMerge(value reflect.Value, key string, errorMap ErrorMap) {
value = indirect(value)
validateMerge(value, key, errorMap, []Rule{r.Rule})
}

func mustNewValidator[T any](f func() (T, error)) T {
validator, err := f()
if err != nil {
Expand All @@ -237,7 +256,9 @@ func validateAny(validator Validator, data any) ErrorMap {
}

func validateValueResult(validator Validator, value reflect.Value) ErrorMap {
if err := validator.TypeCheck(value.Type()); err != nil {
// Users often don't have control over whether any is a pointer, so we're generous
typ := indirectType(value.Type())
if err := validator.TypeCheck(typ); err != nil {
return ErrorMap{"TypeCheck": err.TemplateError()}
}

Expand All @@ -259,12 +280,7 @@ func validateMerge(value reflect.Value, key string, errorMap ErrorMap, rules []R
}

func typeCheck(typ, expectedType reflect.Type, kindString string) *RuleTypeError {
if expectedType == anyTyp {
return nil
}

iType := indirectType(typ)
if iType != expectedType {
if typ != expectedType {
if kindString != "" {
kindString += " "
}
Expand Down
115 changes: 91 additions & 24 deletions pkg/firm/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ func TestNewStructAny(t *testing.T) {
}{
{name: "normal", data: Child{}, ruleMap: RuleMap{"Validates": {presentRule{}}}},
{name: "non_exported_field", data: Child{}, ruleMap: RuleMap{"private": {presentRule{}}}},
{name: "nil_type", data: nil, err: fmt.Errorf("type is not a Struct")},
{name: "nil_type", data: nil, err: fmt.Errorf("type, nil, is not a Struct")},
{name: "pointer", data: &Child{}, err: fmt.Errorf("type, *firm.Child, is not a Struct")},
{name: "non_matching_field", data: Child{}, ruleMap: RuleMap{"No": {presentRule{}}}, err: fmt.Errorf("field, No, not found in type: firm.Child")},
{name: "no_matching_rule", data: Child{}, ruleMap: RuleMap{"Validates": {noMatchingRule}},
err: fmt.Errorf("field, Validates, in firm.Child: %w", noMatchingRule.TypeCheck(reflect.TypeOf("")))},
Expand Down Expand Up @@ -507,7 +508,7 @@ func TestStruct_Validate(t *testing.T) {
})
}

func TestStructAny_ValidateAny(t *testing.T) {
func TestStructAny_ValidateAll(t *testing.T) {
validator := testRegistry.Validator(reflect.TypeOf(parent{}))

tcs := []struct {
Expand All @@ -531,8 +532,8 @@ func TestStructAny_ValidateAny(t *testing.T) {
for i, key := range tc.errorKeys {
errKeySuffixes[i] = joinKeys(key, presentRuleKey)
}
testValidates(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
testValidates(t, validator, &rawData, presentRuleError(""), errKeySuffixes...)
testValidateAll(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
testValidateAll(t, validator, &rawData, presentRuleError(""), errKeySuffixes...)
})
}
}
Expand All @@ -548,7 +549,7 @@ func TestStructAny_TypeCheck(t *testing.T) {
badCondition string
}{
{name: "matching struct", data: parent{}},
{name: "matching struct pointer", data: &parent{}},
{name: "matching struct pointer", data: &parent{}, badCondition: badCondition},
{name: "other struct", data: Child{}, badCondition: badCondition},
{name: "not struct", data: 1, badCondition: badCondition},
}
Expand Down Expand Up @@ -609,6 +610,7 @@ func TestNewSliceAny(t *testing.T) {
}{
{name: "normal", data: []Child{}, rules: []Rule{presentRule{}}},
{name: "nil_type", data: nil, err: fmt.Errorf("type, nil, is not a Slice or Array")},
{name: "pointer", data: &[]Child{}, err: fmt.Errorf("type, *[]firm.Child, is not a Slice or Array")},
{name: "not_slice", data: Child{}, err: fmt.Errorf("type, firm.Child, is not a Slice or Array")},
{name: "no_matching_rule", data: []Child{}, rules: []Rule{noMatchingRule},
err: fmt.Errorf("element type: %w", noMatchingRule.TypeCheck(reflect.TypeOf(Child{})))},
Expand Down Expand Up @@ -641,7 +643,7 @@ func TestSlice_Validate(t *testing.T) {
})
}

func TestSliceAny_ValidateAny(t *testing.T) {
func TestSliceAny_ValidateAll(t *testing.T) {
validator := sliceValidator

tcs := []struct {
Expand All @@ -665,8 +667,8 @@ func TestSliceAny_ValidateAny(t *testing.T) {
for i, key := range tc.errorKeys {
errKeySuffixes[i] = joinKeys(key, presentRuleKey)
}
testValidates(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
testValidates(t, validator, &rawData, presentRuleError(""), errKeySuffixes...)
testValidateAll(t, validator, rawData, presentRuleError(""), errKeySuffixes...)
testValidateAll(t, validator, &rawData, presentRuleError(""), errKeySuffixes...)
})
}
}
Expand All @@ -681,7 +683,7 @@ func TestSliceAny_TypeCheck(t *testing.T) {
badCondition string
}{
{name: "matching slice", data: []sliceValidatorElement{}},
{name: "matching slice pointer", data: &[]sliceValidatorElement{}},
{name: "matching slice pointer", data: &[]sliceValidatorElement{}, badCondition: badCondition},
{name: "other slice", data: []int{}, badCondition: badCondition},
{name: "not slice", data: 1, badCondition: badCondition},
}
Expand All @@ -707,10 +709,9 @@ func TestNewValueAny(t *testing.T) {
err error
}{
{name: "normal", data: i, rules: []Rule{intRule}},
{name: "int_pointer", data: &i, rules: []Rule{intRule}, err: intRule.TypeCheck(reflect.TypeOf(&i))},
{name: "int_pointer", data: &i, err: fmt.Errorf("type, *int, is a Pointer, not recommended")},
{name: "nil_type", data: nil, err: fmt.Errorf("type is nil, not recommended")},
{name: "not_int", data: []int{}, rules: []Rule{intRule}, err: intRule.TypeCheck(reflect.TypeOf([]int{}))},
{name: "nil_type", data: nil, rules: []Rule{intRule}, err: intRule.TypeCheck(anyTyp)},
{name: "nil_with_presentRule", data: nil, rules: []Rule{presentRule{}}},
}
for _, tc := range tcs {
tc := tc
Expand Down Expand Up @@ -744,23 +745,43 @@ func TestValue_Validate(t *testing.T) {
})
}

func TestValueAny_ValidateAny(t *testing.T) {
validator, err := NewValueAny(reflect.TypeOf(0), presentRule{})
require.NoError(t, err)

func TestValueAny_ValidateAll(t *testing.T) {
edgeTcs := []struct {
name string
validator Validator
data any
result ErrorMap
name string
rule Rule
data any

newError bool
result ErrorMap
typeCheckError bool
}{
{name: "invalid", validator: validator, data: nil, result: errInvalidValue},
{name: "invalid", rule: presentRule{}, data: nil, result: errInvalidValue},
{name: "bad_type_with_rule_validator", rule: onlyKindRule{kind: reflect.String}, data: 1, newError: true},
{name: "bad_type_after_new", rule: onlyKindRule{kind: reflect.Bool}, data: 1, typeCheckError: true},
}
for _, tc := range edgeTcs {
tc := tc
t.Run(tc.name, func(t *testing.T) { require.Equal(t, tc.result, tc.validator.ValidateAny(tc.data)) })

t.Run(tc.name, func(t *testing.T) {
require := require.New(t)

validator, err := NewValueAny(reflect.TypeOf(true), tc.rule)
if tc.newError {
require.Equal(NewRuleTypeError(reflect.TypeOf(true), "is not string"), err)
return
}

require.NoError(err)
result := tc.result
if result == nil && tc.typeCheckError {
result = typeCheckErrorResult(validator, tc.data)
}
require.Equal(result, validator.ValidateAny(tc.data))
})
}

validator, err := NewValueAny(reflect.TypeOf(0), presentRule{})
require.NoError(t, err)
type testCase struct {
name string
data any
Expand All @@ -772,7 +793,9 @@ func TestValueAny_ValidateAny(t *testing.T) {
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) { testValidates(t, validator, tc.data, tc.err, presentRuleKey) })
t.Run(tc.name, func(t *testing.T) {
testValidateAll(t, validator, tc.data, tc.err, presentRuleKey)
})
}
}

Expand All @@ -786,7 +809,7 @@ func TestValueAny_TypeCheck(t *testing.T) {
badCondition string
}{
{name: "matching int", data: 0},
{name: "matching int pointer", data: &i},
{name: "matching int pointer", data: &i, badCondition: "is not matching of type int"},
{name: "not int", data: []int{}, badCondition: "is not matching of type int"},
}

Expand All @@ -803,3 +826,47 @@ func TestValueAny_TypeCheck(t *testing.T) {
})
}
}

func TestRuleValidator_ValidateAll(t *testing.T) {
edgeTcs := []struct {
name string
rule Rule
data any
result ErrorMap
typeCheckError bool
}{
{name: "invalid", rule: presentRule{}, data: nil, result: errInvalidValue},
{name: "bad_type", rule: onlyKindRule{kind: reflect.Bool}, data: 1, typeCheckError: true},
}
for _, tc := range edgeTcs {
tc := tc

t.Run(tc.name, func(t *testing.T) {
result := tc.result
if result == nil && tc.typeCheckError {
result = typeCheckErrorResult(tc.rule, tc.data)
}

validator := RuleValidator{Rule: tc.rule}
require.Equal(t, result, validator.ValidateAny(tc.data))
})
}

validator := RuleValidator{Rule: presentRule{}}
type testCase struct {
name string
data any
err *TemplateError
}
tcs := []testCase{
{name: "not_zero", data: 1},
{name: "zero", data: 0, err: presentRuleError("")},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) {
testValidateAll(t, validator, tc.data, tc.err, presentRuleKey)
testValidateAll(t, validator, &tc.data, nil)
})
}
}
Loading

0 comments on commit 5f74e66

Please sign in to comment.