Skip to content

Commit

Permalink
[TT-5279] Moved empty type body validation to astvalidation
Browse files Browse the repository at this point in the history
[changelog]
internal: Moved visitor to astvalidation. Added validation for type extensions. Handled some edge cases.
  • Loading branch information
David Stutt committed May 19, 2022
1 parent 6820fd5 commit a8b8544
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 11 deletions.
1 change: 1 addition & 0 deletions pkg/astvalidation/definition_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

func DefaultDefinitionValidator() *DefinitionValidator {
return NewDefinitionValidator(
PopulatedTypeBodies(),
UniqueOperationTypes(),
UniqueTypeNames(),
UniqueFieldDefinitionNames(),
Expand Down
136 changes: 136 additions & 0 deletions pkg/astvalidation/rule_populated_type_bodies.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package astvalidation

import (
"github.com/jensneuse/graphql-go-tools/pkg/ast"
"github.com/jensneuse/graphql-go-tools/pkg/astvisitor"
"github.com/jensneuse/graphql-go-tools/pkg/operationreport"
)

type populatedTypeBodiesVisitor struct {
*astvisitor.Walker
definition *ast.Document
}

func PopulatedTypeBodies() Rule {
return func(walker *astvisitor.Walker) {
visitor := &populatedTypeBodiesVisitor{
Walker: walker,
definition: nil,
}

walker.RegisterEnterDocumentVisitor(visitor)
walker.RegisterEnterEnumTypeDefinitionVisitor(visitor)
walker.RegisterEnterEnumTypeExtensionVisitor(visitor)
walker.RegisterEnterInputObjectTypeDefinitionVisitor(visitor)
walker.RegisterEnterInputObjectTypeExtensionVisitor(visitor)
walker.RegisterEnterInterfaceTypeDefinitionVisitor(visitor)
walker.RegisterEnterInterfaceTypeExtensionVisitor(visitor)
walker.RegisterEnterObjectTypeDefinitionVisitor(visitor)
walker.RegisterEnterObjectTypeExtensionVisitor(visitor)
}
}

func (p *populatedTypeBodiesVisitor) EnterDocument(operation, _ *ast.Document) {
p.definition = operation
}

func (p populatedTypeBodiesVisitor) EnterEnumTypeDefinition(ref int) {
definition := p.definition
if !definition.EnumTypeDefinitions[ref].HasEnumValuesDefinition {
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("enum", definition.EnumTypeDefinitionNameString(ref)))
return
}
}

func (p *populatedTypeBodiesVisitor) EnterEnumTypeExtension(ref int) {
definition := p.definition
if !definition.EnumTypeExtensions[ref].HasEnumValuesDefinition {
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("enum extension", definition.EnumTypeExtensionNameString(ref)))
return
}
}

func (p populatedTypeBodiesVisitor) EnterInputObjectTypeDefinition(ref int) {
definition := p.definition
if !definition.InputObjectTypeDefinitions[ref].HasInputFieldsDefinition {
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("input", definition.InputObjectTypeDefinitionNameString(ref)))
return
}
}

func (p *populatedTypeBodiesVisitor) EnterInputObjectTypeExtension(ref int) {
definition := p.definition
if !definition.InputObjectTypeExtensions[ref].HasInputFieldsDefinition {
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("input extension", definition.InputObjectTypeExtensionNameString(ref)))
return
}
}

func (p populatedTypeBodiesVisitor) EnterInterfaceTypeDefinition(ref int) {
definition := p.definition
switch definition.InterfaceTypeDefinitions[ref].HasFieldDefinitions {
case true:
refs := definition.InterfaceTypeDefinitions[ref].FieldsDefinition.Refs
if len(refs) > 1 || definition.FieldDefinitionNameString(refs[0]) != typename {
return
}
fallthrough
case false:
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("interface", definition.InterfaceTypeDefinitionNameString(ref)))
return
}
}

func (p *populatedTypeBodiesVisitor) EnterInterfaceTypeExtension(ref int) {
definition := p.definition
if !definition.InterfaceTypeExtensions[ref].HasFieldDefinitions {
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("interface extension", definition.InterfaceTypeExtensionNameString(ref)))
return
}
}

