From 4f38dff115b0596f348d766b2ee7a4c411f612af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Raddum=20Berg?= Date: Tue, 7 May 2024 15:02:55 +0200 Subject: [PATCH] Add derivation of `Schema` for union types (closes #1926) --- .../caliban/schema/SchemaDerivation.scala | 2 + .../caliban/schema/TypeUnionDerivation.scala | 76 +++++++++++++++++++ .../caliban/schema/Scala3DerivesSpec.scala | 50 ++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 core/src/main/scala-3/caliban/schema/TypeUnionDerivation.scala diff --git a/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala b/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala index 4262eb5f2..22151b070 100644 --- a/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala +++ b/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala @@ -140,6 +140,8 @@ trait SchemaDerivation[R] extends CommonSchemaDerivation { inline def genDebug[R, A]: Schema[R, A] = PrintDerived(derived[R, A]) + inline def unionType[T]: Schema[R, T] = ${ TypeUnionDerivation.typeUnionSchema[R, T] } + final lazy val auto = new AutoSchemaDerivation[Any] {} final class SemiAuto[A](impl: Schema[R, A]) extends Schema[R, A] { diff --git a/core/src/main/scala-3/caliban/schema/TypeUnionDerivation.scala b/core/src/main/scala-3/caliban/schema/TypeUnionDerivation.scala new file mode 100644 index 000000000..ceed327e2 --- /dev/null +++ b/core/src/main/scala-3/caliban/schema/TypeUnionDerivation.scala @@ -0,0 +1,76 @@ +package caliban.schema + +import caliban.introspection.adt.__Type + +import scala.quoted.* + +object TypeUnionDerivation { + inline def derived[R, T]: Schema[R, T] = ${ typeUnionSchema[R, T] } + + def typeUnionSchema[R: Type, T: Type](using quotes: Quotes): Expr[Schema[R, T]] = { + import quotes.reflect.* + + class TypeAndSchema[A](val typeRef: String, val schema: Expr[Schema[R, A]], val tpe: Type[A]) + + def rec[A](using tpe: Type[A]): List[TypeAndSchema[?]] = + TypeRepr.of(using tpe).dealias match { + case OrType(l, r) => + rec(using l.asType.asInstanceOf[Type[Any]]) ++ rec(using r.asType.asInstanceOf[Type[Any]]) + case otherRepr => + val otherString: String = otherRepr.show + val expr: TypeAndSchema[A] = + Expr.summon[Schema[R, A]] match { + case Some(foundSchema) => + TypeAndSchema[A](otherString, foundSchema, otherRepr.asType.asInstanceOf[Type[A]]) + case None => + quotes.reflect.report.errorAndAbort(s"Couldn't resolve Schema[Any, $otherString]") + } + + List(expr) + } + + val typeAndSchemas: List[TypeAndSchema[?]] = rec[T] + + val schemaByTypeNameList: Expr[List[(String, Schema[R, Any])]] = Expr.ofList( + typeAndSchemas.map { case (tas: TypeAndSchema[a]) => + given Type[a] = tas.tpe + '{ (${ Expr(tas.typeRef) }, ${ tas.schema }.asInstanceOf[Schema[R, Any]]) } + } + ) + val name = TypeRepr.of[T].show + + if (name.contains("|")) { + report.error( + s"You must explicitly add type parameter to derive Schema for a union type in order to capture the name of the type alias" + ) + } + + '{ + val schemaByName: Map[String, Schema[R, Any]] = ${ schemaByTypeNameList }.toMap + new Schema[R, T] { + + def resolve(value: T): Step[R] = { + var ret: Step[R] = null + ${ + Expr.block( + typeAndSchemas.map { case (tas: TypeAndSchema[a]) => + given Type[a] = tas.tpe + + '{ if value.isInstanceOf[a] then ret = schemaByName(${ Expr(tas.typeRef) }).resolve(value) } + }, + '{ require(ret != null, s"no schema for ${value}") } + ) + } + ret + } + + def toType(isInput: Boolean, isSubscription: Boolean): __Type = + Types.makeUnion( + Some(${ Expr(name) }), + None, + schemaByName.values.map(_.toType_(isInput, isSubscription)).toList + ) + } + } + } +} diff --git a/core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala b/core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala index e2bc255f8..5b7030dd6 100644 --- a/core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala +++ b/core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala @@ -273,6 +273,56 @@ object Scala3DerivesSpec extends ZIOSpecDefault { data1 == """{"enum2String":"ENUM1"}""", data2 == """{"enum2String":"ENUM2"}""" ) + }, + test("union type") { + final case class Foo(value: String) derives Schema.SemiAuto + final case class Bar(foo: Int) derives Schema.SemiAuto + final case class Baz(bar: Int) derives Schema.SemiAuto + type Payload = Foo | Bar | Baz + + given Schema[Any, Payload] = Schema.unionType[Payload] + + final case class QueryInput(isFoo: Boolean) derives ArgBuilder, Schema.SemiAuto + final case class Query(testQuery: QueryInput => zio.UIO[Payload]) derives Schema.SemiAuto + + val gql = graphQL(RootResolver(Query(i => ZIO.succeed(if (i.isFoo) Foo("foo") else Bar(1))))) + + val expectedSchema = + """ +schema { + query: Query +} + +union Payload = Foo | Bar | Baz + +type Bar { + foo: Int! +} + +type Baz { + bar: Int! +} + +type Foo { + value: String! +} + +type Query { + testQuery(isFoo: Boolean!): Payload! +} +""".stripMargin + val interpreter = gql.interpreterUnsafe + + for { + res1 <- interpreter.execute("{ testQuery(isFoo: true){ ... on Foo { value } } }") + res2 <- interpreter.execute("{ testQuery(isFoo: false){ ... on Bar { foo } } }") + data1 = res1.data.toString + data2 = res2.data.toString + } yield assertTrue( + data1 == """{"testQuery":{"value":"foo"}}""", + data2 == """{"testQuery":{"foo":1}}""", + gql.render == expectedSchema + ) } ) }