Skip to content
This repository has been archived by the owner on Oct 26, 2020. It is now read-only.

Validation for non-breakable chains of circular references in Input Objects #48

Merged
merged 3 commits into from Mar 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,6 +1,6 @@
package sangria.schema

import sangria.ast.{AstLocation, Document, ObjectTypeDefinition, ObjectTypeExtensionDefinition, UnionTypeDefinition, UnionTypeExtensionDefinition}
import sangria.ast.{AstLocation, Document, NamedType, NotNullType, ObjectTypeDefinition, ObjectTypeExtensionDefinition, UnionTypeDefinition, UnionTypeExtensionDefinition}

import language.higherKinds
import sangria.execution._
Expand All @@ -24,7 +24,8 @@ object SchemaValidationRule {
EnumValueReservedNameValidator,
ContainerMembersValidator,
ValidNamesValidator,
IntrospectionNamesValidator)
IntrospectionNamesValidator,
InputObjectTypeRecursionValidator)

val default: List[SchemaValidationRule] = List(
DefaultValuesValidationRule,
Expand Down Expand Up @@ -410,6 +411,29 @@ object EnumValueReservedNameValidator extends SchemaElementValidator {
else Vector.empty
}

object InputObjectTypeRecursionValidator extends SchemaElementValidator {
override def validateInputObjectType(schema: Schema[_, _], tpe: InputObjectType[_]): Vector[Violation] = {
containsRecursiveInputObject(tpe.namedType.name, List(), schema, tpe)
}

private def containsRecursiveInputObject(rootTypeName: String, path: List[String], schema: Schema[_, _], tpe: InputObjectType[_]): Vector[Violation] = {
val recursiveFields = tpe.fields.filter(childField => childField.fieldType.namedType.name == rootTypeName && !childField.fieldType.isOptional && !childField.fieldType.isList)
if (recursiveFields.nonEmpty) {
recursiveFields.flatMap(field => Vector(InputObjectTypeRecursion(tpe.name, field.name, path, None, Nil))).toVector
} else {
val childTypesToCheck = tpe.fields.filter(field => !field.fieldType.isOptional && !field.fieldType.isList && field.fieldType.isInstanceOf[InputObjectType[_]])
childTypesToCheck.foldLeft(Vector.empty[Violation]) { case (acc, field) =>
schema.getInputType(NotNullType(NamedType(field.fieldType.namedType.name))) match {
case Some(objectType: InputObjectType[_]) if objectType != tpe =>
val updatedPath = path :+ field.name
acc ++ containsRecursiveInputObject(rootTypeName, updatedPath, schema, objectType)
case _ => acc
}
}
}
}
}

trait SchemaElementValidator {
def validateUnionType(schema: Schema[_, _], tpe: UnionType[_]): Vector[Violation] = Vector.empty

Expand Down
Expand Up @@ -608,3 +608,7 @@ case class ExistingTypeViolation(typeName: String, sourceMapper: Option[SourceMa
case class InvalidTypeUsageViolation(expectedTypeKind: String, tpe: String, sourceMapper: Option[SourceMapper], locations: List[AstLocation]) extends AstNodeViolation {
lazy val simpleErrorMessage = s"Type '$tpe' is not an $expectedTypeKind type."
}

case class InputObjectTypeRecursion(name: String, fieldName: String, path: List[String], sourceMapper: Option[SourceMapper], locations: List[AstLocation]) extends AstNodeViolation {
lazy val simpleErrorMessage: String = s"Cannot reference InputObjectType '$name' within itself through a series of non-null fields: '$fieldName${if (path.isEmpty) "" else "."}${path.mkString(".")}'."
}
Expand Up @@ -882,6 +882,116 @@ class AstSchemaMaterializerSpec extends WordSpec with Matchers with FutureResult
error.getMessage should include ("Object type 'Query' can include field 'field1' only once.")
}

"accepts an Input Object with breakable circular reference" in {
val ast =
graphql"""
schema {
query: Query
}

type Query {
field(arg: SomeInputObject): String
}

input SomeInputObject {
self: SomeInputObject
arrayOfSelf: [SomeInputObject]
nonNullArrayOfSelf: [SomeInputObject]!
nonNullArrayOfNonNullSelf: [SomeInputObject!]!
intermediateSelf: AnotherInputObject
}

input AnotherInputObject {
parent: SomeInputObject
}
"""

noException should be thrownBy (Schema.buildFromAst(ast))
}

"rejects an Input Object with non-breakable circular reference" in {
val ast =
graphql"""
schema {
query: Query
}

type Query {
field(arg: SomeInputObject): String
}

input SomeInputObject {
nonNullSelf: SomeInputObject!
}
"""

val error = intercept [SchemaValidationException] (Schema.buildFromAst(ast))

error.getMessage should include ("Cannot reference InputObjectType 'SomeInputObject' within itself through a series of non-null fields: 'nonNullSelf'.")
}

"rejects Input Objects with non-breakable circular reference spread across them" in {
val ast =
graphql"""
schema {
query: Query
}

type Query {
field(arg: SomeInputObject): String
}

input SomeInputObject {
startLoop: AnotherInputObject!
}

input AnotherInputObject {
nextInLoop: YetAnotherInputObject!
}

input YetAnotherInputObject {
closeLoop: SomeInputObject!
}
"""

val error = intercept [SchemaValidationException] (Schema.buildFromAst(ast))

error.getMessage should include ("Cannot reference InputObjectType 'SomeInputObject' within itself through a series of non-null fields: 'startLoop.nextInLoop.closeLoop'.")
}

"rejects Input Objects with multiple non-breakable circular reference" in {
val ast =
graphql"""
schema {
query: Query
}

type Query {
field(arg: SomeInputObject): String
}

input SomeInputObject {
startLoop: AnotherInputObject!
}

input AnotherInputObject {
closeLoop: SomeInputObject!
startSecondLoop: YetAnotherInputObject!
}

input YetAnotherInputObject {
closeSecondLoop: AnotherInputObject!
nonNullSelf: YetAnotherInputObject!
}
"""

val error = intercept [SchemaValidationException] (Schema.buildFromAst(ast))

error.getMessage should include ("Cannot reference InputObjectType 'SomeInputObject' within itself through a series of non-null fields: 'startLoop.closeLoop'.")
error.getMessage should include ("Cannot reference InputObjectType 'AnotherInputObject' within itself through a series of non-null fields: 'closeLoop.startLoop'.")
error.getMessage should include ("Cannot reference InputObjectType 'YetAnotherInputObject' within itself through a series of non-null fields: 'nonNullSelf'.")
}

"don't allow to have extensions on non-existing types" in {
val ast =
graphql"""
Expand Down