func (p populatedTypeBodiesVisitor) EnterObjectTypeDefinition(ref int) {
definition := p.definition
nameBytes := definition.ObjectTypeDefinitionNameBytes(ref)
if isRootType(nameBytes) {
return
}
switch definition.ObjectTypeDefinitions[ref].HasFieldDefinitions {
case true:
refs := definition.ObjectTypeDefinitions[ref].FieldsDefinition.Refs
if len(refs) > 1 || definition.FieldDefinitionNameString(refs[0]) != typename {
return
}
fallthrough
case false:
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("object", string(nameBytes)))
return
}
}

func (p *populatedTypeBodiesVisitor) EnterObjectTypeExtension(ref int) {
definition := p.definition
if !definition.ObjectTypeExtensions[ref].HasFieldDefinitions {
p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("object extension", definition.ObjectTypeExtensionNameString(ref)))
return
}
}

func isRootType(nameBytes []byte) bool {
length := len(nameBytes)
return isQuery(length, nameBytes) || isMutation(length, nameBytes) || isSubscription(length, nameBytes)
}

func isQuery(length int, b []byte) bool {
return length == 5 && b[0] == 'Q' && b[1] == 'u' && b[2] == 'e' && b[3] == 'r' && b[4] == 'y'
}

func isMutation(length int, b []byte) bool {
return length == 8 && b[0] == 'M' && b[1] == 'u' && b[2] == 't' && b[3] == 'a' && b[4] == 't' && b[5] == 'i' && b[6] == 'o' && b[7] == 'n'
}

func isSubscription(length int, b []byte) bool {
return length == 12 && b[0] == 'S' && b[1] == 'u' && b[2] == 'b' && b[3] == 's' && b[4] == 'c' && b[5] == 'r' && b[6] == 'i' && b[7] == 'p' && b[8] == 't' && b[9] == 'i' && b[10] == 'o' && b[11] == 'n'
}

const typename = "__typename"
126 changes: 126 additions & 0 deletions pkg/astvalidation/rule_populated_type_bodies_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package astvalidation

import (
"testing"
)

