From e6e1a13e36eaf3a4e09a2bbd1eca3dfcb3d955d6 Mon Sep 17 00:00:00 2001 From: Sergey Petrunin Date: Mon, 25 Jul 2022 17:00:20 +0300 Subject: [PATCH] Federation: improve entities handling on sdl merge (#395) * [TT-4815] Remove duplicated scalars from federated schemas [changelog] internal: Duplicated scalars among merged schemas will be removed so only one of that scalar name will exist within the federated schema. * [TT-4815] Responded to PR feedback [changelog] internal: Changed magic number to constant; renamed variable; removed unnecessary code. * [TT-5116] Duplicate enums and unions are merged into a single enum or union respectively [changelog] internal: Duplicate enums and unions are merged into a single enum or union respectively. * [TT-5116] Responding to PR feedback [changelog] internal: Reduced some code duplication through interfaces. * [TT-5116] Some renaming and clean up. [changelog] internal: Renamed and clean up some methods; proposed some efficiency solutions. * [TT-5116] Duplicate fieldless value types are removed through single visitor. [changelog] internal: Duplicate enums, scalars and unions are merged through the same generic visitor. * [TT-5116] Return an error if duplicate value types are not identically composed. [changelog] internal: Added error handling for when duplicate values are not identically composed. * [TT-5116] Added sldmerge_tests [changelog] internal: Added new tests; improved test formatting; improved some naming. * [TT-5116] Removed unnecessary code [changelog] internal: Removed unnecessary code; tidied code up. * [TT-5116] Minor additions to sdlmerge_test [changelog] internal: Capitalised enum values; added one more type to unions. * [TT-5117] Implemented removal of duplicate interfaces [changelog] internal: Implemented removal of duplicate interfaces. (cherry picked from commit c2388b481778a25938e67d4e900262c11377c76c) * [TT-5117] Implemented duplicate Input/Object/Interface removal [changelog] internal: Implemented duplicate Input/Object/Interface removal. (cherry picked from commit edc4f3dfe03a844d8c59e5485bafea5eb483e612) * [TT-5117] Removed duplicate code [changelog] internal: Removed duplicate code. (cherry picked from commit d12e057bcd9c0d7f2be41d468a536919688c47a2) * [TT-5117] Changed receiver to pointer. [changelog] internal: Changed receiver to pointer. (cherry picked from commit 5ab441c0e3ec16b2297ddbab3239b05afd829798) * [TT-5117] Removed unused interface. [changelog] internal: Removed unused interface. (cherry picked from commit cdd0b14fe3b5313d27c7bba7c0fdbb999b9e344c) * [TT-5117] Re-added interface tests [changelog] internal: Re-added interface tests; improved test naming. (cherry picked from commit 06e0c8769d86e74cd0f3fc8e5e9ec090ff17371b) * [TT-5117] Corrected comparison of nested field values. [changelog] internal: Field values are deeply compared; added more tests. (cherry picked from commit a68d94d657e320ef6127a91ecf80e2b11e169879) * [TT-5117] Expanded fielded and sdlmerge test suites. [changelog] internal: Expanded fielded and sdlmerge test suites. (cherry picked from commit de97afa1a0b8e8a51c2a35a27d41942a6267a30e) * [TT-5199] Subgraphs are normalized before a merge is attempted. [changelog] internal: Subgraphs are normalized before merge; added missing extension handlers. (cherry picked from commit f4e821cf4aec3f4491f11e058e72ffb80256d4e3) * [TT-5199] Removed redundant boolean. [changelog] internal: Removed redundant boolean due to early return. (cherry picked from commit 560e9e124cad3bf6c763ab981829c3202e7c9ab5) * [TT-5199] Extracted pre-merge subgraph normalization into separate method. [changelog] internal: Extracted pre-merge subgraph normalization into separate method. (cherry picked from commit 83295b3d013b15e4059d92877e1002336f9a5f3a) * [TT-5117] Entities cannot be shared types [changelog] internal: Added logic to produce error if duplicate entity exists; renamed value type to shared type. (cherry picked from commit 79783e84d5c228d6a174f340d2a3a78a1c814974) * [TT-5117] Handled an edge case where entity is last item in the schema [changelog] internal: Added logic to still error if the entity is the last item in the schema. (cherry picked from commit 091ded185532d88b302a36d36275fbe48751a4a2) * [TT-5117] Inputs cannot be entities [changelog] internal: Removed logic to handle input entities. (cherry picked from commit d8a9fa67c505cb98b094f4c4cb4ba6bb2199f012) * [TT-5279] Shared type, extension orphan and validation clean up [changelog] internal: Shared types cannot be extended; added unique union member validation; unresolved extension orphans in the supergraph return an error (cherry picked from commit 9295cc4231f885caca8e3dc36a72a90971c95cf3) * [TT-5279] Subgraphs with empty type bodies are invalidated before extension could occur [changelog] internal: Added visitor to return an error if a type has an empty body. (cherry picked from commit 6820fd53050879bb1835e3ab98d4aaef72335d43) * [TT-5279] Moved empty type body validation to astvalidation [changelog] internal: Moved visitor to astvalidation. Added validation for type extensions. Handled some edge cases. (cherry picked from commit a8b8544d7280f4b5a02bc5dbb328781f6981245c) * [TT-5279] Handled edge cases around pre-validation of subgraphs [changelog] internal: Added check for root type; added logic to ignore private fields; removed duplicated line in rule. (cherry picked from commit e96a4558750014f0eb8979f1ce7b6bc324cf5246) * [TT-5279] Final clean up [changelog] internal: Removed unnecessary visitor; changed tests to new expected outcome (cherry picked from commit 96d3cc4cf75a7b21d18384f18c8e19b5d6cc4e37) * [TT-5279] Changed reused string to constant [changelog] internal: Changed reused parse error to string constant. (cherry picked from commit b271c75112df67463484be3a9d4d6dfd6b7258ff) * [TT-5279] Early return for root types when checking whether an object is empty [changelog] internal: Changed check for root type to come first so unnecessary code is not performed. (cherry picked from commit 4430be00e360ecf221c89beb10fc8627321938d7) * [TT-5279] Responding to PR feedback [changelog] internal: Added vars to prevent repeated slice creation; added helper function to determine whether type only contains reserved fields. (cherry picked from commit d4210ec4c3432ee6619660b24732c73be36fb93c) * [TT-5279] Use strings to create byte slice vars [changelog] internal: Changed explicit byte slice creation to string conversion. (cherry picked from commit 55244b8e1bfe08ad17cf5ca6f575749c5167ad9d) * [TT-5279] Removed cached struct fields [changelog] internal: Removed variables that cached struct fields. (cherry picked from commit 323747203d516a8acc88b0fee0034328b86723f2) * [TT-5460] Handling of entities PoC [changelog] internal: Added handling for extension of entities. (cherry picked from commit c0ac6e588cc2314d775f303f0aa9b9ed803db16b) * [TT-5460] Fixed gosimple complaint [changelog] internal: Removed explicit length for empty map. (cherry picked from commit 163b185caf0a25a70ad39fc0fe1eadaa7815a530) * [TT-5460] Refactors based on feedback [changelog] internal: Refactored implementation of entity handling; moved entity handling and shared type handling into own files. (cherry picked from commit 59ce18229ff13b23c26f3a840a906d04266bb0e4) * [TT-5460] Minor formatting changes. [changelog] internal: Changed some formatting; exported and reused existing constant. (cherry picked from commit 2e2c5217a9e446537199be0ea53aa498571b59d9) * [TT-5460] Reduced code duplication and cognitive complexity [changelog] internal: Reduced code duplication; reduced cognitive complexity of external directive validation method. (cherry picked from commit 6cde0ac2f09c49cadd84f0488da45f3bcc43d614) * [TT-5460] Further reduced code duplication [changelog] internal: Moved more generic isEntity method to entityValidator. (cherry picked from commit 97a6a5fa08bfcac4c1caa449e19ca46cd161220a) * [TT-5460] SOm refactoring of entity handling. [changelog] internal: Renamed some errors; added logic for edge cases. (cherry picked from commit bcceac453a7e4052dd9fd263b258be0f05c7678a) * [TT-5460] Removed most validation of entities [changelog] internal: Removed most validation of entities. (cherry picked from commit 6059c34c7209b407e520f63d200eb0345128e58f) * [TT-5460] Responded to PR feedback [changelog] internal: Changed placeholder bool to empty struct in entity set; refactored handling of duplicated entities; removed unnecessary checks that are handled elsewhere. (cherry picked from commit 9e715cf84cfd91b4331c39c49a0d725109e8125f) * [TT-5460] Responded to more PR feedback [changelog] internal: Added immediate return when producing an error; added entitySet type and pass that rather than normalizer. (cherry picked from commit f7dd1be82853b531e5001b83f79ca9142bd0860f) * [TT-5460] Final PR feedback changes [changelog] internal: Renamed isTypeEntity; changed logic so document is only changed if the extension is valid; separated exported and unexported const blocks in Plan. (cherry picked from commit 16baec357db48f4891758e5469247c28450b3b29) * [TT-5460] Removed unnecessary variable. [changelog] internal: Removed unnecessary variable. (cherry picked from commit 5389d862d801d89b120bf0284f6c37fea593b572) * [federation-changes-for-upstream] Changed import paths [changelog] internal: Changed import paths. Co-authored-by: David Stutt --- pkg/ast/ast_object_type_definition_test.go | 10 +- pkg/ast/ast_root_operation_type_definition.go | 15 +- pkg/ast/ast_test.go | 2 +- .../definition_normalization.go | 20 + .../definition_normalization_test.go | 137 ++++- pkg/astnormalization/enum_type_extending.go | 18 +- .../input_object_type_extending.go | 19 +- .../interface_type_extending.go | 19 +- pkg/astnormalization/object_type_extending.go | 18 +- pkg/astnormalization/scalar_type_extending.go | 18 +- pkg/astnormalization/union_type_extending.go | 19 +- pkg/asttransform/baseschema.go | 6 +- pkg/asttransform/typename_visitor.go | 4 +- pkg/astvalidation/definition_validation.go | 2 + pkg/astvalidation/rule.go | 2 + pkg/astvalidation/rule_known_type_names.go | 5 +- .../rule_populated_type_bodies.go | 116 ++++ .../rule_populated_type_bodies_test.go | 126 +++++ .../rule_unique_field_definition_names.go | 16 +- .../rule_unique_union_member_types.go | 93 +++ .../rule_unique_union_member_types_test.go | 30 + pkg/engine/plan/local_type_field_extractor.go | 7 +- pkg/engine/plan/required_field_extractor.go | 2 +- pkg/execution/datasource_http_json_test.go | 10 +- pkg/execution/planning_test.go | 24 +- pkg/federation/sdlmerge/collect_entities.go | 62 ++ .../sdlmerge/collect_entities_test.go | 86 +++ .../sdlmerge/enum_type_extending.go | 50 ++ .../sdlmerge/enum_type_extending_test.go | 107 ++++ .../sdlmerge/input_type_extending.go | 50 ++ .../sdlmerge/input_type_extending_test.go | 83 +++ .../sdlmerge/interface_type_extending.go | 40 +- .../sdlmerge/interface_type_extending_test.go | 206 +++++-- .../sdlmerge/object_type_extending.go | 42 +- .../sdlmerge/object_type_extending_test.go | 272 ++++++--- .../remove_duplicate_fielded_shared_types.go | 107 ++++ ...ove_duplicate_fielded_shared_types_test.go | 529 ++++++++++++++++++ ...remove_duplicate_fieldless_shared_types.go | 98 ++++ ...e_duplicate_fieldless_shared_types_test.go | 373 ++++++++++++ .../sdlmerge/remove_type_extensions_test.go | 10 +- .../sdlmerge/scalar_type_extending.go | 49 ++ .../sdlmerge/scalar_type_extending_test.go | 29 + pkg/federation/sdlmerge/sdlmerge.go | 120 +++- pkg/federation/sdlmerge/sdlmerge_test.go | 508 +++++++++++++++-- pkg/federation/sdlmerge/shared_types.go | 168 ++++++ .../sdlmerge/union_type_extending.go | 25 +- .../sdlmerge/union_type_extending_test.go | 90 ++- pkg/operationreport/externalerror.go | 40 ++ 48 files changed, 3568 insertions(+), 314 deletions(-) create mode 100644 pkg/astvalidation/rule_populated_type_bodies.go create mode 100644 pkg/astvalidation/rule_populated_type_bodies_test.go create mode 100644 pkg/astvalidation/rule_unique_union_member_types.go create mode 100644 pkg/astvalidation/rule_unique_union_member_types_test.go create mode 100644 pkg/federation/sdlmerge/collect_entities.go create mode 100644 pkg/federation/sdlmerge/collect_entities_test.go create mode 100644 pkg/federation/sdlmerge/enum_type_extending.go create mode 100644 pkg/federation/sdlmerge/enum_type_extending_test.go create mode 100644 pkg/federation/sdlmerge/input_type_extending.go create mode 100644 pkg/federation/sdlmerge/input_type_extending_test.go create mode 100644 pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types.go create mode 100644 pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types_test.go create mode 100644 pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types.go create mode 100644 pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types_test.go create mode 100644 pkg/federation/sdlmerge/scalar_type_extending.go create mode 100644 pkg/federation/sdlmerge/scalar_type_extending_test.go create mode 100644 pkg/federation/sdlmerge/shared_types.go diff --git a/pkg/ast/ast_object_type_definition_test.go b/pkg/ast/ast_object_type_definition_test.go index 6373dda5a..d35623f8b 100644 --- a/pkg/ast/ast_object_type_definition_test.go +++ b/pkg/ast/ast_object_type_definition_test.go @@ -35,15 +35,15 @@ func TestDocument_RemoveObjectTypeDefinition(t *testing.T) { t.Run("remove query type", func(t *testing.T) { doc := prepareDoc() - doc.RemoveObjectTypeDefinition([]byte("Query")) + doc.RemoveObjectTypeDefinition(ast.DefaultQueryTypeName) docStr, _ := astprinter.PrintString(doc, nil) assert.Equal(t, "type Mutation {mutationName: String} type Country {code: String} interface Model {id: String}", docStr) }) t.Run("remove query and mutations types", func(t *testing.T) { doc := prepareDoc() - doc.RemoveObjectTypeDefinition([]byte("Query")) - doc.RemoveObjectTypeDefinition([]byte("Mutation")) + doc.RemoveObjectTypeDefinition(ast.DefaultQueryTypeName) + doc.RemoveObjectTypeDefinition(ast.DefaultMutationTypeName) docStr, _ := astprinter.PrintString(doc, nil) assert.Equal(t, "type Country {code: String} interface Model {id: String}", docStr) @@ -51,8 +51,8 @@ func TestDocument_RemoveObjectTypeDefinition(t *testing.T) { t.Run("remove all types", func(t *testing.T) { doc := prepareDoc() - doc.RemoveObjectTypeDefinition([]byte("Query")) - doc.RemoveObjectTypeDefinition([]byte("Mutation")) + doc.RemoveObjectTypeDefinition(ast.DefaultQueryTypeName) + doc.RemoveObjectTypeDefinition(ast.DefaultMutationTypeName) doc.RemoveObjectTypeDefinition([]byte("Country")) docStr, _ := astprinter.PrintString(doc, nil) diff --git a/pkg/ast/ast_root_operation_type_definition.go b/pkg/ast/ast_root_operation_type_definition.go index f352cce57..0f252dbc8 100644 --- a/pkg/ast/ast_root_operation_type_definition.go +++ b/pkg/ast/ast_root_operation_type_definition.go @@ -1,9 +1,14 @@ package ast import ( + "bytes" "github.com/wundergraph/graphql-go-tools/pkg/lexer/position" ) +var DefaultQueryTypeName = []byte("Query") +var DefaultMutationTypeName = []byte("Mutation") +var DefaultSubscriptionTypeName = []byte("Subscription") + type RootOperationTypeDefinitionList struct { LBrace position.Position // { Refs []int // RootOperationTypeDefinition @@ -51,11 +56,11 @@ func (d *Document) RootOperationTypeDefinitionIsLastInSchemaDefinition(ref int, func (d *Document) CreateRootOperationTypeDefinition(operationType OperationType, rootNodeRef int) (ref int) { switch operationType { case OperationTypeQuery: - d.Index.QueryTypeName = []byte("Query") + d.Index.QueryTypeName = DefaultQueryTypeName case OperationTypeMutation: - d.Index.MutationTypeName = []byte("Mutation") + d.Index.MutationTypeName = DefaultMutationTypeName case OperationTypeSubscription: - d.Index.SubscriptionTypeName = []byte("Subscription") + d.Index.SubscriptionTypeName = DefaultSubscriptionTypeName default: return } @@ -137,3 +142,7 @@ func (d *Document) ReplaceRootOperationTypeDefinition(name string, operationType ref = d.ImportRootOperationTypeDefinition(name, operationType) return ref, true } + +func IsRootType(nameBytes []byte) bool { + return bytes.Equal(DefaultQueryTypeName, nameBytes) || bytes.Equal(DefaultMutationTypeName, nameBytes) || bytes.Equal(DefaultSubscriptionTypeName, nameBytes) +} diff --git a/pkg/ast/ast_test.go b/pkg/ast/ast_test.go index 272717090..5b62870c0 100644 --- a/pkg/ast/ast_test.go +++ b/pkg/ast/ast_test.go @@ -400,7 +400,7 @@ func TestDocument_NodeByName(t *testing.T) { t.Run("when node name is Query", func(t *testing.T) { t.Run("NodeByName", func(t *testing.T) { - node, exists := doc.NodeByName([]byte("Query")) + node, exists := doc.NodeByName(ast.DefaultQueryTypeName) assert.Equal(t, ast.NodeKindObjectTypeDefinition, node.Kind) assert.True(t, exists) }) diff --git a/pkg/astnormalization/definition_normalization.go b/pkg/astnormalization/definition_normalization.go index 0b15bb60c..fa897cf36 100644 --- a/pkg/astnormalization/definition_normalization.go +++ b/pkg/astnormalization/definition_normalization.go @@ -41,6 +41,26 @@ func (o *DefinitionNormalizer) setupWalkers() { o.walker = &walker } +func NewSubgraphDefinitionNormalizer() *DefinitionNormalizer { + normalizer := &DefinitionNormalizer{} + normalizer.setupSubgraphWalkers() + return normalizer +} + +func (o *DefinitionNormalizer) setupSubgraphWalkers() { + walker := astvisitor.NewWalker(48) + + extendObjectTypeDefinitionKeepingOrphans(&walker) + extendInputObjectTypeDefinitionKeepingOrphans(&walker) + extendEnumTypeDefinitionKeepingOrphans(&walker) + extendInterfaceTypeDefinitionKeepingOrphans(&walker) + extendScalarTypeDefinitionKeepingOrphans(&walker) + extendUnionTypeDefinitionKeepingOrphans(&walker) + removeMergedTypeExtensions(&walker) + + o.walker = &walker +} + // NormalizeDefinition applies all registered rules to the AST func (o *DefinitionNormalizer) NormalizeDefinition(definition *ast.Document, report *operationreport.Report) { o.walker.Walk(definition, nil, report) diff --git a/pkg/astnormalization/definition_normalization_test.go b/pkg/astnormalization/definition_normalization_test.go index a9ae5d07d..a5550204c 100644 --- a/pkg/astnormalization/definition_normalization_test.go +++ b/pkg/astnormalization/definition_normalization_test.go @@ -78,8 +78,8 @@ func TestNormalizeDefinition(t *testing.T) { lat: Float lon: Float planet: Planet - }`, - ) + } + `) }) t.Run("removes type extension and includes interfaces when type already has implements interface", func(t *testing.T) { @@ -100,7 +100,8 @@ func TestNormalizeDefinition(t *testing.T) { interface Entity { id: ID - }`, ` + } + `, ` schema { query: Query } type User implements Named & Entity { @@ -114,8 +115,8 @@ func TestNormalizeDefinition(t *testing.T) { interface Entity { id: ID - }`, - ) + } + `) }) t.Run("removes extensions and creates missing schema and root operation types", func(t *testing.T) { @@ -133,7 +134,129 @@ func TestNormalizeDefinition(t *testing.T) { } type Subscription { textCounter: String - }`, - ) + } + `) + }) +} + +func TestNormalizeSubgraphDefinition(t *testing.T) { + run := func(t *testing.T, definition, expectedOutput string) { + t.Helper() + + definitionDocument := unsafeparser.ParseGraphqlDocumentString(definition) + expectedOutputDocument := unsafeparser.ParseGraphqlDocumentString(expectedOutput) + + report := operationreport.Report{} + normalizer := NewSubgraphDefinitionNormalizer() + normalizer.NormalizeDefinition(&definitionDocument, &report) + + if report.HasErrors() { + t.Fatal(report.Error()) + } + + got := mustString(astprinter.PrintString(&definitionDocument, nil)) + want := mustString(astprinter.PrintString(&expectedOutputDocument, nil)) + + assert.Equal(t, want, got) + } + + t.Run("Extension orphans are not deleted", func(t *testing.T) { + run(t, ` + extend type Rival { + version: Version! + } + + enum Badge { + BOULDER + SOUL + } + + extend enum Version { + SILVER + } + + extend input Deposit { + quantity: Int! + } + + extend interface GymLeader { + badge: Badge! + } + + type Pokemon { + name: String! + } + + extend interface Trainer { + age: Int! + } + + union Types = Water | Fire + + extend input Move { + name: String + } + + input Deposit { + item: String! + } + + extend enum Badge { + EARTH + } + + extend union Berry = Oran + + extend type Pokemon { + types: Types! + } + + extend union Types = Grass + + interface Trainer { + name: String! + } + `, ` + extend type Rival { + version: Version! + } + + enum Badge { + BOULDER + SOUL + EARTH + } + + extend enum Version { + SILVER + } + + extend interface GymLeader { + badge: Badge! + } + + type Pokemon { + name: String! + types: Types! + } + + union Types = Water | Fire | Grass + + extend input Move { + name: String + } + + input Deposit { + item: String! + quantity: Int! + } + + extend union Berry = Oran + + interface Trainer { + name: String! + age: Int! + } + `) }) } diff --git a/pkg/astnormalization/enum_type_extending.go b/pkg/astnormalization/enum_type_extending.go index 0682fd2b3..e0ca63436 100644 --- a/pkg/astnormalization/enum_type_extending.go +++ b/pkg/astnormalization/enum_type_extending.go @@ -13,12 +13,22 @@ func extendEnumTypeDefinition(walker *astvisitor.Walker) { walker.RegisterEnterEnumTypeExtensionVisitor(&visitor) } +func extendEnumTypeDefinitionKeepingOrphans(walker *astvisitor.Walker) { + visitor := extendEnumTypeDefinitionVisitor{ + Walker: walker, + keepExtensionOrphans: true, + } + walker.RegisterEnterDocumentVisitor(&visitor) + walker.RegisterEnterEnumTypeExtensionVisitor(&visitor) +} + type extendEnumTypeDefinitionVisitor struct { *astvisitor.Walker - operation *ast.Document + operation *ast.Document + keepExtensionOrphans bool } -func (e *extendEnumTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { +func (e *extendEnumTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { e.operation = operation } @@ -36,5 +46,9 @@ func (e *extendEnumTypeDefinitionVisitor) EnterEnumTypeExtension(ref int) { return } + if e.keepExtensionOrphans { + return + } + e.operation.ImportAndExtendEnumTypeDefinitionByEnumTypeExtension(ref) } diff --git a/pkg/astnormalization/input_object_type_extending.go b/pkg/astnormalization/input_object_type_extending.go index 82ffd6675..ea44430fa 100644 --- a/pkg/astnormalization/input_object_type_extending.go +++ b/pkg/astnormalization/input_object_type_extending.go @@ -13,17 +13,26 @@ func extendInputObjectTypeDefinition(walker *astvisitor.Walker) { walker.RegisterEnterInputObjectTypeExtensionVisitor(&visitor) } +func extendInputObjectTypeDefinitionKeepingOrphans(walker *astvisitor.Walker) { + visitor := extendInputObjectTypeDefinitionVisitor{ + Walker: walker, + keepExtensionOrphans: true, + } + walker.RegisterEnterDocumentVisitor(&visitor) + walker.RegisterEnterInputObjectTypeExtensionVisitor(&visitor) +} + type extendInputObjectTypeDefinitionVisitor struct { *astvisitor.Walker - operation *ast.Document + operation *ast.Document + keepExtensionOrphans bool } -func (e *extendInputObjectTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { +func (e *extendInputObjectTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { e.operation = operation } func (e *extendInputObjectTypeDefinitionVisitor) EnterInputObjectTypeExtension(ref int) { - nodes, exists := e.operation.Index.NodesByNameBytes(e.operation.InputObjectTypeExtensionNameBytes(ref)) if !exists { return @@ -37,5 +46,9 @@ func (e *extendInputObjectTypeDefinitionVisitor) EnterInputObjectTypeExtension(r return } + if e.keepExtensionOrphans { + return + } + e.operation.ImportAndExtendInputObjectTypeDefinitionByInputObjectTypeExtension(ref) } diff --git a/pkg/astnormalization/interface_type_extending.go b/pkg/astnormalization/interface_type_extending.go index 4902640ac..c80d3b28d 100644 --- a/pkg/astnormalization/interface_type_extending.go +++ b/pkg/astnormalization/interface_type_extending.go @@ -13,17 +13,26 @@ func extendInterfaceTypeDefinition(walker *astvisitor.Walker) { walker.RegisterEnterInterfaceTypeExtensionVisitor(&visitor) } +func extendInterfaceTypeDefinitionKeepingOrphans(walker *astvisitor.Walker) { + visitor := extendInterfaceTypeDefinitionVisitor{ + Walker: walker, + keepExtensionOrphans: true, + } + walker.RegisterEnterDocumentVisitor(&visitor) + walker.RegisterEnterInterfaceTypeExtensionVisitor(&visitor) +} + type extendInterfaceTypeDefinitionVisitor struct { *astvisitor.Walker - operation *ast.Document + operation *ast.Document + keepExtensionOrphans bool } -func (e *extendInterfaceTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { +func (e *extendInterfaceTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { e.operation = operation } func (e *extendInterfaceTypeDefinitionVisitor) EnterInterfaceTypeExtension(ref int) { - nodes, exists := e.operation.Index.NodesByNameBytes(e.operation.InterfaceTypeExtensionNameBytes(ref)) if !exists { return @@ -37,5 +46,9 @@ func (e *extendInterfaceTypeDefinitionVisitor) EnterInterfaceTypeExtension(ref i return } + if e.keepExtensionOrphans { + return + } + e.operation.ImportAndExtendInterfaceTypeDefinitionByInterfaceTypeExtension(ref) } diff --git a/pkg/astnormalization/object_type_extending.go b/pkg/astnormalization/object_type_extending.go index eec5ccb62..ffa2e7cdc 100644 --- a/pkg/astnormalization/object_type_extending.go +++ b/pkg/astnormalization/object_type_extending.go @@ -13,12 +13,22 @@ func extendObjectTypeDefinition(walker *astvisitor.Walker) { walker.RegisterEnterObjectTypeExtensionVisitor(&visitor) } +func extendObjectTypeDefinitionKeepingOrphans(walker *astvisitor.Walker) { + visitor := extendObjectTypeDefinitionVisitor{ + Walker: walker, + keepExtensionOrphans: true, + } + walker.RegisterEnterDocumentVisitor(&visitor) + walker.RegisterEnterObjectTypeExtensionVisitor(&visitor) +} + type extendObjectTypeDefinitionVisitor struct { *astvisitor.Walker - operation *ast.Document + operation *ast.Document + keepExtensionOrphans bool } -func (e *extendObjectTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { +func (e *extendObjectTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { e.operation = operation } @@ -37,5 +47,9 @@ func (e *extendObjectTypeDefinitionVisitor) EnterObjectTypeExtension(ref int) { return } + if e.keepExtensionOrphans { + return + } + e.operation.ImportAndExtendObjectTypeDefinitionByObjectTypeExtension(ref) } diff --git a/pkg/astnormalization/scalar_type_extending.go b/pkg/astnormalization/scalar_type_extending.go index dafa2210d..de4086adf 100644 --- a/pkg/astnormalization/scalar_type_extending.go +++ b/pkg/astnormalization/scalar_type_extending.go @@ -13,12 +13,22 @@ func extendScalarTypeDefinition(walker *astvisitor.Walker) { walker.RegisterEnterScalarTypeExtensionVisitor(&visitor) } +func extendScalarTypeDefinitionKeepingOrphans(walker *astvisitor.Walker) { + visitor := extendScalarTypeDefinitionVisitor{ + Walker: walker, + keepExtensionOrphans: true, + } + walker.RegisterEnterDocumentVisitor(&visitor) + walker.RegisterEnterScalarTypeExtensionVisitor(&visitor) +} + type extendScalarTypeDefinitionVisitor struct { *astvisitor.Walker - operation *ast.Document + operation *ast.Document + keepExtensionOrphans bool } -func (e *extendScalarTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { +func (e *extendScalarTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { e.operation = operation } @@ -37,5 +47,9 @@ func (e *extendScalarTypeDefinitionVisitor) EnterScalarTypeExtension(ref int) { return } + if e.keepExtensionOrphans { + return + } + e.operation.ImportAndExtendScalarTypeDefinitionByScalarTypeExtension(ref) } diff --git a/pkg/astnormalization/union_type_extending.go b/pkg/astnormalization/union_type_extending.go index bcd8740e3..ac5ba6e29 100644 --- a/pkg/astnormalization/union_type_extending.go +++ b/pkg/astnormalization/union_type_extending.go @@ -13,17 +13,26 @@ func extendUnionTypeDefinition(walker *astvisitor.Walker) { walker.RegisterEnterUnionTypeExtensionVisitor(&visitor) } +func extendUnionTypeDefinitionKeepingOrphans(walker *astvisitor.Walker) { + visitor := extendUnionTypeDefinitionVisitor{ + Walker: walker, + keepExtensionOrphans: true, + } + walker.RegisterEnterDocumentVisitor(&visitor) + walker.RegisterEnterUnionTypeExtensionVisitor(&visitor) +} + type extendUnionTypeDefinitionVisitor struct { *astvisitor.Walker - operation *ast.Document + operation *ast.Document + keepExtensionOrphans bool } -func (e *extendUnionTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { +func (e *extendUnionTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { e.operation = operation } func (e *extendUnionTypeDefinitionVisitor) EnterUnionTypeExtension(ref int) { - nodes, exists := e.operation.Index.NodesByNameBytes(e.operation.UnionTypeExtensionNameBytes(ref)) if !exists { return @@ -37,5 +46,9 @@ func (e *extendUnionTypeDefinitionVisitor) EnterUnionTypeExtension(ref int) { return } + if e.keepExtensionOrphans { + return + } + e.operation.ImportAndExtendUnionTypeDefinitionByUnionTypeExtension(ref) } diff --git a/pkg/asttransform/baseschema.go b/pkg/asttransform/baseschema.go index 7fd6c4646..b2168fd4f 100644 --- a/pkg/asttransform/baseschema.go +++ b/pkg/asttransform/baseschema.go @@ -54,11 +54,11 @@ func addMissingRootOperationTypeDefinitions(definition *ast.Document) { typeName := definition.ObjectTypeDefinitionNameBytes(definition.RootNodes[i].Ref) switch { - case bytes.Equal(typeName, []byte("Query")): + case bytes.Equal(typeName, ast.DefaultQueryTypeName): rootOperationTypeRefs = createRootOperationTypeIfNotExists(definition, rootOperationTypeRefs, ast.OperationTypeQuery, i) - case bytes.Equal(typeName, []byte("Mutation")): + case bytes.Equal(typeName, ast.DefaultMutationTypeName): rootOperationTypeRefs = createRootOperationTypeIfNotExists(definition, rootOperationTypeRefs, ast.OperationTypeMutation, i) - case bytes.Equal(typeName, []byte("Subscription")): + case bytes.Equal(typeName, ast.DefaultSubscriptionTypeName): rootOperationTypeRefs = createRootOperationTypeIfNotExists(definition, rootOperationTypeRefs, ast.OperationTypeSubscription, i) default: continue diff --git a/pkg/asttransform/typename_visitor.go b/pkg/asttransform/typename_visitor.go index 950790f31..036553c66 100644 --- a/pkg/asttransform/typename_visitor.go +++ b/pkg/asttransform/typename_visitor.go @@ -11,8 +11,6 @@ import ( const typenameFieldName = "__typename" -var defaultSubscriptionName = []byte("Subscription") - type TypeNameVisitor struct { *astvisitor.Walker definition *ast.Document @@ -61,7 +59,7 @@ func (v *TypeNameVisitor) LeaveInterfaceTypeDefinition(ref int) { func (v *TypeNameVisitor) LeaveObjectTypeDefinition(ref int) { objectTypeDefName := v.definition.ObjectTypeDefinitionNameBytes(ref) if bytes.Equal(objectTypeDefName, v.definition.Index.SubscriptionTypeName) || - bytes.Equal(objectTypeDefName, defaultSubscriptionName) { + bytes.Equal(objectTypeDefName, ast.DefaultSubscriptionTypeName) { return } diff --git a/pkg/astvalidation/definition_validation.go b/pkg/astvalidation/definition_validation.go index 99b8e6e22..57be57034 100644 --- a/pkg/astvalidation/definition_validation.go +++ b/pkg/astvalidation/definition_validation.go @@ -8,10 +8,12 @@ import ( func DefaultDefinitionValidator() *DefinitionValidator { return NewDefinitionValidator( + PopulatedTypeBodies(), UniqueOperationTypes(), UniqueTypeNames(), UniqueFieldDefinitionNames(), UniqueEnumValueNames(), + UniqueUnionMemberTypes(), KnownTypeNames(), RequireDefinedTypesForExtensions(), ImplementTransitiveInterfaces(), diff --git a/pkg/astvalidation/rule.go b/pkg/astvalidation/rule.go index 46136f9f4..ab0814b0a 100644 --- a/pkg/astvalidation/rule.go +++ b/pkg/astvalidation/rule.go @@ -4,5 +4,7 @@ import ( "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" ) +var reservedFieldPrefix = []byte("__") + // Rule is hook to register callback functions on the Walker type Rule func(walker *astvisitor.Walker) diff --git a/pkg/astvalidation/rule_known_type_names.go b/pkg/astvalidation/rule_known_type_names.go index a10a84e5f..704fde0f2 100644 --- a/pkg/astvalidation/rule_known_type_names.go +++ b/pkg/astvalidation/rule_known_type_names.go @@ -35,13 +35,13 @@ type knownTypeNamesVisitor struct { referencedTypeNames map[uint64][]byte } -func (u *knownTypeNamesVisitor) EnterDocument(operation, definition *ast.Document) { +func (u *knownTypeNamesVisitor) EnterDocument(operation, _ *ast.Document) { u.definition = operation u.definedTypeNameHashs = make(map[uint64]bool) u.referencedTypeNames = make(map[uint64][]byte) } -func (u *knownTypeNamesVisitor) LeaveDocument(operation, definition *ast.Document) { +func (u *knownTypeNamesVisitor) LeaveDocument(_, _ *ast.Document) { for referencedTypeNameHash, referencedTypeName := range u.referencedTypeNames { if !u.definedTypeNameHashs[referencedTypeNameHash] { u.Report.AddExternalError(operationreport.ErrTypeUndefined(referencedTypeName)) @@ -54,7 +54,6 @@ func (u *knownTypeNamesVisitor) LeaveDocument(operation, definition *ast.Documen func (u *knownTypeNamesVisitor) EnterRootOperationTypeDefinition(ref int) { referencedTypeName := u.definition.Input.ByteSlice(u.definition.RootOperationTypeDefinitions[ref].NamedType.Name) u.saveReferencedTypeName(referencedTypeName) - u.saveReferencedTypeName(referencedTypeName) } func (u *knownTypeNamesVisitor) EnterFieldDefinition(ref int) { diff --git a/pkg/astvalidation/rule_populated_type_bodies.go b/pkg/astvalidation/rule_populated_type_bodies.go new file mode 100644 index 000000000..16820288f --- /dev/null +++ b/pkg/astvalidation/rule_populated_type_bodies.go @@ -0,0 +1,116 @@ +package astvalidation + +import ( + "bytes" + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +type populatedTypeBodiesVisitor struct { + *astvisitor.Walker + definition *ast.Document +} + +func PopulatedTypeBodies() Rule { + return func(walker *astvisitor.Walker) { + visitor := &populatedTypeBodiesVisitor{ + Walker: walker, + definition: nil, + } + + walker.RegisterEnterDocumentVisitor(visitor) + walker.RegisterEnterEnumTypeDefinitionVisitor(visitor) + walker.RegisterEnterEnumTypeExtensionVisitor(visitor) + walker.RegisterEnterInputObjectTypeDefinitionVisitor(visitor) + walker.RegisterEnterInputObjectTypeExtensionVisitor(visitor) + walker.RegisterEnterInterfaceTypeDefinitionVisitor(visitor) + walker.RegisterEnterInterfaceTypeExtensionVisitor(visitor) + walker.RegisterEnterObjectTypeDefinitionVisitor(visitor) + walker.RegisterEnterObjectTypeExtensionVisitor(visitor) + } +} + +func (p *populatedTypeBodiesVisitor) EnterDocument(operation, _ *ast.Document) { + p.definition = operation +} + +func (p populatedTypeBodiesVisitor) EnterEnumTypeDefinition(ref int) { + if !p.definition.EnumTypeDefinitions[ref].HasEnumValuesDefinition { + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("enum", p.definition.EnumTypeDefinitionNameString(ref))) + return + } +} + +func (p *populatedTypeBodiesVisitor) EnterEnumTypeExtension(ref int) { + if !p.definition.EnumTypeExtensions[ref].HasEnumValuesDefinition { + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("enum extension", p.definition.EnumTypeExtensionNameString(ref))) + return + } +} + +func (p populatedTypeBodiesVisitor) EnterInputObjectTypeDefinition(ref int) { + if !p.definition.InputObjectTypeDefinitions[ref].HasInputFieldsDefinition { + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("input", p.definition.InputObjectTypeDefinitionNameString(ref))) + return + } +} + +func (p *populatedTypeBodiesVisitor) EnterInputObjectTypeExtension(ref int) { + if !p.definition.InputObjectTypeExtensions[ref].HasInputFieldsDefinition { + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("input extension", p.definition.InputObjectTypeExtensionNameString(ref))) + return + } +} + +func (p populatedTypeBodiesVisitor) EnterInterfaceTypeDefinition(ref int) { + switch p.definition.InterfaceTypeDefinitions[ref].HasFieldDefinitions { + case true: + if !p.doesTypeOnlyContainReservedFields(p.definition.InterfaceTypeDefinitions[ref].FieldsDefinition.Refs) { + return + } + fallthrough + case false: + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("interface", p.definition.InterfaceTypeDefinitionNameString(ref))) + return + } +} + +func (p *populatedTypeBodiesVisitor) EnterInterfaceTypeExtension(ref int) { + if !p.definition.InterfaceTypeExtensions[ref].HasFieldDefinitions { + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("interface extension", p.definition.InterfaceTypeExtensionNameString(ref))) + return + } +} + +func (p populatedTypeBodiesVisitor) EnterObjectTypeDefinition(ref int) { + nameBytes := p.definition.ObjectTypeDefinitionNameBytes(ref) + object := p.definition.ObjectTypeDefinitions[ref] + switch object.HasFieldDefinitions { + case true: + if ast.IsRootType(nameBytes) || !p.doesTypeOnlyContainReservedFields(p.definition.ObjectTypeDefinitions[ref].FieldsDefinition.Refs) { + return + } + fallthrough + case false: + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("object", string(nameBytes))) + return + } +} + +func (p *populatedTypeBodiesVisitor) EnterObjectTypeExtension(ref int) { + if !p.definition.ObjectTypeExtensions[ref].HasFieldDefinitions { + p.Report.AddExternalError(operationreport.ErrTypeBodyMustNotBeEmpty("object extension", p.definition.ObjectTypeExtensionNameString(ref))) + return + } +} + +func (p *populatedTypeBodiesVisitor) doesTypeOnlyContainReservedFields(refs []int) bool { + for _, fieldRef := range refs { + fieldNameBytes := p.definition.FieldDefinitionNameBytes(fieldRef) + if len(fieldNameBytes) < 2 || !bytes.HasPrefix(fieldNameBytes, reservedFieldPrefix) { + return false + } + } + return true +} diff --git a/pkg/astvalidation/rule_populated_type_bodies_test.go b/pkg/astvalidation/rule_populated_type_bodies_test.go new file mode 100644 index 000000000..38c7c02e9 --- /dev/null +++ b/pkg/astvalidation/rule_populated_type_bodies_test.go @@ -0,0 +1,126 @@ +package astvalidation + +import ( + "testing" +) + +func TestPopulatedTypeBodies(t *testing.T) { + t.Run("Definition", func(t *testing.T) { + t.Run("Populated type bodies are valid", func(t *testing.T) { + runDefinitionValidation(t, ` + enum Species { + CAT + } + + extend enum Color { + DOG + } + + input Message { + content: String! + } + + extend input Message { + updated: DateTime! + } + + interface Animal { + species: Species! + } + + extend interface Animal { + age: Int! + } + + type Cat implements Animal { + species: Species! + } + + extend type Cat implements Animal { + age: Int! + } + `, Valid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty enum is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + enum Species { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty enum extension is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + enum Species { + CAT + } + + extend enum Species { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty input is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + input Message { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty input extension is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + input Message { + content: String! + } + + extend input Message { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty interface is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + interface Animal { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty interface extension is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + interface Animal { + species: String! + } + + extend interface Animal { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty object is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + type Cat { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + + t.Run("Empty object extension is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + type Cat { + species: String! + } + + extend type Cat { + } + `, Invalid, PopulatedTypeBodies(), + ) + }) + }) +} diff --git a/pkg/astvalidation/rule_unique_field_definition_names.go b/pkg/astvalidation/rule_unique_field_definition_names.go index dc5580f36..405a2cdd4 100644 --- a/pkg/astvalidation/rule_unique_field_definition_names.go +++ b/pkg/astvalidation/rule_unique_field_definition_names.go @@ -39,7 +39,7 @@ type uniqueFieldDefinitionNamesVisitor struct { usedFieldNames map[uint64]hashedFieldNames // map of hashed type names containing a map of hashed field names } -func (u *uniqueFieldDefinitionNamesVisitor) EnterDocument(operation, definition *ast.Document) { +func (u *uniqueFieldDefinitionNamesVisitor) EnterDocument(operation, _ *ast.Document) { u.definition = operation u.currentTypeName = u.currentTypeName[:0] u.currentTypeNameHash = 0 @@ -66,7 +66,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) EnterObjectTypeDefinition(ref int) { u.setCurrentTypeName(typeName, ast.NodeKindObjectTypeDefinition) } -func (u *uniqueFieldDefinitionNamesVisitor) LeaveObjectTypeDefinition(ref int) { +func (u *uniqueFieldDefinitionNamesVisitor) LeaveObjectTypeDefinition(_ int) { u.unsetCurrentTypeName() } @@ -75,7 +75,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) EnterObjectTypeExtension(ref int) { u.setCurrentTypeName(typeName, ast.NodeKindObjectTypeExtension) } -func (u *uniqueFieldDefinitionNamesVisitor) LeaveObjectTypeExtension(ref int) { +func (u *uniqueFieldDefinitionNamesVisitor) LeaveObjectTypeExtension(_ int) { u.unsetCurrentTypeName() } @@ -84,7 +84,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) EnterInterfaceTypeDefinition(ref int u.setCurrentTypeName(typeName, ast.NodeKindInterfaceTypeDefinition) } -func (u *uniqueFieldDefinitionNamesVisitor) LeaveInterfaceTypeDefinition(ref int) { +func (u *uniqueFieldDefinitionNamesVisitor) LeaveInterfaceTypeDefinition(_ int) { u.unsetCurrentTypeName() } @@ -93,7 +93,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) EnterInterfaceTypeExtension(ref int) u.setCurrentTypeName(typeName, ast.NodeKindInterfaceTypeExtension) } -func (u *uniqueFieldDefinitionNamesVisitor) LeaveInterfaceTypeExtension(ref int) { +func (u *uniqueFieldDefinitionNamesVisitor) LeaveInterfaceTypeExtension(_ int) { u.unsetCurrentTypeName() } @@ -102,7 +102,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) EnterInputObjectTypeDefinition(ref i u.setCurrentTypeName(typeName, ast.NodeKindInputObjectTypeDefinition) } -func (u *uniqueFieldDefinitionNamesVisitor) LeaveInputObjectTypeDefinition(ref int) { +func (u *uniqueFieldDefinitionNamesVisitor) LeaveInputObjectTypeDefinition(_ int) { u.unsetCurrentTypeName() } @@ -111,7 +111,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) EnterInputObjectTypeExtension(ref in u.setCurrentTypeName(typeName, ast.NodeKindInputObjectTypeExtension) } -func (u *uniqueFieldDefinitionNamesVisitor) LeaveInputObjectTypeExtension(ref int) { +func (u *uniqueFieldDefinitionNamesVisitor) LeaveInputObjectTypeExtension(_ int) { u.unsetCurrentTypeName() } @@ -132,7 +132,7 @@ func (u *uniqueFieldDefinitionNamesVisitor) unsetCurrentTypeName() { } func (u *uniqueFieldDefinitionNamesVisitor) checkField(fieldName ast.ByteSlice) { - if bytes.HasPrefix(fieldName, []byte("__")) { // don't validate graphql reserved fields + if bytes.HasPrefix(fieldName, reservedFieldPrefix) { // don't validate graphql reserved fields return } diff --git a/pkg/astvalidation/rule_unique_union_member_types.go b/pkg/astvalidation/rule_unique_union_member_types.go new file mode 100644 index 000000000..204a2ac53 --- /dev/null +++ b/pkg/astvalidation/rule_unique_union_member_types.go @@ -0,0 +1,93 @@ +package astvalidation + +import ( + "github.com/cespare/xxhash/v2" + + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +type hashedMembers map[uint64]bool + +type uniqueUnionMemberTypesVisitor struct { + *astvisitor.Walker + definition *ast.Document + currentUnionName ast.ByteSlice + currentUnionHash uint64 + presentMembers map[uint64]hashedMembers +} + +func UniqueUnionMemberTypes() Rule { + return func(walker *astvisitor.Walker) { + visitor := &uniqueUnionMemberTypesVisitor{ + Walker: walker, + } + + walker.RegisterEnterDocumentVisitor(visitor) + walker.RegisterEnterUnionTypeDefinitionVisitor(visitor) + walker.RegisterEnterUnionMemberTypeVisitor(visitor) + walker.RegisterUnionTypeDefinitionVisitor(visitor) + walker.RegisterUnionTypeExtensionVisitor(visitor) + } +} + +func (u *uniqueUnionMemberTypesVisitor) EnterDocument(operation, _ *ast.Document) { + u.definition = operation + u.currentUnionName = u.currentUnionName[:0] + u.currentUnionHash = 0 + u.presentMembers = make(map[uint64]hashedMembers) +} + +func (u *uniqueUnionMemberTypesVisitor) EnterUnionTypeDefinition(ref int) { + unionName := u.definition.UnionTypeDefinitionNameBytes(ref) + u.setCurrentUnion(unionName) +} + +func (u *uniqueUnionMemberTypesVisitor) LeaveUnionTypeDefinition(_ int) { + u.unsetCurrentUnion() +} + +func (u *uniqueUnionMemberTypesVisitor) EnterUnionMemberType(ref int) { + memberName := u.definition.TypeNameBytes(ref) + u.checkMemberName(memberName) +} + +func (u *uniqueUnionMemberTypesVisitor) EnterUnionTypeExtension(ref int) { + unionName := u.definition.UnionTypeExtensionNameBytes(ref) + u.setCurrentUnion(unionName) +} + +func (u *uniqueUnionMemberTypesVisitor) LeaveUnionTypeExtension(_ int) { + u.unsetCurrentUnion() +} + +func (u *uniqueUnionMemberTypesVisitor) setCurrentUnion(unionName ast.ByteSlice) { + u.currentUnionName = unionName + u.currentUnionHash = xxhash.Sum64(unionName) +} + +func (u *uniqueUnionMemberTypesVisitor) unsetCurrentUnion() { + u.currentUnionName = u.currentUnionName[:0] + u.currentUnionHash = 0 +} + +func (u *uniqueUnionMemberTypesVisitor) checkMemberName(memberName ast.ByteSlice) { + if len(u.currentUnionName) == 0 || u.currentUnionHash == 0 { + return + } + + memberNameHash := xxhash.Sum64(memberName) + memberNames, ok := u.presentMembers[u.currentUnionHash] + if !ok { + memberNames = make(hashedMembers) + } + + if memberNames[memberNameHash] { + u.Report.AddExternalError(operationreport.ErrUnionMembersMustBeUnique(u.currentUnionName, memberName)) + return + } + + memberNames[memberNameHash] = true + u.presentMembers[u.currentUnionHash] = memberNames +} diff --git a/pkg/astvalidation/rule_unique_union_member_types_test.go b/pkg/astvalidation/rule_unique_union_member_types_test.go new file mode 100644 index 000000000..0783ac5e7 --- /dev/null +++ b/pkg/astvalidation/rule_unique_union_member_types_test.go @@ -0,0 +1,30 @@ +package astvalidation + +import ( + "testing" +) + +func TestUniqueMemberTypes(t *testing.T) { + t.Run("Definition", func(t *testing.T) { + t.Run("Union with a single member is valid", func(t *testing.T) { + runDefinitionValidation(t, ` + union Foo = Bar + `, Valid, UniqueUnionMemberTypes(), + ) + }) + + t.Run("Union with many members is valid", func(t *testing.T) { + runDefinitionValidation(t, ` + union Foo = Bar | FooBar | BarFoo + `, Valid, UniqueUnionMemberTypes(), + ) + }) + + t.Run("Union with duplicate members is invalid", func(t *testing.T) { + runDefinitionValidation(t, ` + union Foo = Bar | Bar + `, Invalid, UniqueUnionMemberTypes(), + ) + }) + }) +} diff --git a/pkg/engine/plan/local_type_field_extractor.go b/pkg/engine/plan/local_type_field_extractor.go index e7b2ea4c3..7408fcb27 100644 --- a/pkg/engine/plan/local_type_field_extractor.go +++ b/pkg/engine/plan/local_type_field_extractor.go @@ -4,8 +4,9 @@ import ( "github.com/wundergraph/graphql-go-tools/pkg/ast" ) +const FederationKeyDirectiveName = "key" + const ( - federationKeyDirectiveName = "key" federationRequireDirectiveName = "requires" federationExternalDirectiveName = "external" ) @@ -200,13 +201,13 @@ func (e *LocalTypeFieldExtractor) getNodeInfo(node ast.Node) *nodeInformation { nodeInfo, ok := e.nodeInfoMap[typeName] if ok { // if this node has the key directive, we need to add it to the node information - nodeInfo.hasKeyDirective = nodeInfo.hasKeyDirective || e.document.NodeHasDirectiveByNameString(node, federationKeyDirectiveName) + nodeInfo.hasKeyDirective = nodeInfo.hasKeyDirective || e.document.NodeHasDirectiveByNameString(node, FederationKeyDirectiveName) return nodeInfo } nodeInfo = &nodeInformation{ typeName: typeName, - hasKeyDirective: e.document.NodeHasDirectiveByNameString(node, federationKeyDirectiveName), + hasKeyDirective: e.document.NodeHasDirectiveByNameString(node, FederationKeyDirectiveName), requiredFields: make(map[string]struct{}), } diff --git a/pkg/engine/plan/required_field_extractor.go b/pkg/engine/plan/required_field_extractor.go index 618d3ee2b..f84b7cc4c 100644 --- a/pkg/engine/plan/required_field_extractor.go +++ b/pkg/engine/plan/required_field_extractor.go @@ -119,7 +119,7 @@ func requiredFieldsByRequiresDirective(document *ast.Document, fieldDefinitionRe func (f *RequiredFieldExtractor) primaryKeyFieldsIfObjectTypeIsEntity(objectType ast.ObjectTypeDefinition) (keyFields []string, ok bool) { for _, directiveRef := range objectType.Directives.Refs { - if directiveName := f.document.DirectiveNameString(directiveRef); directiveName != federationKeyDirectiveName { + if directiveName := f.document.DirectiveNameString(directiveRef); directiveName != FederationKeyDirectiveName { continue } diff --git a/pkg/execution/datasource_http_json_test.go b/pkg/execution/datasource_http_json_test.go index 7f869a6ee..ba78f950f 100644 --- a/pkg/execution/datasource_http_json_test.go +++ b/pkg/execution/datasource_http_json_test.go @@ -113,7 +113,7 @@ func TestHttpJsonDataSourcePlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -219,7 +219,7 @@ func TestHttpJsonDataSourcePlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -306,7 +306,7 @@ func TestHttpJsonDataSourcePlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -415,7 +415,7 @@ func TestHttpJsonDataSourcePlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -565,7 +565,7 @@ func TestHttpJsonDataSourcePlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), diff --git a/pkg/execution/planning_test.go b/pkg/execution/planning_test.go index 3f7ec12cd..5630d27d5 100644 --- a/pkg/execution/planning_test.go +++ b/pkg/execution/planning_test.go @@ -168,7 +168,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -289,7 +289,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Mutation"), + Value: ast.DefaultMutationTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -461,7 +461,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -488,7 +488,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -694,7 +694,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -773,7 +773,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -851,7 +851,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -942,7 +942,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -1046,7 +1046,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -1377,7 +1377,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -1498,7 +1498,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), @@ -1632,7 +1632,7 @@ func TestPlanner_Plan(t *testing.T) { Args: []datasource.Argument{ &datasource.StaticVariableArgument{ Name: []byte("root_type_name"), - Value: []byte("Query"), + Value: ast.DefaultQueryTypeName, }, &datasource.StaticVariableArgument{ Name: []byte("root_field_name"), diff --git a/pkg/federation/sdlmerge/collect_entities.go b/pkg/federation/sdlmerge/collect_entities.go new file mode 100644 index 000000000..2b07f8434 --- /dev/null +++ b/pkg/federation/sdlmerge/collect_entities.go @@ -0,0 +1,62 @@ +package sdlmerge + +import ( + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +type collectEntitiesVisitor struct { + *astvisitor.Walker + document *ast.Document + collectedEntities entitySet +} + +func newCollectEntitiesVisitor(collectedEntities entitySet) *collectEntitiesVisitor { + return &collectEntitiesVisitor{ + collectedEntities: collectedEntities, + } +} + +func (c *collectEntitiesVisitor) Register(walker *astvisitor.Walker) { + c.Walker = walker + walker.RegisterEnterDocumentVisitor(c) + walker.RegisterEnterInterfaceTypeDefinitionVisitor(c) + walker.RegisterEnterObjectTypeDefinitionVisitor(c) +} + +func (c *collectEntitiesVisitor) EnterDocument(operation, _ *ast.Document) { + c.document = operation +} + +func (c *collectEntitiesVisitor) EnterInterfaceTypeDefinition(ref int) { + interfaceType := c.document.InterfaceTypeDefinitions[ref] + name := c.document.InterfaceTypeDefinitionNameString(ref) + if err := c.resolvePotentialEntity(name, interfaceType.Directives.Refs); err != nil { + c.StopWithExternalErr(*err) + } +} + +func (c *collectEntitiesVisitor) EnterObjectTypeDefinition(ref int) { + objectType := c.document.ObjectTypeDefinitions[ref] + name := c.document.ObjectTypeDefinitionNameString(ref) + if err := c.resolvePotentialEntity(name, objectType.Directives.Refs); err != nil { + c.StopWithExternalErr(*err) + } +} + +func (c *collectEntitiesVisitor) resolvePotentialEntity(name string, directiveRefs []int) *operationreport.ExternalError { + if _, exists := c.collectedEntities[name]; exists { + err := operationreport.ErrEntitiesMustNotBeDuplicated(name) + return &err + } + for _, directiveRef := range directiveRefs { + if c.document.DirectiveNameString(directiveRef) != plan.FederationKeyDirectiveName { + continue + } + c.collectedEntities[name] = struct{}{} + return nil + } + return nil +} diff --git a/pkg/federation/sdlmerge/collect_entities_test.go b/pkg/federation/sdlmerge/collect_entities_test.go new file mode 100644 index 000000000..f7813546c --- /dev/null +++ b/pkg/federation/sdlmerge/collect_entities_test.go @@ -0,0 +1,86 @@ +package sdlmerge + +import ( + "github.com/stretchr/testify/assert" + "github.com/wundergraph/graphql-go-tools/internal/pkg/unsafeparser" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" + "testing" +) + +func TestCollectEntities(t *testing.T) { + t.Run("Valid entities are collected", func(t *testing.T) { + collectEntities(t, newCollectEntitiesVisitor(newTestNormalizer(false)), ` + type Dog @key(fields: "name") @key(fields: "id") { + id: ID! + name: String! + } + + type Cat @key(fields: "species") { + id: ID! + species: String! + } + `, entitySet{ + "Dog": {}, + "Cat": {}, + }) + }) + + t.Run("Valid entities are collected", func(t *testing.T) { + collectEntitiesAndExpectError(t, newCollectEntitiesVisitor(newTestNormalizer(false)), ` + type Dog @key(fields: "name") @key(fields: "id") { + id: ID! + name: String! + } + + type Dog @key(fields: "name") @key(fields: "id") { + id: ID! + name: String! + } + + type Cat @key(fields: "species") { + id: ID! + species: String! + } + `, duplicateEntityErrorMessage("Dog")) + }) +} + +var collectEntities = func(t *testing.T, visitor *collectEntitiesVisitor, operation string, expectedEntities entitySet) { + operationDocument := unsafeparser.ParseGraphqlDocumentString(operation) + report := operationreport.Report{} + walker := astvisitor.NewWalker(48) + + visitor.Register(&walker) + + walker.Walk(&operationDocument, nil, &report) + + if report.HasErrors() { + t.Fatal(report.Error()) + } + + got := visitor.collectedEntities + + assert.Equal(t, expectedEntities, got) +} + +var collectEntitiesAndExpectError = func(t *testing.T, visitor *collectEntitiesVisitor, operation string, expectedError string) { + operationDocument := unsafeparser.ParseGraphqlDocumentString(operation) + report := operationreport.Report{} + walker := astvisitor.NewWalker(48) + + visitor.Register(&walker) + + walker.Walk(&operationDocument, nil, &report) + + var got string + if report.HasErrors() { + if report.InternalErrors == nil { + got = report.ExternalErrors[0].Message + } else { + got = report.InternalErrors[0].Error() + } + } + + assert.Equal(t, expectedError, got) +} diff --git a/pkg/federation/sdlmerge/enum_type_extending.go b/pkg/federation/sdlmerge/enum_type_extending.go new file mode 100644 index 000000000..57a586c79 --- /dev/null +++ b/pkg/federation/sdlmerge/enum_type_extending.go @@ -0,0 +1,50 @@ +package sdlmerge + +import ( + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +type extendEnumTypeDefinitionVisitor struct { + *astvisitor.Walker + document *ast.Document +} + +func newExtendEnumTypeDefinition() *extendEnumTypeDefinitionVisitor { + return &extendEnumTypeDefinitionVisitor{} +} + +func (e *extendEnumTypeDefinitionVisitor) Register(walker *astvisitor.Walker) { + e.Walker = walker + walker.RegisterEnterDocumentVisitor(e) + walker.RegisterEnterEnumTypeExtensionVisitor(e) +} + +func (e *extendEnumTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { + e.document = operation +} + +func (e *extendEnumTypeDefinitionVisitor) EnterEnumTypeExtension(ref int) { + nodes, exists := e.document.Index.NodesByNameBytes(e.document.EnumTypeExtensionNameBytes(ref)) + if !exists { + return + } + + hasExtended := false + for i := range nodes { + if nodes[i].Kind != ast.NodeKindEnumTypeDefinition { + continue + } + if hasExtended { + e.StopWithExternalErr(operationreport.ErrSharedTypesMustNotBeExtended(e.document.EnumTypeExtensionNameString(ref))) + return + } + e.document.ExtendEnumTypeDefinitionByEnumTypeExtension(nodes[i].Ref, ref) + hasExtended = true + } + + if !hasExtended { + e.StopWithExternalErr(operationreport.ErrExtensionOrphansMustResolveInSupergraph(e.document.EnumTypeExtensionNameBytes(ref))) + } +} diff --git a/pkg/federation/sdlmerge/enum_type_extending_test.go b/pkg/federation/sdlmerge/enum_type_extending_test.go new file mode 100644 index 000000000..01e290153 --- /dev/null +++ b/pkg/federation/sdlmerge/enum_type_extending_test.go @@ -0,0 +1,107 @@ +package sdlmerge + +import "testing" + +func TestExtendEnumObjectType(t *testing.T) { + t.Run("extend simple enum type by field", func(t *testing.T) { + run(t, newExtendEnumTypeDefinition(), ` + enum Starters { + BULBASAUR + CHARMANDER + SQUIRTLE + } + + extend enum Starters { + CHIKORITA + } + `, ` + enum Starters { + BULBASAUR + CHARMANDER + SQUIRTLE + CHIKORITA + } + + extend enum Starters { + CHIKORITA + } + `) + }) + + t.Run("extend simple enum type by directive", func(t *testing.T) { + run(t, newExtendEnumTypeDefinition(), ` + enum Starters { + BULBASAUR + CHARMANDER + SQUIRTLE + } + + extend enum Starters @deprecated(reason: "some reason") + `, ` + enum Starters @deprecated(reason: "some reason") { + BULBASAUR + CHARMANDER + SQUIRTLE + } + + extend enum Starters @deprecated(reason: "some reason") + `) + }) + + t.Run("extend enum type by complex extends", func(t *testing.T) { + run(t, newExtendEnumTypeDefinition(), ` + enum Starters { + BULBASAUR + CHARMANDER + SQUIRTLE + } + + extend enum Starters @deprecated(reason: "some reason") @skip(if: false) { + CHIKORITA + CYNDAQUIL + } + `, ` + enum Starters @deprecated(reason: "some reason") @skip(if: false) { + BULBASAUR + CHARMANDER + SQUIRTLE + CHIKORITA + CYNDAQUIL + } + + extend enum Starters @deprecated(reason: "some reason") @skip(if: false) { + CHIKORITA + CYNDAQUIL + } + `) + }) + + t.Run("Extending an enum that is a shared type returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendEnumTypeDefinition(), ` + enum Starters { + BULBASAUR + CHARMANDER + SQUIRTLE + } + + enum Starters { + BULBASAUR + CHARMANDER + SQUIRTLE + } + + extend enum Starters @deprecated(reason: "some reason") @skip(if: false) { + CHIKORITA + CYNDAQUIL + } + `, sharedTypeExtensionErrorMessage("Starters")) + }) + + t.Run("Unresolved enum extension orphan returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendEnumTypeDefinition(), ` + extend enum Badges { + BOULDER + } + `, unresolvedExtensionOrphansErrorMessage("Badges")) + }) +} diff --git a/pkg/federation/sdlmerge/input_type_extending.go b/pkg/federation/sdlmerge/input_type_extending.go new file mode 100644 index 000000000..2292cad80 --- /dev/null +++ b/pkg/federation/sdlmerge/input_type_extending.go @@ -0,0 +1,50 @@ +package sdlmerge + +import ( + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +func newExtendInputObjectTypeDefinition() *extendInputObjectTypeDefinitionVisitor { + return &extendInputObjectTypeDefinitionVisitor{} +} + +type extendInputObjectTypeDefinitionVisitor struct { + *astvisitor.Walker + document *ast.Document +} + +func (e *extendInputObjectTypeDefinitionVisitor) Register(walker *astvisitor.Walker) { + e.Walker = walker + walker.RegisterEnterDocumentVisitor(e) + walker.RegisterEnterInputObjectTypeExtensionVisitor(e) +} + +func (e *extendInputObjectTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { + e.document = operation +} + +func (e *extendInputObjectTypeDefinitionVisitor) EnterInputObjectTypeExtension(ref int) { + nodes, exists := e.document.Index.NodesByNameBytes(e.document.InputObjectTypeExtensionNameBytes(ref)) + if !exists { + return + } + + hasExtended := false + for i := range nodes { + if nodes[i].Kind != ast.NodeKindInputObjectTypeDefinition { + continue + } + if hasExtended { + e.StopWithExternalErr(operationreport.ErrSharedTypesMustNotBeExtended(e.document.InputObjectTypeExtensionNameString(ref))) + return + } + e.document.ExtendInputObjectTypeDefinitionByInputObjectTypeExtension(nodes[i].Ref, ref) + hasExtended = true + } + + if !hasExtended { + e.StopWithExternalErr(operationreport.ErrExtensionOrphansMustResolveInSupergraph(e.document.InputObjectTypeExtensionNameBytes(ref))) + } +} diff --git a/pkg/federation/sdlmerge/input_type_extending_test.go b/pkg/federation/sdlmerge/input_type_extending_test.go new file mode 100644 index 000000000..4cfe91628 --- /dev/null +++ b/pkg/federation/sdlmerge/input_type_extending_test.go @@ -0,0 +1,83 @@ +package sdlmerge + +import "testing" + +func TestExtendInputObjectType(t *testing.T) { + t.Run("extend simple input type by field", func(t *testing.T) { + run(t, newExtendInputObjectTypeDefinition(), ` + input Mammal { + name: String + } + extend input Mammal { + furType: String + } + `, ` + input Mammal { + name: String + furType: String + } + extend input Mammal { + furType: String + } + `) + }) + + t.Run("extend simple input type by directive", func(t *testing.T) { + run(t, newExtendInputObjectTypeDefinition(), ` + input Mammal { + name: String + } + extend input Mammal @deprecated(reason: "some reason") + `, ` + input Mammal @deprecated(reason: "some reason") { + name: String + } + extend input Mammal @deprecated(reason: "some reason") + `) + }) + + t.Run("extend input type by complex extends", func(t *testing.T) { + run(t, newExtendInputObjectTypeDefinition(), ` + input Mammal { + name: String + } + extend input Mammal @deprecated(reason: "some reason") @skip(if: false) { + furType: String + age: Int + } + `, ` + input Mammal @deprecated(reason: "some reason") @skip(if: false) { + name: String + furType: String + age: Int + } + extend input Mammal @deprecated(reason: "some reason") @skip(if: false) { + furType: String + age: Int + } + `) + }) + + t.Run("Extending an input that is a shared type returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInputObjectTypeDefinition(), ` + input Mammal { + name: String + } + input Mammal { + name: String + } + extend input Mammal @deprecated(reason: "some reason") @skip(if: false) { + furType: String + age: Int + } + `, sharedTypeExtensionErrorMessage("Mammal")) + }) + + t.Run("Unresolved input extension orphan returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInputObjectTypeDefinition(), ` + extend input Badges { + name: String! + } + `, unresolvedExtensionOrphansErrorMessage("Badges")) + }) +} diff --git a/pkg/federation/sdlmerge/interface_type_extending.go b/pkg/federation/sdlmerge/interface_type_extending.go index 211d3285e..0379d2ce9 100644 --- a/pkg/federation/sdlmerge/interface_type_extending.go +++ b/pkg/federation/sdlmerge/interface_type_extending.go @@ -3,37 +3,61 @@ package sdlmerge import ( "github.com/wundergraph/graphql-go-tools/pkg/ast" "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" ) -func newExtendInterfaceTypeDefinition() *extendInterfaceTypeDefinitionVisitor { - return &extendInterfaceTypeDefinitionVisitor{} +func newExtendInterfaceTypeDefinition(collectedEntities entitySet) *extendInterfaceTypeDefinitionVisitor { + return &extendInterfaceTypeDefinitionVisitor{ + collectedEntities: collectedEntities, + } } type extendInterfaceTypeDefinitionVisitor struct { - operation *ast.Document + *astvisitor.Walker + document *ast.Document + collectedEntities entitySet } func (e *extendInterfaceTypeDefinitionVisitor) Register(walker *astvisitor.Walker) { + e.Walker = walker walker.RegisterEnterDocumentVisitor(e) walker.RegisterEnterInterfaceTypeExtensionVisitor(e) } -func (e *extendInterfaceTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { - e.operation = operation +func (e *extendInterfaceTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { + e.document = operation } func (e *extendInterfaceTypeDefinitionVisitor) EnterInterfaceTypeExtension(ref int) { - - nodes, exists := e.operation.Index.NodesByNameBytes(e.operation.InterfaceTypeExtensionNameBytes(ref)) + nameBytes := e.document.InterfaceTypeExtensionNameBytes(ref) + nodes, exists := e.document.Index.NodesByNameBytes(nameBytes) if !exists { return } + var nodeToExtend *ast.Node + isEntity := false for i := range nodes { if nodes[i].Kind != ast.NodeKindInterfaceTypeDefinition { continue } - e.operation.ExtendInterfaceTypeDefinitionByInterfaceTypeExtension(nodes[i].Ref, ref) + if nodeToExtend != nil { + e.StopWithExternalErr(*multipleExtensionError(isEntity, nameBytes)) + return + } + var err *operationreport.ExternalError + extension := e.document.InterfaceTypeExtensions[ref] + if isEntity, err = e.collectedEntities.isExtensionForEntity(nameBytes, extension.Directives.Refs, e.document); err != nil { + e.StopWithExternalErr(*err) + return + } + nodeToExtend = &nodes[i] + } + + if nodeToExtend == nil { + e.StopWithExternalErr(operationreport.ErrExtensionOrphansMustResolveInSupergraph(e.document.InterfaceTypeExtensionNameBytes(ref))) return } + + e.document.ExtendInterfaceTypeDefinitionByInterfaceTypeExtension(nodeToExtend.Ref, ref) } diff --git a/pkg/federation/sdlmerge/interface_type_extending_test.go b/pkg/federation/sdlmerge/interface_type_extending_test.go index 0e9c6b142..70080267a 100644 --- a/pkg/federation/sdlmerge/interface_type_extending_test.go +++ b/pkg/federation/sdlmerge/interface_type_extending_test.go @@ -4,55 +4,169 @@ import "testing" func TestExtendInterfaceType(t *testing.T) { t.Run("extend simple interface type by field", func(t *testing.T) { - run(t, newExtendInterfaceTypeDefinition(), ` - interface Mammal { - name: String - } - extend interface Mammal { - furType: String - } - `, ` - interface Mammal { - name: String - furType: String - } - extend interface Mammal { - furType: String - } - `) + run(t, newExtendInterfaceTypeDefinition(newTestNormalizer(false)), ` + interface Mammal { + name: String + } + + extend interface Mammal { + furType: String + } + `, ` + interface Mammal { + name: String + furType: String + } + + extend interface Mammal { + furType: String + } + `) }) + t.Run("extend simple interface type by directive", func(t *testing.T) { - run(t, newExtendInterfaceTypeDefinition(), ` - interface Mammal { - name: String - } - extend interface Mammal @deprecated(reason: "some reason") - `, ` - interface Mammal @deprecated(reason: "some reason") { - name: String - } - extend interface Mammal @deprecated(reason: "some reason") - `) + run(t, newExtendInterfaceTypeDefinition(newTestNormalizer(false)), ` + interface Mammal { + name: String + } + + extend interface Mammal @deprecated(reason: "some reason") + `, ` + interface Mammal @deprecated(reason: "some reason") { + name: String + } + + extend interface Mammal @deprecated(reason: "some reason") + `) }) + t.Run("extend interface type by complex extends", func(t *testing.T) { - run(t, newExtendInterfaceTypeDefinition(), ` - interface Mammal { - name: String - } - extend interface Mammal @deprecated(reason: "some reason") @skip(if: false) { - furType: String - age: Int - } - `, ` - interface Mammal @deprecated(reason: "some reason") @skip(if: false) { - name: String - furType: String - age: Int - } - extend interface Mammal @deprecated(reason: "some reason") @skip(if: false) { - furType: String - age: Int - } - `) + run(t, newExtendInterfaceTypeDefinition(newTestNormalizer(false)), ` + interface Mammal { + name: String + } + + extend interface Mammal @deprecated(reason: "some reason") @skip(if: false) { + furType: String + age: Int + } + `, ` + interface Mammal @deprecated(reason: "some reason") @skip(if: false) { + name: String + furType: String + age: Int + } + + extend interface Mammal @deprecated(reason: "some reason") @skip(if: false) { + furType: String + age: Int + } + `) + }) + + t.Run("Extending an interface that is a shared type returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInterfaceTypeDefinition(newTestNormalizer(false)), ` + interface Mammal { + name: String + } + + interface Mammal { + name: String + } + + extend interface Mammal @deprecated(reason: "some reason") @skip(if: false) { + furType: String + age: Int + } + `, sharedTypeExtensionErrorMessage("Mammal")) + }) + + t.Run("Unresolved interface extension orphan returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInterfaceTypeDefinition(newTestNormalizer(false)), ` + extend interface Mammal { + name: String! + } + `, unresolvedExtensionOrphansErrorMessage("Mammal")) + }) + + t.Run("Entity is extended successfully", func(t *testing.T) { + run(t, newExtendInterfaceTypeDefinition(newTestNormalizer(true)), ` + interface Mammal @key(fields: "name") { + name: String! + } + + extend interface Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `, ` + interface Mammal @key(fields: "name") @key(fields: "name") { + + name: String! + name: String! @external + age: Int! + } + + extend interface Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `) + }) + + t.Run("No key directive on entity extension returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInterfaceTypeDefinition(newTestNormalizer(true)), ` + interface Mammal @key(fields: "name") { + name: String! + } + + extend interface Mammal { + name: String! @external + age: Int! + } + `, noKeyDirectiveErrorMessage("Mammal")) + }) + + t.Run("Non-key directive when extending an entity returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInterfaceTypeDefinition(newTestNormalizer(true)), ` + interface Mammal @key(fields: "name") { + name: String! + } + + extend interface Mammal @deprecated { + name: String! @external + age: Int! + } + `, noKeyDirectiveErrorMessage("Mammal")) + }) + + t.Run("Extending multiple entities returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInterfaceTypeDefinition(newTestNormalizer(true)), ` + interface Mammal @key(fields: "name") { + name: String! + } + + interface Mammal @key(fields: "name") { + name: String! + } + + extend interface Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `, duplicateEntityErrorMessage("Mammal")) + }) + + t.Run("A non-entity that is extended by an extension with a key directive returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendInterfaceTypeDefinition(newTestNormalizer(false)), ` + interface Mammal { + name: String! + } + + extend interface Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `, nonEntityExtensionErrorMessage("Mammal")) }) } diff --git a/pkg/federation/sdlmerge/object_type_extending.go b/pkg/federation/sdlmerge/object_type_extending.go index 71223c9ba..aada85a70 100644 --- a/pkg/federation/sdlmerge/object_type_extending.go +++ b/pkg/federation/sdlmerge/object_type_extending.go @@ -3,36 +3,64 @@ package sdlmerge import ( "github.com/wundergraph/graphql-go-tools/pkg/ast" "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" ) -func newExtendObjectTypeDefinition() *extendObjectTypeDefinitionVisitor { - return &extendObjectTypeDefinitionVisitor{} +func newExtendObjectTypeDefinition(collectedEntities entitySet) *extendObjectTypeDefinitionVisitor { + return &extendObjectTypeDefinitionVisitor{ + collectedEntities: collectedEntities, + } } type extendObjectTypeDefinitionVisitor struct { - operation *ast.Document + *astvisitor.Walker + document *ast.Document + collectedEntities entitySet } func (e *extendObjectTypeDefinitionVisitor) Register(walker *astvisitor.Walker) { + e.Walker = walker walker.RegisterEnterDocumentVisitor(e) walker.RegisterEnterObjectTypeExtensionVisitor(e) } -func (e *extendObjectTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { - e.operation = operation +func (e *extendObjectTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { + e.document = operation } func (e *extendObjectTypeDefinitionVisitor) EnterObjectTypeExtension(ref int) { - nodes, exists := e.operation.Index.NodesByNameBytes(e.operation.ObjectTypeExtensionNameBytes(ref)) + nameBytes := e.document.ObjectTypeExtensionNameBytes(ref) + nodes, exists := e.document.Index.NodesByNameBytes(nameBytes) if !exists { return } + var nodeToExtend *ast.Node + isEntity := false for i := range nodes { if nodes[i].Kind != ast.NodeKindObjectTypeDefinition { continue } - e.operation.ExtendObjectTypeDefinitionByObjectTypeExtension(nodes[i].Ref, ref) + if nodeToExtend != nil { + e.StopWithExternalErr(*multipleExtensionError(isEntity, nameBytes)) + return + } + var err *operationreport.ExternalError + extension := e.document.ObjectTypeExtensions[ref] + if isEntity, err = e.collectedEntities.isExtensionForEntity(nameBytes, extension.Directives.Refs, e.document); err != nil { + e.StopWithExternalErr(*err) + return + } + nodeToExtend = &nodes[i] + if ast.IsRootType(nameBytes) { + break + } + } + + if nodeToExtend == nil { + e.StopWithExternalErr(operationreport.ErrExtensionOrphansMustResolveInSupergraph(nameBytes)) return } + + e.document.ExtendObjectTypeDefinitionByObjectTypeExtension(nodeToExtend.Ref, ref) } diff --git a/pkg/federation/sdlmerge/object_type_extending_test.go b/pkg/federation/sdlmerge/object_type_extending_test.go index b9ebe817f..df75d1758 100644 --- a/pkg/federation/sdlmerge/object_type_extending_test.go +++ b/pkg/federation/sdlmerge/object_type_extending_test.go @@ -4,89 +4,209 @@ import "testing" func TestExtendObjectType(t *testing.T) { t.Run("extend object type by field", func(t *testing.T) { - run(t, newExtendObjectTypeDefinition(), ` - type Dog { - name: String - } - extend type Dog { - favoriteToy: String - } - `, ` - type Dog { - name: String - favoriteToy: String - } - extend type Dog { - favoriteToy: String - } - `) + run(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Dog { + name: String + } + + extend type Dog { + favoriteToy: String + } + `, ` + type Dog { + name: String + favoriteToy: String + } + + extend type Dog { + favoriteToy: String + } + `) }) + t.Run("extend object type by directive", func(t *testing.T) { - run(t, newExtendObjectTypeDefinition(), ` - type Cat { - name: String - } - extend type Cat @deprecated(reason: "not as cool as dogs") - `, ` - type Cat @deprecated(reason: "not as cool as dogs") { - name: String - } - extend type Cat @deprecated(reason: "not as cool as dogs") - `) + run(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Cat { + name: String + } + + extend type Cat @deprecated(reason: "not as cool as dogs") + `, ` + type Cat @deprecated(reason: "not as cool as dogs") { + name: String + } + + extend type Cat @deprecated(reason: "not as cool as dogs") + `) }) + t.Run("extend object type by multiple field", func(t *testing.T) { - run(t, newExtendObjectTypeDefinition(), ` - type Dog { - name: String - } - extend type Dog { - favoriteToy: String - breed: String - } - `, ` - type Dog { - name: String - favoriteToy: String - breed: String - } - extend type Dog { - favoriteToy: String - breed: String - } - `) + run(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Dog { + name: String + } + + extend type Dog { + favoriteToy: String + breed: String + } + `, ` + type Dog { + name: String + favoriteToy: String + breed: String + } + + extend type Dog { + favoriteToy: String + breed: String + } + `) }) + t.Run("extend object type by multiple directives", func(t *testing.T) { - run(t, newExtendObjectTypeDefinition(), ` - type Cat { - name: String - } - extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) - `, ` - type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { - name: String - } - extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) - `) + run(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Cat { + name: String + } + + extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) + `, ` + type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { + name: String + } + + extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) + `) }) + t.Run("extend object type by complex extends", func(t *testing.T) { - run(t, newExtendObjectTypeDefinition(), ` - type Cat { - name: String - } - extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { - age: Int - breed: String - } - `, ` - type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { - name: String - age: Int - breed: String - } - extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { - age: Int - breed: String - } - `) + run(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Cat { + name: String + } + + extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { + age: Int + breed: String + } + `, ` + type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { + name: String + age: Int + breed: String + } + + extend type Cat @deprecated(reason: "not as cool as dogs") @skip(if: false) { + age: Int + breed: String + } + `) + }) + + t.Run("Extending an object that is a shared type returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Mammal { + name: String + } + + type Mammal { + name: String + } + + extend type Mammal @deprecated(reason: "not as cool as dogs") @skip(if: false) { + age: Int + breed: String + } + `, sharedTypeExtensionErrorMessage("Mammal")) + }) + + t.Run("Unresolved object extension orphan returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + extend type Mammal { + name: String! + } + `, unresolvedExtensionOrphansErrorMessage("Mammal")) + }) + + t.Run("Entity is extended successfully", func(t *testing.T) { + run(t, newExtendObjectTypeDefinition(newTestNormalizer(true)), ` + type Mammal @key(fields: "name") { + name: String! + } + + extend type Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `, ` + type Mammal @key(fields: "name") @key(fields: "name") { + + name: String! + name: String! @external + age: Int! + } + + extend type Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `) + }) + + t.Run("No key directive on entity extension returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendObjectTypeDefinition(newTestNormalizer(true)), ` + type Mammal @key(fields: "name") { + name: String! + } + + extend type Mammal { + name: String! @external + age: Int! + } + `, noKeyDirectiveErrorMessage("Mammal")) + }) + + t.Run("Non-key directive when extending an entity returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendObjectTypeDefinition(newTestNormalizer(true)), ` + type Mammal @key(fields: "name") { + name: String! + } + + extend type Mammal @deprecated { + name: String! @external + age: Int! + } + `, noKeyDirectiveErrorMessage("Mammal")) + }) + + t.Run("Extending multiple entities returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendObjectTypeDefinition(newTestNormalizer(true)), ` + type Mammal @key(fields: "name") { + name: String! + } + + type Mammal @key(fields: "name") { + name: String! + } + + extend type Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `, duplicateEntityErrorMessage("Mammal")) + }) + + t.Run("A non-entity that is extended by an extension with a key directive returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendObjectTypeDefinition(newTestNormalizer(false)), ` + type Mammal { + name: String! + } + + extend type Mammal @key(fields: "name") { + name: String! @external + age: Int! + } + `, nonEntityExtensionErrorMessage("Mammal")) }) } diff --git a/pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types.go b/pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types.go new file mode 100644 index 000000000..cc618efd7 --- /dev/null +++ b/pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types.go @@ -0,0 +1,107 @@ +package sdlmerge + +import ( + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +type removeDuplicateFieldedSharedTypesVisitor struct { + *astvisitor.Walker + document *ast.Document + sharedTypeSet map[string]fieldedSharedType + rootNodesToRemove []ast.Node + lastInputRef int + lastInterfaceRef int + lastObjectRef int +} + +func newRemoveDuplicateFieldedSharedTypesVisitor() *removeDuplicateFieldedSharedTypesVisitor { + return &removeDuplicateFieldedSharedTypesVisitor{ + nil, + nil, + make(map[string]fieldedSharedType), + nil, + ast.InvalidRef, + ast.InvalidRef, + ast.InvalidRef, + } +} + +func (r *removeDuplicateFieldedSharedTypesVisitor) Register(walker *astvisitor.Walker) { + r.Walker = walker + walker.RegisterEnterDocumentVisitor(r) + walker.RegisterEnterInputObjectTypeDefinitionVisitor(r) + walker.RegisterEnterInterfaceTypeDefinitionVisitor(r) + walker.RegisterEnterObjectTypeDefinitionVisitor(r) + walker.RegisterLeaveDocumentVisitor(r) +} + +func (r *removeDuplicateFieldedSharedTypesVisitor) EnterDocument(operation, _ *ast.Document) { + r.document = operation +} + +func (r *removeDuplicateFieldedSharedTypesVisitor) EnterInputObjectTypeDefinition(ref int) { + if ref <= r.lastInputRef { + return + } + name := r.document.InputObjectTypeDefinitionNameString(ref) + refs := r.document.InputObjectTypeDefinitions[ref].InputFieldsDefinition.Refs + input, exists := r.sharedTypeSet[name] + if exists { + if !input.areFieldsIdentical(refs) { + r.StopWithExternalErr(operationreport.ErrSharedTypesMustBeIdenticalToFederate(name)) + return + } + r.rootNodesToRemove = append(r.rootNodesToRemove, ast.Node{Kind: ast.NodeKindInputObjectTypeDefinition, Ref: ref}) + } else { + r.sharedTypeSet[name] = newFieldedSharedType(r.document, ast.NodeKindInputValueDefinition, refs) + } + r.lastInputRef = ref +} + +func (r *removeDuplicateFieldedSharedTypesVisitor) EnterInterfaceTypeDefinition(ref int) { + if ref <= r.lastInterfaceRef { + return + } + name := r.document.InterfaceTypeDefinitionNameString(ref) + interfaceType := r.document.InterfaceTypeDefinitions[ref] + refs := interfaceType.FieldsDefinition.Refs + iFace, exists := r.sharedTypeSet[name] + if exists { + if !iFace.areFieldsIdentical(refs) { + r.StopWithExternalErr(operationreport.ErrSharedTypesMustBeIdenticalToFederate(name)) + return + } + r.rootNodesToRemove = append(r.rootNodesToRemove, ast.Node{Kind: ast.NodeKindInterfaceTypeDefinition, Ref: ref}) + } else { + r.sharedTypeSet[name] = newFieldedSharedType(r.document, ast.NodeKindFieldDefinition, refs) + } + r.lastInterfaceRef = ref +} + +func (r *removeDuplicateFieldedSharedTypesVisitor) EnterObjectTypeDefinition(ref int) { + if ref <= r.lastObjectRef { + return + } + name := r.document.ObjectTypeDefinitionNameString(ref) + objectType := r.document.ObjectTypeDefinitions[ref] + refs := objectType.FieldsDefinition.Refs + object, exists := r.sharedTypeSet[name] + if exists { + if !object.areFieldsIdentical(refs) { + r.StopWithExternalErr(operationreport.ErrSharedTypesMustBeIdenticalToFederate(name)) + return + } + r.rootNodesToRemove = append(r.rootNodesToRemove, ast.Node{Kind: ast.NodeKindObjectTypeDefinition, Ref: ref}) + } else { + r.sharedTypeSet[name] = newFieldedSharedType(r.document, ast.NodeKindFieldDefinition, refs) + } + r.lastObjectRef = ref +} + +func (r *removeDuplicateFieldedSharedTypesVisitor) LeaveDocument(_, _ *ast.Document) { + if r.rootNodesToRemove != nil { + r.document.DeleteRootNodes(r.rootNodesToRemove) + } +} diff --git a/pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types_test.go b/pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types_test.go new file mode 100644 index 000000000..ec67045af --- /dev/null +++ b/pkg/federation/sdlmerge/remove_duplicate_fielded_shared_types_test.go @@ -0,0 +1,529 @@ +package sdlmerge + +import ( + "testing" +) + +func TestRemoveDuplicateFieldedValueTypes(t *testing.T) { + t.Run("Same name empty inputs are merged into a single input", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Trainer { + } + + input Trainer { + } + `, ` + input Trainer { + } + `) + }) + + t.Run("Identical same name inputs are merged into a single input", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: Int! + } + `, ` + input Trainer { + name: String! + age: Int! + } + `) + }) + + t.Run("Identical same name inputs are merged into a single input regardless of field order", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Trainer { + age: Int! + name: String! + } + + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: Int! + } + `, ` + input Trainer { + age: Int! + name: String! + } + `) + }) + + t.Run("Groups of identical same name inputs are respectively merged into single inputs", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: Int! + } + + input Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + `, ` + input Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + + input Trainer { + name: String! + age: Int! + } + `) + }) + + t.Run("Same name inputs with different nullability of fields return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String + age: Int! + } + + input Trainer { + name: String! + age: Int + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name inputs with different fields return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Trainer { + name: String + age: Int + } + + input Trainer { + name: String + age: Int + } + + input Trainer { + name: String + age: Int + badges: Int + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name inputs with a slight difference in nested field values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Pokemon { + type: [[[[Type!]]!]!]! + } + + input Pokemon { + type: [[[[Type!]]]!]! + } + + input Pokemon { + type: [[[[Type!]]!]!]! + } + `, NonIdenticalSharedTypeErrorMessage("Pokemon")) + }) + + t.Run("Same name inputs with different non-nullable field values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: Int! + } + + input Trainer { + name: String! + age: String! + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name empty interfaces are merged into a single interface", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Trainer { + } + + interface Trainer { + } + `, ` + interface Trainer { + } + `) + }) + + t.Run("Identical same name interfaces are merged into a single interface", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: Int! + } + `, ` + interface Trainer { + name: String! + age: Int! + } + `) + }) + + t.Run("Identical same name interfaces are merged into a single input regardless of field order", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Trainer { + age: Int! + name: String! + } + + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: Int! + } + `, ` + interface Trainer { + age: Int! + name: String! + } + `) + }) + + t.Run("Groups of identical same name interfaces are respectively merged into single interfaces", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: Int! + } + + interface Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + `, ` + interface Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + + interface Trainer { + name: String! + age: Int! + } + `) + }) + + t.Run("Same name interfaces with different nullability of fields return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String + age: Int! + } + + interface Trainer { + name: String! + age: Int + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name interfaces with different fields return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Trainer { + name: String + age: Int + } + + interface Trainer { + name: String + age: Int + } + + interface Trainer { + name: String + age: Int + badges: Int + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name interfaces with a slight difference in nested field values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Pokemon { + type: [[[[Type!]]!]!]! + } + + interface Pokemon { + type: [[[[Type!]]]!]! + } + + interface Pokemon { + type: [[[[Type!]]!]!]! + } + `, NonIdenticalSharedTypeErrorMessage("Pokemon")) + }) + + t.Run("Same name interfaces with different non-nullable field values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: Int! + } + + interface Trainer { + name: String! + age: String! + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name empty objects are merged into a single object", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Trainer { + } + + type Trainer { + } + `, ` + type Trainer { + } + `) + }) + + t.Run("Identical same name objects are merged into a single object", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: Int! + } + `, ` + type Trainer { + name: String! + age: Int! + } + `) + }) + + t.Run("Identical same name objects are merged into a single input regardless of field order", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Trainer { + age: Int! + name: String! + } + + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: Int! + } + `, ` + type Trainer { + age: Int! + name: String! + } + `) + }) + + t.Run("Groups of identical same name objects are respectively merged into single objects", func(t *testing.T) { + run(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: Int! + } + + type Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + `, ` + type Pokemon { + type: [Type!]! + isEvolved: Boolean! + } + + type Trainer { + name: String! + age: Int! + } + `) + }) + + t.Run("Same name objects with different nullability of fields return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String + age: Int! + } + + type Trainer { + name: String! + age: Int + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name objects with different fields return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Trainer { + name: String + age: Int + } + + type Trainer { + name: String + age: Int + } + + type Trainer { + name: String + age: Int + badges: Int + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) + + t.Run("Same name objects with a slight difference in nested field values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Pokemon { + type: [[[[Type!]]!]!]! + } + + type Pokemon { + type: [[[[Type!]]]!]! + } + + type Pokemon { + type: [[[[Type!]]!]!]! + } + `, NonIdenticalSharedTypeErrorMessage("Pokemon")) + }) + + t.Run("Same name objects with different non-nullable field values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldedSharedTypesVisitor(), ` + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: Int! + } + + type Trainer { + name: String! + age: String! + } + `, NonIdenticalSharedTypeErrorMessage("Trainer")) + }) +} diff --git a/pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types.go b/pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types.go new file mode 100644 index 000000000..7fcf8c7a7 --- /dev/null +++ b/pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types.go @@ -0,0 +1,98 @@ +package sdlmerge + +import ( + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +type removeDuplicateFieldlessSharedTypesVisitor struct { + *astvisitor.Walker + document *ast.Document + sharedTypeSet map[string]fieldlessSharedType + rootNodesToRemove []ast.Node + lastEnumRef int + lastUnionRef int + lastScalarRef int +} + +func newRemoveDuplicateFieldlessSharedTypesVisitor() *removeDuplicateFieldlessSharedTypesVisitor { + return &removeDuplicateFieldlessSharedTypesVisitor{ + nil, + nil, + make(map[string]fieldlessSharedType), + nil, + ast.InvalidRef, + ast.InvalidRef, + ast.InvalidRef, + } +} + +func (r *removeDuplicateFieldlessSharedTypesVisitor) Register(walker *astvisitor.Walker) { + r.Walker = walker + walker.RegisterEnterDocumentVisitor(r) + walker.RegisterEnterEnumTypeDefinitionVisitor(r) + walker.RegisterEnterScalarTypeDefinitionVisitor(r) + walker.RegisterEnterUnionTypeDefinitionVisitor(r) + walker.RegisterLeaveDocumentVisitor(r) +} + +func (r *removeDuplicateFieldlessSharedTypesVisitor) EnterDocument(operation, _ *ast.Document) { + r.document = operation +} + +func (r *removeDuplicateFieldlessSharedTypesVisitor) EnterEnumTypeDefinition(ref int) { + if ref <= r.lastEnumRef { + return + } + name := r.document.EnumTypeDefinitionNameString(ref) + enum, exists := r.sharedTypeSet[name] + if exists { + if !enum.areValuesIdentical(r.document.EnumTypeDefinitions[ref].EnumValuesDefinition.Refs) { + r.StopWithExternalErr(operationreport.ErrSharedTypesMustBeIdenticalToFederate(name)) + return + } + r.rootNodesToRemove = append(r.rootNodesToRemove, ast.Node{Kind: ast.NodeKindEnumTypeDefinition, Ref: ref}) + } else { + r.sharedTypeSet[name] = newEnumSharedType(r.document, ref) + } + r.lastEnumRef = ref +} + +func (r *removeDuplicateFieldlessSharedTypesVisitor) EnterScalarTypeDefinition(ref int) { + if ref <= r.lastScalarRef { + return + } + name := r.document.ScalarTypeDefinitionNameString(ref) + _, exists := r.sharedTypeSet[name] + if exists { + r.rootNodesToRemove = append(r.rootNodesToRemove, ast.Node{Kind: ast.NodeKindScalarTypeDefinition, Ref: ref}) + } else { + r.sharedTypeSet[name] = scalarSharedType{} + } + r.lastScalarRef = ref +} + +func (r *removeDuplicateFieldlessSharedTypesVisitor) EnterUnionTypeDefinition(ref int) { + if ref <= r.lastUnionRef { + return + } + name := r.document.UnionTypeDefinitionNameString(ref) + union, exists := r.sharedTypeSet[name] + if exists { + if !union.areValuesIdentical(r.document.UnionTypeDefinitions[ref].UnionMemberTypes.Refs) { + r.StopWithExternalErr(operationreport.ErrSharedTypesMustBeIdenticalToFederate(name)) + return + } + r.rootNodesToRemove = append(r.rootNodesToRemove, ast.Node{Kind: ast.NodeKindUnionTypeDefinition, Ref: ref}) + } else { + r.sharedTypeSet[name] = newUnionSharedType(r.document, ref) + } + r.lastUnionRef = ref +} + +func (r *removeDuplicateFieldlessSharedTypesVisitor) LeaveDocument(_, _ *ast.Document) { + if r.rootNodesToRemove != nil { + r.document.DeleteRootNodes(r.rootNodesToRemove) + } +} diff --git a/pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types_test.go b/pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types_test.go new file mode 100644 index 000000000..acefe2e25 --- /dev/null +++ b/pkg/federation/sdlmerge/remove_duplicate_fieldless_shared_types_test.go @@ -0,0 +1,373 @@ +package sdlmerge + +import ( + "fmt" + "testing" +) + +func TestRemoveDuplicateFieldlessSharedTypes(t *testing.T) { + t.Run("Input and output are identical when no duplications", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + `, ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + `) + }) + + t.Run("Identical same name enums are merged", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + `, ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + `) + }) + + t.Run("Identical same name enums are merged into a single input regardless of value order", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + + enum Pokemon { + SQUIRTLE, + CHARMANDER, + BULBASAUR, + } + `, ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + `) + }) + + t.Run("Same name enums with different values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + } + + enum Pokemon { + BULBASAUR, + CHARMANDER, + SQUIRTLE, + MEW, + } + `, NonIdenticalSharedTypeErrorMessage(pokemon)) + }) + + t.Run("Empty and populated same name enums return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + } + + enum Pokemon { + CHARMANDER, + SQUIRTLE, + } + `, NonIdenticalSharedTypeErrorMessage(pokemon)) + }) + + t.Run("Empty enums are merged into a single empty enum", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + } + + enum Pokemon { + } + `, ` + enum Pokemon { + } + `) + }) + + t.Run("Same name enums with no overlapping values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + } + + enum Pokemon { + SQUIRTLE, + MEW, + } + `, NonIdenticalSharedTypeErrorMessage(pokemon)) + }) + + t.Run("Same name enums with varying overlapping values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Pokemon { + BULBASAUR, + CHARMANDER, + } + + enum Pokemon { + CHARMANDER, + MEW, + } + + enum Pokemon { + BULBASAUR, + MEW, + } + + enum Pokemon { + BULBASAUR, + SQUIRTLE, + } + `, NonIdenticalSharedTypeErrorMessage(pokemon)) + }) + + t.Run("Different groups of same name enums return an error immediately upon invalidation", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + enum Cities { + CERULEAN, + SAFFRON, + } + + enum Types { + GRASS, + FIRE, + ROCK, + } + + enum Badges { + } + + enum Types { + FIRE, + WATER, + } + + enum Badges { + BOULDER, + VOLCANO, + EARTH, + } + + enum Cities { + VIRIDIAN, + SAFFRON, + CELADON, + } + + enum Types { + ROCK, + GRASS, + FIRE, + WATER, + } + + enum Badges { + MARSH, + SOUL, + VOLCANO, + THUNDER, + RAINBOW, + CASCADE, + } + + enum Badges { + VOLCANO, + RAINBOW, + BOULDER, + SOUL, + } + + enum Types { + WATER, + FIRE, + } + + enum Badges { + } + + enum Badges { + EARTH, + THUNDER, + } + + enum Cities { + } + + enum Cities { + CERULEAN, + CELADON, + } + `, NonIdenticalSharedTypeErrorMessage(types)) + }) + + t.Run("Input and output are identical when no scalar duplications", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + scalar DateTime + `, ` + scalar DateTime + `) + }) + + t.Run("Same name scalars are removed to leave only one", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + scalar DateTime + + scalar DateTime + + scalar DateTime + `, ` + scalar DateTime + `) + }) + + t.Run("Any more than one of a same name scalar are removed", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + scalar DateTime + + scalar BigInt + + scalar BigInt + + scalar CustomScalar + + scalar DateTime + + scalar UniqueScalar + + scalar BigInt + + scalar CustomScalar + + scalar CustomScalar + + scalar DateTime + + scalar CustomScalar + + scalar DateTime + `, ` + scalar DateTime + + scalar BigInt + + scalar CustomScalar + + scalar UniqueScalar + `) + }) + + t.Run("Input and output are identical when no duplications", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Types = Grass | Fire | Water + `, ` + union Types = Grass | Fire | Water + `) + }) + + t.Run("Identical same name unions are merged", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Types = Grass | Fire | Water + + union Types = Grass | Fire | Water + `, ` + union Types = Grass | Fire | Water + `) + }) + + t.Run("Identical same name unions are merged into a single input regardless of value order", func(t *testing.T) { + run(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Types = Grass | Fire | Water + + union Types = Water | Grass | Fire + `, ` + union Types = Grass | Fire | Water + `) + }) + + t.Run("Same name unions with different values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Types = Grass | Fire | Water + + union Types = Grass | Fire | Water | Rock + `, NonIdenticalSharedTypeErrorMessage(types)) + }) + + t.Run("Same name unions with no overlapping values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Types = Grass | Fire + + union Types = Water | Rock + `, NonIdenticalSharedTypeErrorMessage(types)) + }) + + t.Run("Same name unions with varying overlapping values return an error", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Types = Grass | Fire + + union Types = Fire | Water + + union Types = Rock | Grass + + union Types = Water | Fire + `, NonIdenticalSharedTypeErrorMessage(types)) + }) + + t.Run("Different groups of same name unions return an error immediately upon invalidation", func(t *testing.T) { + runAndExpectError(t, newRemoveDuplicateFieldlessSharedTypesVisitor(), ` + union Cities = Cerulean | Saffron + + union Types = Grass | Fire | Rock + + union Badges = Boulder | Volcano | Earth + + union Cities = Viridian | Saffron | Celadon + + union Types = Rock | Grass | Fire | Water + + union Badges = Marsh | Soul | Volcano | Thunder | Rainbow | Cascade + + union Badges = Volcano | Rainbow | Boulder | Soul + + union Types = Water | Fire + + union Badges = Earth | Thunder + + union Cities = Cerulean | Celadon + `, NonIdenticalSharedTypeErrorMessage(cities)) + }) +} + +const ( + cities = "Cities" + pokemon = "Pokemon" + types = "Types" +) + +func NonIdenticalSharedTypeErrorMessage(typeName string) string { + return fmt.Sprintf("the shared type named '%s' must be identical in any subgraphs to federate", typeName) +} diff --git a/pkg/federation/sdlmerge/remove_type_extensions_test.go b/pkg/federation/sdlmerge/remove_type_extensions_test.go index 1c2def439..f6aec0073 100644 --- a/pkg/federation/sdlmerge/remove_type_extensions_test.go +++ b/pkg/federation/sdlmerge/remove_type_extensions_test.go @@ -17,7 +17,7 @@ func TestRemoveTypeExtensions(t *testing.T) { favoriteToy: String } `, - newExtendObjectTypeDefinition(), + newExtendObjectTypeDefinition(newTestNormalizer(false)), newRemoveMergedTypeExtensions()) }) t.Run("remove single type extension of directive", func(t *testing.T) { @@ -31,7 +31,7 @@ func TestRemoveTypeExtensions(t *testing.T) { name: String } `, - newExtendObjectTypeDefinition(), + newExtendObjectTypeDefinition(newTestNormalizer(false)), newRemoveMergedTypeExtensions()) }) t.Run("remove multiple type extensions at once", func(t *testing.T) { @@ -49,7 +49,7 @@ func TestRemoveTypeExtensions(t *testing.T) { age: Int } `, - newExtendObjectTypeDefinition(), + newExtendObjectTypeDefinition(newTestNormalizer(false)), newRemoveMergedTypeExtensions()) }) t.Run("remove interface type extensions", func(t *testing.T) { @@ -68,7 +68,7 @@ func TestRemoveTypeExtensions(t *testing.T) { age: Int } `, - newExtendInterfaceTypeDefinition(), + newExtendInterfaceTypeDefinition(newTestNormalizer(false)), newRemoveMergedTypeExtensions()) }) t.Run("keep not merged type extension", func(t *testing.T) { @@ -81,7 +81,7 @@ func TestRemoveTypeExtensions(t *testing.T) { field: String! } `, - newExtendInterfaceTypeDefinition(), + newExtendInterfaceTypeDefinition(newTestNormalizer(false)), newRemoveMergedTypeExtensions(), ) }) diff --git a/pkg/federation/sdlmerge/scalar_type_extending.go b/pkg/federation/sdlmerge/scalar_type_extending.go new file mode 100644 index 000000000..24669a9ce --- /dev/null +++ b/pkg/federation/sdlmerge/scalar_type_extending.go @@ -0,0 +1,49 @@ +package sdlmerge + +import ( + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" +) + +func newExtendScalarTypeDefinition() *extendScalarTypeDefinitionVisitor { + return &extendScalarTypeDefinitionVisitor{} +} + +type extendScalarTypeDefinitionVisitor struct { + *astvisitor.Walker + document *ast.Document +} + +func (e *extendScalarTypeDefinitionVisitor) Register(walker *astvisitor.Walker) { + e.Walker = walker + walker.RegisterEnterDocumentVisitor(e) + walker.RegisterEnterScalarTypeExtensionVisitor(e) +} + +func (e *extendScalarTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { + e.document = operation +} + +func (e *extendScalarTypeDefinitionVisitor) EnterScalarTypeExtension(ref int) { + nodes, exists := e.document.Index.NodesByNameBytes(e.document.ScalarTypeExtensionNameBytes(ref)) + if !exists { + return + } + + hasExtended := false + for i := range nodes { + if nodes[i].Kind != ast.NodeKindScalarTypeDefinition { + continue + } + if hasExtended { + e.StopWithExternalErr(operationreport.ErrSharedTypesMustNotBeExtended(e.document.ScalarTypeExtensionNameString(ref))) + return + } + e.document.ExtendScalarTypeDefinitionByScalarTypeExtension(nodes[i].Ref, ref) + hasExtended = true + } + if !hasExtended { + e.StopWithExternalErr(operationreport.ErrExtensionOrphansMustResolveInSupergraph(e.document.ScalarTypeExtensionNameBytes(ref))) + } +} diff --git a/pkg/federation/sdlmerge/scalar_type_extending_test.go b/pkg/federation/sdlmerge/scalar_type_extending_test.go new file mode 100644 index 000000000..4dbb82dca --- /dev/null +++ b/pkg/federation/sdlmerge/scalar_type_extending_test.go @@ -0,0 +1,29 @@ +package sdlmerge + +import "testing" + +func TestExtendScalarType(t *testing.T) { + t.Run("Scalar types can be extended", func(t *testing.T) { + run(t, newExtendScalarTypeDefinition(), ` + scalar Attack + extend scalar Attack @deprecated(reason: "some reason") @skip(if: false) + `, ` + scalar Attack @deprecated(reason: "some reason") @skip(if: false) + extend scalar Attack @deprecated(reason: "some reason") @skip(if: false) + `) + }) + + t.Run("Extending a scalar that is a shared type returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendScalarTypeDefinition(), ` + scalar Attack + scalar Attack + extend scalar Attack @deprecated(reason: "some reason") @skip(if: false) + `, sharedTypeExtensionErrorMessage("Attack")) + }) + + t.Run("Unresolved scalar extension orphan returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendScalarTypeDefinition(), ` + extend scalar Badges @onScalar + `, unresolvedExtensionOrphansErrorMessage("Badges")) + }) +} diff --git a/pkg/federation/sdlmerge/sdlmerge.go b/pkg/federation/sdlmerge/sdlmerge.go index aaa0f7773..79f98a7f4 100644 --- a/pkg/federation/sdlmerge/sdlmerge.go +++ b/pkg/federation/sdlmerge/sdlmerge.go @@ -2,6 +2,9 @@ package sdlmerge import ( "fmt" + "github.com/wundergraph/graphql-go-tools/pkg/asttransform" + "github.com/wundergraph/graphql-go-tools/pkg/astvalidation" + "github.com/wundergraph/graphql-go-tools/pkg/engine/plan" "strings" "github.com/wundergraph/graphql-go-tools/pkg/ast" @@ -12,11 +15,15 @@ import ( "github.com/wundergraph/graphql-go-tools/pkg/operationreport" ) -const rootOperationTypeDefinitions = ` - type Query {} - type Mutation {} - type Subscription {} -` +const ( + rootOperationTypeDefinitions = ` + type Query {} + type Mutation {} + type Subscription {} + ` + + parseDocumentError = "parse graphql document string: %s" +) type Visitor interface { Register(walker *astvisitor.Walker) @@ -33,6 +40,12 @@ func MergeSDLs(SDLs ...string) (string, error) { rawDocs := make([]string, 0, len(SDLs)+1) rawDocs = append(rawDocs, rootOperationTypeDefinitions) rawDocs = append(rawDocs, SDLs...) + if validationError := validateSubgraphs(rawDocs[1:]); validationError != nil { + return "", validationError + } + if normalizationError := normalizeSubgraphs(rawDocs[1:]); normalizationError != nil { + return "", normalizationError + } doc, report := astparser.ParseGraphqlDocumentString(strings.Join(rawDocs, "\n")) if report.HasErrors() { @@ -56,23 +69,73 @@ func MergeSDLs(SDLs ...string) (string, error) { return out, nil } +func validateSubgraphs(subgraphs []string) error { + validator := astvalidation.NewDefinitionValidator( + astvalidation.PopulatedTypeBodies(), astvalidation.KnownTypeNames(), + ) + for _, subgraph := range subgraphs { + doc, report := astparser.ParseGraphqlDocumentString(subgraph) + if err := asttransform.MergeDefinitionWithBaseSchema(&doc); err != nil { + return err + } + if report.HasErrors() { + return fmt.Errorf(parseDocumentError, report.Error()) + } + validator.Validate(&doc, &report) + if report.HasErrors() { + return fmt.Errorf("validate schema: %s", report.Error()) + } + } + return nil +} + +func normalizeSubgraphs(subgraphs []string) error { + subgraphNormalizer := astnormalization.NewSubgraphDefinitionNormalizer() + for i, subgraph := range subgraphs { + doc, report := astparser.ParseGraphqlDocumentString(subgraph) + if report.HasErrors() { + return fmt.Errorf(parseDocumentError, report.Error()) + } + subgraphNormalizer.NormalizeDefinition(&doc, &report) + if report.HasErrors() { + return fmt.Errorf("normalize schema: %s", report.Error()) + } + out, err := astprinter.PrintString(&doc, nil) + if err != nil { + return fmt.Errorf("stringify schema: %s", err.Error()) + } + subgraphs[i] = out + } + return nil +} + type normalizer struct { walkers []*astvisitor.Walker } +type entitySet map[string]struct{} + func (m *normalizer) setupWalkers() { + collectedEntities := make(entitySet) visitorGroups := [][]Visitor{ - // visitors for extending objects and interfaces { - newExtendInterfaceTypeDefinition(), + newCollectEntitiesVisitor(collectedEntities), + }, + { + newExtendEnumTypeDefinition(), + newExtendInputObjectTypeDefinition(), + newExtendInterfaceTypeDefinition(collectedEntities), + newExtendScalarTypeDefinition(), newExtendUnionTypeDefinition(), - newExtendObjectTypeDefinition(), + newExtendObjectTypeDefinition(collectedEntities), newRemoveEmptyObjectTypeDefinition(), newRemoveMergedTypeExtensions(), }, - // visitors for clean up federated duplicated fields and directives + // visitors for cleaning up federated duplicated fields and directives { newRemoveFieldDefinitions("external"), + newRemoveDuplicateFieldedSharedTypesVisitor(), + newRemoveDuplicateFieldlessSharedTypesVisitor(), newRemoveInterfaceDefinitionDirective("key"), newRemoveObjectTypeDefinitionDirective("key"), newRemoveFieldDefinitionDirective("provides", "requires"), @@ -100,3 +163,42 @@ func (m *normalizer) normalize(operation *ast.Document) error { return nil } + +func (e entitySet) isExtensionForEntity(nameBytes []byte, directiveRefs []int, document *ast.Document) (bool, *operationreport.ExternalError) { + name := string(nameBytes) + hasDirectives := len(directiveRefs) > 0 + if _, exists := e[name]; !exists { + if !hasDirectives || !isEntityExtension(directiveRefs, document) { + return false, nil + } + err := operationreport.ErrExtensionWithKeyDirectiveMustExtendEntity(name) + return false, &err + } + if !hasDirectives { + err := operationreport.ErrEntityExtensionMustHaveKeyDirective(name) + return false, &err + } + if isEntityExtension(directiveRefs, document) { + return true, nil + } + err := operationreport.ErrEntityExtensionMustHaveKeyDirective(name) + return false, &err +} + +func isEntityExtension(directiveRefs []int, document *ast.Document) bool { + for _, directiveRef := range directiveRefs { + if document.DirectiveNameString(directiveRef) == plan.FederationKeyDirectiveName { + return true + } + } + return false +} + +func multipleExtensionError(isEntity bool, nameBytes []byte) *operationreport.ExternalError { + if isEntity { + err := operationreport.ErrEntitiesMustNotBeDuplicated(string(nameBytes)) + return &err + } + err := operationreport.ErrSharedTypesMustNotBeExtended(string(nameBytes)) + return &err +} diff --git a/pkg/federation/sdlmerge/sdlmerge_test.go b/pkg/federation/sdlmerge/sdlmerge_test.go index 6aab0f04c..995e90ec0 100644 --- a/pkg/federation/sdlmerge/sdlmerge_test.go +++ b/pkg/federation/sdlmerge/sdlmerge_test.go @@ -1,6 +1,7 @@ package sdlmerge import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -11,6 +12,15 @@ import ( "github.com/wundergraph/graphql-go-tools/pkg/operationreport" ) +var testEntitySet = entitySet{"Mammal": {}} + +func newTestNormalizer(withEntity bool) entitySet { + if withEntity { + return testEntitySet + } + return make(entitySet) +} + type composeVisitor []Visitor func (c composeVisitor) Register(walker *astvisitor.Walker) { @@ -20,7 +30,6 @@ func (c composeVisitor) Register(walker *astvisitor.Walker) { } var run = func(t *testing.T, visitor Visitor, operation, expectedOutput string) { - operationDocument := unsafeparser.ParseGraphqlDocumentString(operation) expectedOutputDocument := unsafeparser.ParseGraphqlDocumentString(expectedOutput) report := operationreport.Report{} @@ -40,6 +49,27 @@ var run = func(t *testing.T, visitor Visitor, operation, expectedOutput string) assert.Equal(t, want, got) } +var runAndExpectError = func(t *testing.T, visitor Visitor, operation, expectedError string) { + operationDocument := unsafeparser.ParseGraphqlDocumentString(operation) + report := operationreport.Report{} + walker := astvisitor.NewWalker(48) + + visitor.Register(&walker) + + walker.Walk(&operationDocument, nil, &report) + + var got string + if report.HasErrors() { + if report.InternalErrors == nil { + got = report.ExternalErrors[0].Message + } else { + got = report.InternalErrors[0].Error() + } + } + + assert.Equal(t, expectedError, got) +} + func runMany(t *testing.T, operation, expectedOutput string, visitors ...Visitor) { run(t, composeVisitor(visitors), operation, expectedOutput) } @@ -68,20 +98,48 @@ func TestMergeSDLs(t *testing.T) { } } + runMergeTestAndExpectError := func(expectedError string, sdls ...string) func(t *testing.T) { + return func(t *testing.T) { + _, err := MergeSDLs(sdls...) + + assert.Equal(t, expectedError, err.Error()) + } + } + t.Run("should merge all sdls successfully", runMergeTest( federatedSchema, accountSchema, productSchema, reviewSchema, likeSchema, disLikeSchema, paymentSchema, onlinePaymentSchema, classicPaymentSchema, )) - t.Run("should merge product and review sdl and leave `extend type User` in the schema", runMergeTest( - productAndReviewFederatedSchema, + t.Run("When merging product and review, the unresolved orphan extension for User will return an error", runMergeTestAndExpectError( + unresolvedExtensionOrphansMergeErrorMessage("User"), productSchema, reviewSchema, )) - t.Run("should merge product and extends directives sdl and leave the type extension definition in the schema", runMergeTest( - productAndExtendsDirectivesFederatedSchema, + t.Run("When merging product and extendsDirectives, the unresolved orphan extension for User will return an error", runMergeTestAndExpectError( + unresolvedExtensionOrphansMergeErrorMessage("User"), productSchema, extendsDirectivesSchema, )) + + t.Run("Non-identical duplicate enums should return an error", runMergeTestAndExpectError( + nonIdenticalSharedTypeMergeErrorMessage("Satisfaction"), + productSchema, negativeTestingLikeSchema, + )) + + t.Run("Non-identical duplicate unions should return an error", runMergeTestAndExpectError( + nonIdenticalSharedTypeMergeErrorMessage("AlphaNumeric"), + accountSchema, negativeTestingReviewSchema, + )) + + t.Run("Entity duplicates should return an error", runMergeTestAndExpectError( + duplicateEntityMergeErrorMessage("User"), + accountSchema, negativeTestingAccountSchema, + )) + + t.Run("The first type encountered without a body should return an error", runMergeTestAndExpectError( + emptyTypeBodyErrorMessage("object", "Message"), + accountSchema, negativeTestingProductSchema, + )) } const ( @@ -89,88 +147,381 @@ const ( extend type Query { me: User } - + + union AlphaNumeric = Int | String | Float + + scalar DateTime + + scalar CustomScalar + type User @key(fields: "id") { id: ID! username: String! + created: DateTime! + reputation: CustomScalar! + } + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + } + ` + + negativeTestingAccountSchema = ` + extend type Query { + me: User + } + + union AlphaNumeric = Int | String | Float + + scalar DateTime + + scalar CustomScalar + + type User { + id: ID! + username: String! + created: DateTime! + reputation: CustomScalar! + } + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, } ` + productSchema = ` + enum Satisfaction { + UNHAPPY, + HAPPY, + NEUTRAL, + } + + scalar CustomScalar + + extend type Query { + topProducts(first: Int = 5): [Product] + } + + enum Department { + COSMETICS, + ELECTRONICS, + GROCERIES, + } + + interface ProductInfo { + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + + scalar BigInt + + type Product implements ProductInfo @key(fields: "upc") { + upc: String! + name: String! + price: Int! + worth: BigInt! + reputation: CustomScalar! + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + + union AlphaNumeric = Int | String | Float + ` + + negativeTestingProductSchema = ` + enum Satisfaction { + UNHAPPY, + HAPPY, + NEUTRAL, + } + + scalar CustomScalar + extend type Query { topProducts(first: Int = 5): [Product] } + + enum Department { + COSMETICS, + ELECTRONICS, + GROCERIES, + } + + interface ProductInfo { + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + + type Message { + } + + scalar BigInt - type Product @key(fields: "upc") { + type Product implements ProductInfo @key(fields: "upc") { upc: String! name: String! price: Int! + worth: BigInt! + reputation: CustomScalar! + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + + extend type Message { + content: String! } + + union AlphaNumeric = Int | String | Float ` reviewSchema = ` + scalar DateTime + + input ReviewInput { + body: String! + author: User! @provides(fields: "username") + product: Product! + updated: DateTime! + inputType: AlphaNumeric! + } + type Review { + id: ID! + created: DateTime! body: String! author: User! @provides(fields: "username") product: Product! + updated: DateTime! + inputType: AlphaNumeric! } + type Query { + getReview(id: ID!): Review + } + + type Mutation { + createReview(input: ReviewInput): Review + updateReview(id: ID!, input: ReviewInput): Review + } + + enum Department { + GROCERIES, + COSMETICS, + ELECTRONICS, + } + extend type User @key(fields: "id") { id: ID! @external reviews: [Review] } - - extend type Product @key(fields: "upc") { + + scalar BigInt + + extend type Product implements ProductInfo @key(fields: "upc") { upc: String! @external name: String! @external reviews: [Review] @requires(fields: "name") + sales: BigInt! + } + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + } + + extend type Subscription { + review: Review! + } + + interface ProductInfo { + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + ` + + negativeTestingReviewSchema = ` + scalar DateTime + + input ReviewInput { + body: String! + author: User! @provides(fields: "username") + product: Product! + updated: DateTime! + inputType: AlphaNumeric! + } + + type Review { + id: ID! + created: DateTime! + body: String! + author: User! @provides(fields: "username") + product: Product! + updated: DateTime! + inputType: AlphaNumeric! + } + + type Query { + getReview(id: ID!): Review + } + + type Mutation { + createReview(input: ReviewInput): Review + updateReview(id: ID!, input: ReviewInput): Review + } + + interface ProductInfo { + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + + enum Department { + COSMETICS, + ELECTRONICS, + GROCERIES, + } + + extend type User @key(fields: "id") { + id: ID! @external + reviews: [Review] + } + + scalar BigInt + + union AlphaNumeric = BigInt | String + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, } extend type Subscription { review: Review! } ` + likeSchema = ` + scalar DateTime + type Like @key(fields: "id") { id: ID! productId: ID! userId: ID! + date: DateTime! } + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + } + + type Query { + likesCount(productID: ID!): Int! + likes(productID: ID!): [Like]! + } + ` + negativeTestingLikeSchema = ` + scalar DateTime + + type Like @key(fields: "id") { + id: ID! + productId: ID! + userId: ID! + date: DateTime! + } + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + DEVASTATED, + } + type Query { likesCount(productID: ID!): Int! likes(productID: ID!): [Like]! } ` + disLikeSchema = ` type Like @key(fields: "id") @extends { id: ID! @external isDislike: Boolean! } + + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + } + ` paymentSchema = ` + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + } + interface PaymentType { name: String! } ` onlinePaymentSchema = ` + extend enum Satisfaction { + UNHAPPY + } + + scalar DateTime + + union AlphaNumeric = Int | String + + scalar BigInt + interface PaymentType @extends { email: String! + date: DateTime! + amount: BigInt! + } + + extend union AlphaNumeric = Float + + enum Satisfaction { + HAPPY + NEUTRAL } ` classicPaymentSchema = ` + union AlphaNumeric = Int | String | Float + + scalar CustomScalar + extend interface PaymentType { number: String! + reputation: CustomScalar! } ` extendsDirectivesSchema = ` + scalar DateTime + type Comment { body: String! author: User! + created: DateTime! } - + type User @extends @key(fields: "id") { id: ID! @external comments: [Comment] } - + + union AlphaNumeric = Int | String | Float + interface PaymentType @extends { name: String! } @@ -179,97 +530,134 @@ const ( type Query { me: User topProducts(first: Int = 5): [Product] + getReview(id: ID!): Review likesCount(productID: ID!): Int! likes(productID: ID!): [Like]! } + + type Mutation { + createReview(input: ReviewInput): Review + updateReview(id: ID!, input: ReviewInput): Review + } type Subscription { review: Review! } + + union AlphaNumeric = Int | String | Float + + scalar DateTime + + scalar CustomScalar type User { id: ID! username: String! + created: DateTime! + reputation: CustomScalar! reviews: [Review] } - type Product { + enum Satisfaction { + HAPPY, + NEUTRAL, + UNHAPPY, + } + + enum Department { + COSMETICS, + ELECTRONICS, + GROCERIES, + } + + interface ProductInfo { + departments: [Department!]! + averageSatisfaction: Satisfaction! + } + + scalar BigInt + + type Product implements ProductInfo { upc: String! name: String! price: Int! + worth: BigInt! + reputation: CustomScalar! + departments: [Department!]! + averageSatisfaction: Satisfaction! reviews: [Review] + sales: BigInt! + } + + input ReviewInput { + body: String! + author: User! @provides(fields: "username") + product: Product! + updated: DateTime! + inputType: AlphaNumeric! } type Review { + id: ID! + created: DateTime! body: String! author: User! product: Product! + updated: DateTime! + inputType: AlphaNumeric! } + type Like { id: ID! productId: ID! userId: ID! + date: DateTime! isDislike: Boolean! } interface PaymentType { name: String! email: String! + date: DateTime! + amount: BigInt! number: String! + reputation: CustomScalar! } ` +) - productAndReviewFederatedSchema = ` - type Query { - topProducts(first: Int = 5): [Product] - } +func nonIdenticalSharedTypeMergeErrorMessage(typeName string) string { + return fmt.Sprintf("merge ast: walk: external: the shared type named '%s' must be identical in any subgraphs to federate, locations: [], path: []", typeName) +} - type Subscription { - review: Review! - } - - type Product { - upc: String! - name: String! - price: Int! - reviews: [Review] - } +func duplicateEntityMergeErrorMessage(typeName string) string { + return fmt.Sprintf("merge ast: walk: external: the entity named '%s' is defined in the subgraph(s) more than once, locations: [], path: []", typeName) +} - type Review { - body: String! - author: User! - product: Product! - } - - extend type User @key(fields: "id") { - id: ID! @external - reviews: [Review] - } - ` +func sharedTypeExtensionErrorMessage(typeName string) string { + return fmt.Sprintf("the type named '%s' cannot be extended because it is a shared type", typeName) +} - productAndExtendsDirectivesFederatedSchema = ` - type Query { - topProducts(first: Int = 5): [Product] - } - - type Product { - upc: String! - name: String! - price: Int! - } +func emptyTypeBodyErrorMessage(definitionType, typeName string) string { + return fmt.Sprintf("validate schema: external: the %s named '%s' is invalid due to an empty body, locations: [], path: []", definitionType, typeName) +} - type Comment { - body: String! - author: User! - } +func unresolvedExtensionOrphansErrorMessage(typeName string) string { + return fmt.Sprintf("the extension orphan named '%s' was never resolved in the supergraph", typeName) +} - extend type User @key(fields: "id") { - id: ID! @external - comments: [Comment] - } +func unresolvedExtensionOrphansMergeErrorMessage(typeName string) string { + return fmt.Sprintf("merge ast: walk: external: the extension orphan named '%s' was never resolved in the supergraph, locations: [], path: []", typeName) +} - extend interface PaymentType { - name: String! - } - ` -) +func noKeyDirectiveErrorMessage(typeName string) string { + return fmt.Sprintf("an extension of the entity named '%s' does not have a key directive", typeName) +} + +func nonEntityExtensionErrorMessage(typeName string) string { + return fmt.Sprintf("the extension named '%s' has a key directive but there is no entity of the same name", typeName) +} + +func duplicateEntityErrorMessage(typeName string) string { + return fmt.Sprintf("the entity named '%s' is defined in the subgraph(s) more than once", typeName) +} diff --git a/pkg/federation/sdlmerge/shared_types.go b/pkg/federation/sdlmerge/shared_types.go new file mode 100644 index 000000000..83cde0390 --- /dev/null +++ b/pkg/federation/sdlmerge/shared_types.go @@ -0,0 +1,168 @@ +package sdlmerge + +import "github.com/wundergraph/graphql-go-tools/pkg/ast" + +type fieldlessSharedType interface { + areValuesIdentical(valueRefsToCompare []int) bool + valueRefs() []int + valueName(ref int) string +} + +func createValueSet(f fieldlessSharedType) map[string]bool { + valueSet := make(map[string]bool) + for _, valueRef := range f.valueRefs() { + valueSet[f.valueName(valueRef)] = true + } + return valueSet +} + +type fieldedSharedType struct { + document *ast.Document + fieldKind ast.NodeKind + fieldRefs []int + fieldSet map[string]int +} + +func newFieldedSharedType(document *ast.Document, fieldKind ast.NodeKind, fieldRefs []int) fieldedSharedType { + f := fieldedSharedType{ + document, + fieldKind, + fieldRefs, + nil, + } + f.createFieldSet() + return f +} + +func (f fieldedSharedType) areFieldsIdentical(fieldRefsToCompare []int) bool { + if len(f.fieldRefs) != len(fieldRefsToCompare) { + return false + } + for _, fieldRef := range fieldRefsToCompare { + actualFieldName := f.fieldName(fieldRef) + expectedTypeRef, exists := f.fieldSet[actualFieldName] + if !exists { + return false + } + actualTypeRef := f.fieldTypeRef(fieldRef) + if !f.document.TypesAreCompatibleDeep(expectedTypeRef, actualTypeRef) { + return false + } + } + return true +} + +func (f *fieldedSharedType) createFieldSet() { + fieldSet := make(map[string]int) + for _, fieldRef := range f.fieldRefs { + fieldSet[f.fieldName(fieldRef)] = f.fieldTypeRef(fieldRef) + } + f.fieldSet = fieldSet +} + +func (f fieldedSharedType) fieldName(ref int) string { + switch f.fieldKind { + case ast.NodeKindInputValueDefinition: + return f.document.InputValueDefinitionNameString(ref) + default: + return f.document.FieldDefinitionNameString(ref) + } +} + +func (f fieldedSharedType) fieldTypeRef(ref int) int { + switch f.fieldKind { + case ast.NodeKindInputValueDefinition: + return f.document.InputValueDefinitions[ref].Type + default: + return f.document.FieldDefinitions[ref].Type + } +} + +type enumSharedType struct { + *ast.EnumTypeDefinition + document *ast.Document + valueSet map[string]bool +} + +func newEnumSharedType(document *ast.Document, ref int) enumSharedType { + e := enumSharedType{ + &document.EnumTypeDefinitions[ref], + document, + nil, + } + e.valueSet = createValueSet(e) + return e +} + +func (e enumSharedType) areValuesIdentical(valueRefsToCompare []int) bool { + if len(e.valueRefs()) != len(valueRefsToCompare) { + return false + } + for _, valueRefToCompare := range valueRefsToCompare { + name := e.valueName(valueRefToCompare) + if !e.valueSet[name] { + return false + } + } + return true +} + +func (e enumSharedType) valueRefs() []int { + return e.EnumValuesDefinition.Refs +} + +func (e enumSharedType) valueName(ref int) string { + return e.document.EnumValueDefinitionNameString(ref) +} + +type unionSharedType struct { + *ast.UnionTypeDefinition + document *ast.Document + valueSet map[string]bool +} + +func newUnionSharedType(document *ast.Document, ref int) unionSharedType { + u := unionSharedType{ + &document.UnionTypeDefinitions[ref], + document, + nil, + } + u.valueSet = createValueSet(u) + return u +} + +func (u unionSharedType) areValuesIdentical(valueRefsToCompare []int) bool { + if len(u.valueRefs()) != len(valueRefsToCompare) { + return false + } + for _, refToCompare := range valueRefsToCompare { + name := u.valueName(refToCompare) + if !u.valueSet[name] { + return false + } + } + return true +} + +func (u unionSharedType) valueRefs() []int { + return u.UnionMemberTypes.Refs +} + +func (u unionSharedType) valueName(ref int) string { + return u.document.TypeNameString(ref) +} + +type scalarSharedType struct { +} + +func (_ scalarSharedType) areValuesIdentical(_ []int) bool { + return true +} + +func (_ scalarSharedType) valueRefs() []int { + return nil +} + +func (_ scalarSharedType) valueName(_ int) string { + return "" +} diff --git a/pkg/federation/sdlmerge/union_type_extending.go b/pkg/federation/sdlmerge/union_type_extending.go index 1ae8d5fe5..49624daaf 100644 --- a/pkg/federation/sdlmerge/union_type_extending.go +++ b/pkg/federation/sdlmerge/union_type_extending.go @@ -3,6 +3,7 @@ package sdlmerge import ( "github.com/wundergraph/graphql-go-tools/pkg/ast" "github.com/wundergraph/graphql-go-tools/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/pkg/operationreport" ) func newExtendUnionTypeDefinition() *extendUnionTypeDefinitionVisitor { @@ -10,30 +11,40 @@ func newExtendUnionTypeDefinition() *extendUnionTypeDefinitionVisitor { } type extendUnionTypeDefinitionVisitor struct { - operation *ast.Document + *astvisitor.Walker + document *ast.Document } func (e *extendUnionTypeDefinitionVisitor) Register(walker *astvisitor.Walker) { + e.Walker = walker walker.RegisterEnterDocumentVisitor(e) walker.RegisterEnterUnionTypeExtensionVisitor(e) } -func (e *extendUnionTypeDefinitionVisitor) EnterDocument(operation, definition *ast.Document) { - e.operation = operation +func (e *extendUnionTypeDefinitionVisitor) EnterDocument(operation, _ *ast.Document) { + e.document = operation } func (e *extendUnionTypeDefinitionVisitor) EnterUnionTypeExtension(ref int) { - - nodes, exists := e.operation.Index.NodesByNameBytes(e.operation.UnionTypeExtensionNameBytes(ref)) + nodes, exists := e.document.Index.NodesByNameBytes(e.document.UnionTypeExtensionNameBytes(ref)) if !exists { return } + hasExtended := false for i := range nodes { if nodes[i].Kind != ast.NodeKindUnionTypeDefinition { continue } - e.operation.ExtendUnionTypeDefinitionByUnionTypeExtension(nodes[i].Ref, ref) - return + if hasExtended { + e.StopWithExternalErr(operationreport.ErrSharedTypesMustNotBeExtended(e.document.UnionTypeExtensionNameString(ref))) + return + } + e.document.ExtendUnionTypeDefinitionByUnionTypeExtension(nodes[i].Ref, ref) + hasExtended = true + } + + if !hasExtended { + e.StopWithExternalErr(operationreport.ErrExtensionOrphansMustResolveInSupergraph(e.document.UnionTypeExtensionNameBytes(ref))) } } diff --git a/pkg/federation/sdlmerge/union_type_extending_test.go b/pkg/federation/sdlmerge/union_type_extending_test.go index c831c2836..9db731293 100644 --- a/pkg/federation/sdlmerge/union_type_extending_test.go +++ b/pkg/federation/sdlmerge/union_type_extending_test.go @@ -1,35 +1,71 @@ package sdlmerge -import "testing" +import ( + "testing" +) func TestExtendUnionType(t *testing.T) { t.Run("extend union types", func(t *testing.T) { run(t, newExtendUnionTypeDefinition(), ` - type Dog { - name: String - } - union Animal = Dog - - type Cat { - name: String - } - type Bird { - name: String - } - extend union Animal = Bird | Cat - `, ` - type Dog { - name: String - } - union Animal = Dog | Bird | Cat - - type Cat { - name: String - } - type Bird { - name: String - } - extend union Animal = Bird | Cat - `) + type Dog { + name: String + } + + union Animal = Dog + + type Cat { + name: String + } + + type Bird { + name: String + } + + extend union Animal = Bird | Cat + `, ` + type Dog { + name: String + } + + union Animal = Dog | Bird | Cat + + type Cat { + name: String + } + + type Bird { + name: String + } + + extend union Animal = Bird | Cat + `) + }) + + t.Run("Extending a union that is a shared type returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendUnionTypeDefinition(), ` + type Dog { + name: String + } + + union Animal = Dog + + type Cat { + name: String + } + + type Bird { + name: String + } + + union Animal = Dog + + extend union Animal = Bird | Cat + `, sharedTypeExtensionErrorMessage("Animal")) + }) + + t.Run("Unresolved union extension orphan returns an error", func(t *testing.T) { + runAndExpectError(t, newExtendUnionTypeDefinition(), ` + extend union Badges = Boulder + `, unresolvedExtensionOrphansErrorMessage("Badges")) }) } diff --git a/pkg/operationreport/externalerror.go b/pkg/operationreport/externalerror.go index be4632090..44a3a9aff 100644 --- a/pkg/operationreport/externalerror.go +++ b/pkg/operationreport/externalerror.go @@ -244,6 +244,11 @@ func ErrEnumValueNameMustBeUnique(enumName, enumValueName ast.ByteSlice) (err Ex return err } +func ErrUnionMembersMustBeUnique(unionName, memberName ast.ByteSlice) (err ExternalError) { + err.Message = fmt.Sprintf("union member '%s.%s' can only be defined once", unionName, memberName) + return err +} + func ErrTransitiveInterfaceNotImplemented(typeName, transitiveInterfaceName ast.ByteSlice) (err ExternalError) { err.Message = fmt.Sprintf("type %s does not implement transitive interface %s", typeName, transitiveInterfaceName) return err @@ -263,3 +268,38 @@ func ErrImplementingTypeDoesNotHaveFields(typeName ast.ByteSlice) (err ExternalE err.Message = fmt.Sprintf("type '%s' implements an interface but does not have any fields defined", typeName) return err } + +func ErrSharedTypesMustBeIdenticalToFederate(typeName string) (err ExternalError) { + err.Message = fmt.Sprintf("the shared type named '%s' must be identical in any subgraphs to federate", typeName) + return err +} + +func ErrEntitiesMustNotBeDuplicated(typeName string) (err ExternalError) { + err.Message = fmt.Sprintf("the entity named '%s' is defined in the subgraph(s) more than once", typeName) + return err +} + +func ErrSharedTypesMustNotBeExtended(typeName string) (err ExternalError) { + err.Message = fmt.Sprintf("the type named '%s' cannot be extended because it is a shared type", typeName) + return err +} + +func ErrExtensionOrphansMustResolveInSupergraph(extensionNameBytes []byte) (err ExternalError) { + err.Message = fmt.Sprintf("the extension orphan named '%s' was never resolved in the supergraph", extensionNameBytes) + return err +} + +func ErrTypeBodyMustNotBeEmpty(definitionType, typeName string) (err ExternalError) { + err.Message = fmt.Sprintf("the %s named '%s' is invalid due to an empty body", definitionType, typeName) + return err +} + +func ErrEntityExtensionMustHaveKeyDirective(typeName string) (err ExternalError) { + err.Message = fmt.Sprintf("an extension of the entity named '%s' does not have a key directive", typeName) + return err +} + +func ErrExtensionWithKeyDirectiveMustExtendEntity(typeName string) (err ExternalError) { + err.Message = fmt.Sprintf("the extension named '%s' has a key directive but there is no entity of the same name", typeName) + return err +}