func TestPopulatedTypeBodies(t *testing.T) {
t.Run("Definition", func(t *testing.T) {
t.Run("Populated type bodies are valid", func(t *testing.T) {
runDefinitionValidation(t, `
enum Species {
CAT
}
extend enum Color {
DOG
}
input Message {
content: String!
}
extend input Message {
updated: DateTime!
}
interface Animal {
species: Species!
}
extend interface Animal {
age: Int!
}
type Cat implements Animal {
species: Species!
}
extend type Cat implements Animal {
age: Int!
}
`, Valid, PopulatedTypeBodies(),
)
})

t.Run("Empty enum is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
enum Species {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty enum extension is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
enum Species {
CAT
}
extend enum Species {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty input is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
input Message {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty input extension is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
input Message {
content: String!
}
extend input Message {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty interface is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
interface Animal {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty interface extension is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
interface Animal {
species: String!
}
extend interface Animal {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty object is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
type Cat {
}
`, Invalid, PopulatedTypeBodies(),
)
})

t.Run("Empty object extension is invalid", func(t *testing.T) {
runDefinitionValidation(t, `
type Cat {
species: String!
}
extend type Cat {
}
`, Invalid, PopulatedTypeBodies(),
)
})
})
}
10 changes: 7 additions & 3 deletions pkg/federation/sdlmerge/object_type_extending.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (e *extendObjectTypeDefinitionVisitor) EnterObjectTypeExtension(ref int) {
}

hasExtended := false
shouldReturn := isQueryOrMutation(nameBytes)
shouldReturn := isRootType(nameBytes)
for i := range nodes {
if nodes[i].Kind != ast.NodeKindObjectTypeDefinition {
continue
Expand All @@ -49,9 +49,9 @@ func (e *extendObjectTypeDefinitionVisitor) EnterObjectTypeExtension(ref int) {
}
}

func isQueryOrMutation(nameBytes []byte) bool {
func isRootType(nameBytes []byte) bool {
length := len(nameBytes)
return isQuery(length, nameBytes) || isMutation(length, nameBytes)
return isQuery(length, nameBytes) || isMutation(length, nameBytes) || isSubscription(length, nameBytes)
}

func isQuery(length int, b []byte) bool {
Expand All @@ -61,3 +61,7 @@ func isQuery(length int, b []byte) bool {
func isMutation(length int, b []byte) bool {
return length == 8 && b[0] == 'M' && b[1] == 'u' && b[2] == 't' && b[3] == 'a' && b[4] == 't' && b[5] == 'i' && b[6] == 'o' && b[7] == 'n'
}

func isSubscription(length int, b []byte) bool {
return length == 12 && b[0] == 'S' && b[1] == 'u' && b[2] == 'b' && b[3] == 's' && b[4] == 'c' && b[5] == 'r' && b[6] == 'i' && b[7] == 'p' && b[8] == 't' && b[9] == 'i' && b[10] == 'o' && b[11] == 'n'
}
42 changes: 34 additions & 8 deletions pkg/federation/sdlmerge/sdlmerge.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sdlmerge

import (
"fmt"
"github.com/jensneuse/graphql-go-tools/pkg/astvalidation"
"strings"

"github.com/jensneuse/graphql-go-tools/pkg/ast"
Expand Down Expand Up @@ -29,12 +30,21 @@ func MergeAST(ast *ast.Document) error {
return normalizer.normalize(ast)
}

type ParsedSubgraph struct {
document *ast.Document
report *operationreport.Report
}

func MergeSDLs(SDLs ...string) (string, error) {
rawDocs := make([]string, 0, len(SDLs)+1)
rawDocs = append(rawDocs, rootOperationTypeDefinitions)
rawDocs = append(rawDocs, SDLs...)
if err := normalizeSubgraphs(rawDocs); err != nil {
return "", err
parsedSubgraphs, validationError := validateSubgraphs(rawDocs)
if validationError != nil {
return "", validationError
}
if normalizationError := normalizeSubgraphs(rawDocs, parsedSubgraphs); normalizationError != nil {
return "", normalizationError
}

doc, report := astparser.ParseGraphqlDocumentString(strings.Join(rawDocs, "\n"))
Expand All @@ -59,18 +69,34 @@ func MergeSDLs(SDLs ...string) (string, error) {
return out, nil
}

func normalizeSubgraphs(subgraphs []string) error {
subgraphNormalizer := astnormalization.NewSubgraphDefinitionNormalizer()
for i, subgraph := range subgraphs {
func validateSubgraphs(subgraphs []string) ([]ParsedSubgraph, error) {
validator := astvalidation.NewDefinitionValidator(astvalidation.PopulatedTypeBodies())
parsedSubgraphs := make([]ParsedSubgraph, 0, len(subgraphs))
for _, subgraph := range subgraphs {
doc, report := astparser.ParseGraphqlDocumentString(subgraph)
parsedSubgraph := ParsedSubgraph{&doc, &report}
if report.HasErrors() {
return fmt.Errorf("parse graphql document string: %s", report.Error())
return parsedSubgraphs, fmt.Errorf("parse graphql document string: %s", report.Error())
}
subgraphNormalizer.NormalizeDefinition(&doc, &report)
validator.Validate(&doc, &report)
if report.HasErrors() {
return parsedSubgraphs, fmt.Errorf("validate subgraph: %s", report.Error())
}
parsedSubgraphs = append(parsedSubgraphs, parsedSubgraph)
}
return parsedSubgraphs, nil
}

func normalizeSubgraphs(subgraphs []string, parsedSubgraph []ParsedSubgraph) error {
subgraphNormalizer := astnormalization.NewSubgraphDefinitionNormalizer()
for i := range subgraphs {
doc := parsedSubgraph[i].document
report := parsedSubgraph[i].report
subgraphNormalizer.NormalizeDefinition(doc, report)
if report.HasErrors() {
return fmt.Errorf("normalize subgraph: %s", report.Error())
}
out, err := astprinter.PrintString(&doc, nil)
out, err := astprinter.PrintString(doc, nil)
if err != nil {
return fmt.Errorf("stringify schema: %s", err.Error())
}
Expand Down
Loading

0 comments on commit a8b8544

Please sign in to comment.