diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index 6a20ba0073..281161eef4 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -105,14 +105,14 @@ class ClientCodegenVisitor( // Add errors attached at the service level to the models .let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) } // Add `Box` to recursive shapes as necessary - .let(RecursiveShapeBoxer::transform) + .let(RecursiveShapeBoxer()::transform) // Normalize the `message` field on errors when enabled in settings (default: true) .letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform) // NormalizeOperations by ensuring every operation has an input & output shape .let(OperationNormalizer::transform) // Drop unsupported event stream operations from the model .let { RemoveEventStreamOperations.transform(it, settings) } - // - Normalize event stream operations + // Normalize event stream operations .let(EventStreamNormalizer::transform) /** diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt index 2d2cf76916..74d893bc17 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt @@ -36,7 +36,7 @@ internal class ResiliencyConfigCustomizationTest { @Test fun `generates a valid config`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val project = TestWorkspace.testProject() val codegenContext = testCodegenContext(model, settings = project.rustSettings()) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt index 2d5a313401..41144d5945 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt @@ -12,7 +12,8 @@ import software.amazon.smithy.model.traits.Trait /** * Trait indicating that this shape should be represented with `Box` when converted into Rust * - * This is used to handle recursive shapes. See RecursiveShapeBoxer. + * This is used to handle recursive shapes. + * See [software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer]. * * This trait is synthetic, applied during code generation, and never used in actual models. */ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt index 4b47801a8d..d53751829f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt @@ -7,25 +7,50 @@ package software.amazon.smithy.rust.codegen.core.smithy.transformers import software.amazon.smithy.codegen.core.TopologicalIndex import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait -object RecursiveShapeBoxer { +class RecursiveShapeBoxer( /** - * Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait] + * A predicate that determines when a cycle in the shape graph contains "indirection". If a cycle contains + * indirection, no shape needs to be tagged. What constitutes indirection is up to the caller to decide. + */ + private val containsIndirectionPredicate: (Collection) -> Boolean = ::containsIndirection, + /** + * A closure that gets called on one member shape of a cycle that does not contain indirection for "fixing". For + * example, the [RustBoxTrait] trait can be used to tag the member shape. + */ + private val boxShapeFn: (MemberShape) -> MemberShape = ::addRustBoxTrait, +) { + /** + * Transform a model which may contain recursive shapes. * - * When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will - * iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point. + * For example, when recursive shapes do NOT go through a `CollectionShape` or a `MapShape` shape, they must be + * boxed in Rust. This function will iteratively find cycles and call [boxShapeFn] on a member shape in the + * cycle to act on it. This is done in a deterministic way until it reaches a fixed point. * - * This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so + * This function MUST be deterministic (always choose the same shapes to fix). If it is not, that is a bug. Even so * this function may cause backward compatibility issues in certain pathological cases where a changes to recursive * structures cause different members to be boxed. We may need to address these via customizations. + * + * For example, given the following model, + * + * ```smithy + * namespace com.example + * + * structure Recursive { + * recursiveStruct: Recursive + * anotherField: Boolean + * } + * ``` + * + * The `com.example#Recursive$recursiveStruct` member shape is part of a cycle, but the + * `com.example#Recursive$anotherField` member shape is not. */ fun transform(model: Model): Model { val next = transformInner(model) @@ -37,16 +62,17 @@ object RecursiveShapeBoxer { } /** - * If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model. - * If [model] contains no loops, return null. + * If [model] contains a recursive loop that must be boxed, return the transformed model resulting form a call to + * [boxShapeFn]. + * If [model] contains no loops, return `null`. */ private fun transformInner(model: Model): Model? { - // Execute 1-step of the boxing algorithm in the path to reaching a fixed point - // 1. Find all the shapes that are part of a cycle - // 2. Find all the loops that those shapes are part of - // 3. Filter out the loops that go through a layer of indirection - // 3. Pick _just one_ of the remaining loops to fix - // 4. Select the member shape in that loop with the earliest shape id + // Execute 1 step of the boxing algorithm in the path to reaching a fixed point: + // 1. Find all the shapes that are part of a cycle. + // 2. Find all the loops that those shapes are part of. + // 3. Filter out the loops that go through a layer of indirection. + // 3. Pick _just one_ of the remaining loops to fix. + // 4. Select the member shape in that loop with the earliest shape id. // 5. Box it. // (External to this function) Go back to 1. val index = TopologicalIndex.of(model) @@ -58,34 +84,38 @@ object RecursiveShapeBoxer { // Flatten the connections into shapes. loops.map { it.shapes } } - val loopToFix = loops.firstOrNull { !containsIndirection(it) } + val loopToFix = loops.firstOrNull { !containsIndirectionPredicate(it) } return loopToFix?.let { loop: List -> check(loop.isNotEmpty()) - // pick the shape to box in a deterministic way + // Pick the shape to box in a deterministic way. val shapeToBox = loop.filterIsInstance().minByOrNull { it.id }!! ModelTransformer.create().mapShapes(model) { shape -> if (shape == shapeToBox) { - shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build() + boxShapeFn(shape.asMemberShape().get()) } else { shape } } } } +} - /** - * Check if a List contains a shape which will use a pointer when represented in Rust, avoiding the - * need to add more Boxes - */ - private fun containsIndirection(loop: List): Boolean { - return loop.find { - when (it) { - is ListShape, - is MapShape, - is SetShape, -> true - else -> it.hasTrait() - } - } != null +/** + * Check if a `List` contains a shape which will use a pointer when represented in Rust, avoiding the + * need to add more `Box`es. + * + * Why `CollectionShape`s and `MapShape`s? Note that `CollectionShape`s get rendered in Rust as `Vec`, and + * `MapShape`s as `HashMap`; they're the only Smithy shapes that "organically" introduce indirection + * (via a pointer to the heap) in the recursive path. For other recursive paths, we thus have to introduce the + * indirection artificially ourselves using `Box`. + * + */ +private fun containsIndirection(loop: Collection): Boolean = loop.find { + when (it) { + is CollectionShape, is MapShape -> true + else -> it.hasTrait() } -} +} != null + +private fun addRustBoxTrait(shape: MemberShape): MemberShape = shape.toBuilder().addTrait(RustBoxTrait()).build() diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index 10c61e125b..efb37235a2 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -82,7 +82,7 @@ class InstantiatorTest { @required num: Integer } - """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } + """.asSmithyModel().let { RecursiveShapeBoxer().transform(it) } private val codegenContext = testCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index 72a2dce315..e6be7eee97 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -313,7 +313,7 @@ class StructureGeneratorTest { @Test fun `it generates accessor methods`() { val testModel = - RecursiveShapeBoxer.transform( + RecursiveShapeBoxer().transform( """ namespace test diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt index b2c5195478..1f092ba71c 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt @@ -42,7 +42,7 @@ class AwsQueryParserGeneratorTest { @Test fun `it modifies operation parsing to include Response and Result tags`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = AwsQueryParserGenerator( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt index 2532a0ceec..04dab966b5 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt @@ -42,7 +42,7 @@ class Ec2QueryParserGeneratorTest { @Test fun `it modifies operation parsing to include Response and Result tags`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = Ec2QueryParserGenerator( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index ceca4a75b9..2530947261 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -114,7 +114,7 @@ class JsonParserGeneratorTest { @Test fun `generates valid deserializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider fun builderSymbol(shape: StructureShape): Symbol = diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index 253a5c5916..ade2ab1c4b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -94,7 +94,7 @@ internal class XmlBindingTraitParserGeneratorTest { @Test fun `generates valid parsers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = XmlBindingTraitParserGenerator( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt index 963383b67c..a7f2519f05 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt @@ -92,7 +92,7 @@ class AwsQuerySerializerGeneratorTest { true -> CodegenTarget.CLIENT false -> CodegenTarget.SERVER } - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget) val symbolProvider = codegenContext.symbolProvider val parserGenerator = AwsQuerySerializerGenerator(testCodegenContext(model, codegenTarget = codegenTarget)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt index 9d784b9a2b..2663fd271b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt @@ -85,7 +85,7 @@ class Ec2QuerySerializerGeneratorTest { @Test fun `generates valid serializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = Ec2QuerySerializerGenerator(codegenContext) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt index c56385057b..5e94e73432 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt @@ -100,7 +100,7 @@ class JsonSerializerGeneratorTest { @Test fun `generates valid serializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserSerializer = JsonSerializerGenerator( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt index e753b9e166..1bff7855e9 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt @@ -105,7 +105,7 @@ internal class XmlBindingTraitSerializerGeneratorTest { @Test fun `generates valid serializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = XmlBindingTraitSerializerGenerator( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt index 061814a73a..293e221713 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt @@ -31,7 +31,7 @@ internal class RecursiveShapeBoxerTest { hello: Hello } """.asSmithyModel() - RecursiveShapeBoxer.transform(model) shouldBe model + RecursiveShapeBoxer().transform(model) shouldBe model } @Test @@ -43,7 +43,7 @@ internal class RecursiveShapeBoxerTest { anotherField: Boolean } """.asSmithyModel() - val transformed = RecursiveShapeBoxer.transform(model) + val transformed = RecursiveShapeBoxer().transform(model) val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct") member.expectTrait() } @@ -70,7 +70,7 @@ internal class RecursiveShapeBoxerTest { third: SecondTree } """.asSmithyModel() - val transformed = RecursiveShapeBoxer.transform(model) + val transformed = RecursiveShapeBoxer().transform(model) val boxed = transformed.shapes().filter { it.hasTrait() }.toList() boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf( "Atom\$add", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt index 56993015b9..8b33c57694 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt @@ -66,7 +66,7 @@ class RecursiveShapesIntegrationTest { } output.message shouldContain "has infinite size" - val fixedProject = check(RecursiveShapeBoxer.transform(model)) + val fixedProject = check(RecursiveShapeBoxer().transform(model)) fixedProject.compileAndTest() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 7243b1fa6c..79572b9c38 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList +import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger @@ -162,7 +163,9 @@ open class ServerCodegenVisitor( // Add errors attached at the service level to the models .let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) } // Add `Box` to recursive shapes as necessary - .let(RecursiveShapeBoxer::transform) + .let(RecursiveShapeBoxer()::transform) + // Add `Box` to recursive constraint violations as necessary + .let(RecursiveConstraintViolationBoxer::transform) // Normalize operations by adding synthetic input and output shapes to every operation .let(OperationNormalizer::transform) // Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException` diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt index 7d503e72de..ef80796d5a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt @@ -9,12 +9,15 @@ import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.join -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput class CollectionConstraintViolationGenerator( @@ -38,16 +41,22 @@ class CollectionConstraintViolationGenerator( private val constraintsInfo: List = collectionConstraintsInfo.map { it.toTraitInfo() } fun render() { - val memberShape = model.expectShape(shape.member.target) + val targetShape = model.expectShape(shape.member.target) val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) val constraintViolationName = constraintViolationSymbol.name - val isMemberConstrained = memberShape.canReachConstrainedShape(model, symbolProvider) + val isMemberConstrained = targetShape.canReachConstrainedShape(model, symbolProvider) val constraintViolationVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE) modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) { val constraintViolationVariants = constraintsInfo.map { it.constraintViolationVariant }.toMutableList() if (isMemberConstrained) { constraintViolationVariants += { + val memberConstraintViolationSymbol = + constraintViolationSymbolProvider.toSymbol(targetShape).letIf( + shape.member.hasTrait(), + ) { + it.makeRustBoxed() + } rustTemplate( """ /// Constraint violation error when an element doesn't satisfy its own constraints. @@ -56,7 +65,7 @@ class CollectionConstraintViolationGenerator( ##[doc(hidden)] Member(usize, #{MemberConstraintViolationSymbol}) """, - "MemberConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(memberShape), + "MemberConstraintViolationSymbol" to memberConstraintViolationSymbol, ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt index 4833fa3058..8bbf95ed88 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt @@ -11,10 +11,13 @@ import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput class MapConstraintViolationGenerator( @@ -47,7 +50,14 @@ class MapConstraintViolationGenerator( constraintViolationCodegenScopeMutableList.add("KeyConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(keyShape)) } if (isValueConstrained(valueShape, model, symbolProvider)) { - constraintViolationCodegenScopeMutableList.add("ValueConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(valueShape)) + constraintViolationCodegenScopeMutableList.add( + "ValueConstraintViolationSymbol" to + constraintViolationSymbolProvider.toSymbol(valueShape).letIf( + shape.value.hasTrait(), + ) { + it.makeRustBoxed() + }, + ) constraintViolationCodegenScopeMutableList.add("KeySymbol" to constrainedShapeSymbolProvider.toSymbol(keyShape)) } val constraintViolationCodegenScope = constraintViolationCodegenScopeMutableList.toTypedArray() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt index b2c89f8d73..bd968bb1b7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt @@ -20,13 +20,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed -import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait /** * Renders constraint violation types that arise when building a structure shape builder. @@ -138,8 +138,8 @@ class ServerBuilderConstraintViolations( val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(targetShape) - // If the corresponding structure's member is boxed, box this constraint violation symbol too. - .letIf(constraintViolation.forMember.hasTrait()) { + // Box this constraint violation symbol if necessary. + .letIf(constraintViolation.forMember.hasTrait()) { it.makeRustBoxed() } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt index 178aa725b7..3d65e4c89e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -49,6 +49,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput import software.amazon.smithy.rust.codegen.server.smithy.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled @@ -541,6 +542,8 @@ class ServerBuilderGenerator( val hasBox = builderMemberSymbol(member) .mapRustType { it.stripOuter() } .isRustBoxed() + val errHasBox = member.hasTrait() + if (hasBox) { writer.rustTemplate( """ @@ -548,11 +551,6 @@ class ServerBuilderGenerator( #{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)), #{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)), }) - .map(|res| - res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" } - .map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err))) - ) - .transpose()? """, *codegenScope, ) @@ -563,16 +561,23 @@ class ServerBuilderGenerator( #{MaybeConstrained}::Constrained(x) => Ok(x), #{MaybeConstrained}::Unconstrained(x) => x.try_into(), }) - .map(|res| - res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} - .map_err(ConstraintViolation::${constraintViolation.name()}) - ) - .transpose()? """, *codegenScope, ) } + writer.rustTemplate( + """ + .map(|res| + res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} + ${if (errHasBox) ".map_err(Box::new)" else "" } + .map_err(ConstraintViolation::${constraintViolation.name()}) + ) + .transpose()? + """, + *codegenScope, + ) + // Constrained types are not public and this is a member shape that would have generated a // public constrained type, were the setting to be enabled. // We've just checked the constraints hold by going through the non-public diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt index b5a9d45895..3d233d28ef 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt @@ -19,11 +19,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.UnconstrainedShapeSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait /** * Generates a Rust type for a constrained collection shape that is able to hold values for the corresponding @@ -107,7 +109,11 @@ class UnconstrainedCollectionGenerator( constrainedShapeSymbolProvider.toSymbol(shape.member) } val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape) - + val boxErr = if (shape.member.hasTrait()) { + ".map_err(|(idx, inner_violation)| (idx, Box::new(inner_violation)))" + } else { + "" + } val constrainValueWritable = writable { conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) { rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))") @@ -124,7 +130,9 @@ class UnconstrainedCollectionGenerator( #{ConstrainValueWritable:W} }) .collect(); - let inner = res.map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?; + let inner = res + $boxErr + .map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?; """, "Vec" to RuntimeType.Vec, "ConstrainedMemberSymbol" to constrainedMemberSymbol, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt index e18d372c75..b6445da017 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt @@ -20,10 +20,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait /** * Generates a Rust type for a constrained map shape that is able to hold values for the corresponding @@ -125,6 +127,11 @@ class UnconstrainedMapGenerator( ) } val constrainValueWritable = writable { + val boxErr = if (shape.value.hasTrait()) { + ".map_err(Box::new)" + } else { + "" + } if (constrainedMemberValueSymbol.isOptional()) { // The map is `@sparse`. rustBlock("match v") { @@ -133,7 +140,7 @@ class UnconstrainedMapGenerator( // DRYing this up with the else branch below would make this less understandable. rustTemplate( """ - match #{ConstrainedValueSymbol}::try_from(v) { + match #{ConstrainedValueSymbol}::try_from(v)$boxErr { Ok(v) => Ok((k, Some(v))), Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), } @@ -145,7 +152,7 @@ class UnconstrainedMapGenerator( } else { rustTemplate( """ - match #{ConstrainedValueSymbol}::try_from(v) { + match #{ConstrainedValueSymbol}::try_from(v)$boxErr { Ok(v) => #{Epilogue:W}, Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), } @@ -214,9 +221,10 @@ class UnconstrainedMapGenerator( // ``` rustTemplate( """ - let hm: std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}> = + let hm: #{HashMap}<#{KeySymbol}, #{ValueSymbol}> = hm.into_iter().map(|(k, v)| (k, v.into())).collect(); """, + "HashMap" to RuntimeType.HashMap, "KeySymbol" to symbolProvider.toSymbol(keyShape), "ValueSymbol" to symbolProvider.toSymbol(valueShape), ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt index 72655675a0..dac275e051 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput /** @@ -171,8 +172,8 @@ class UnconstrainedUnionGenerator( val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(targetShape) - // If the corresponding union's member is boxed, box this constraint violation symbol too. - .letIf(constraintViolation.forMember.hasTrait()) { + // Box this constraint violation symbol if necessary. + .letIf(constraintViolation.forMember.hasTrait()) { it.makeRustBoxed() } @@ -201,10 +202,15 @@ class UnconstrainedUnionGenerator( (!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider)) val (unconstrainedVar, boxIt) = if (member.hasTrait()) { - "(*unconstrained)" to ".map(Box::new).map_err(Box::new)" + "(*unconstrained)" to ".map(Box::new)" } else { "unconstrained" to "" } + val boxErr = if (member.hasTrait()) { + ".map_err(Box::new)" + } else { + "" + } if (resolveToNonPublicConstrainedType) { val constrainedSymbol = @@ -219,6 +225,7 @@ class UnconstrainedUnionGenerator( let constrained: #{ConstrainedSymbol} = $unconstrainedVar .try_into() $boxIt + $boxErr .map_err(Self::Error::${ConstraintViolation(member).name()})?; constrained.into() } @@ -231,6 +238,7 @@ class UnconstrainedUnionGenerator( $unconstrainedVar .try_into() $boxIt + $boxErr .map_err(Self::Error::${ConstraintViolation(member).name()})? """, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt new file mode 100644 index 0000000000..9aee2b884e --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt @@ -0,0 +1,25 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.traits + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.Trait + +/** + * This shape is analogous to [software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait], but for the + * constraint violation graph. The sets of shapes we tag are different, and they are interpreted by the code generator + * differently, so we need a separate tag. + * + * This is used to handle recursive constraint violations. + * See [software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer]. + */ +class ConstraintViolationRustBoxTrait : Trait { + val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#constraintViolationBox") + override fun toNode(): Node = Node.objectNode() + + override fun toShapeId(): ShapeId = ID +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt new file mode 100644 index 0000000000..d2e41ead36 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait + +object RecursiveConstraintViolationBoxer { + /** + * Transform a model which may contain recursive shapes into a model annotated with [ConstraintViolationRustBoxTrait]. + * + * See [RecursiveShapeBoxer] for how the tagging algorithm works. + * + * The constraint violation graph needs to box types in recursive paths more often. Since we don't collect + * constraint violations (yet, see [0]), the constraint violation graph never holds `Vec`s or `HashMap`s, + * only simple types. Indeed, the following simple recursive model: + * + * ```smithy + * union Recursive { + * list: List + * } + * + * @length(min: 69) + * list List { + * member: Recursive + * } + * ``` + * + * has a cycle that goes through a list shape, so no shapes in it need boxing in the regular shape graph. However, + * the constraint violation graph is infinitely recursive if we don't introduce boxing somewhere: + * + * ```rust + * pub mod model { + * pub mod list { + * pub enum ConstraintViolation { + * Length(usize), + * Member( + * usize, + * crate::model::recursive::ConstraintViolation, + * ), + * } + * } + * + * pub mod recursive { + * pub enum ConstraintViolation { + * List(crate::model::list::ConstraintViolation), + * } + * } + * } + * ``` + * + * So what we do to fix this is to configure the `RecursiveShapeBoxer` model transform so that the "cycles through + * lists and maps introduce indirection" assumption can be lifted. This allows this model transform to tag member + * shapes along recursive paths with a new trait, `ConstraintViolationRustBoxTrait`, that the constraint violation + * type generation then utilizes to ensure that no infinitely recursive constraint violation types get generated. + * Places where constraint violations are handled (like where unconstrained types are converted to constrained + * types) must account for the scenario where they now are or need to be boxed. + * + * [0] https://github.com/awslabs/smithy-rs/pull/2040 + */ + fun transform(model: Model): Model = RecursiveShapeBoxer( + containsIndirectionPredicate = ::constraintViolationLoopContainsIndirection, + boxShapeFn = ::addConstraintViolationRustBoxTrait, + ).transform(model) + + private fun constraintViolationLoopContainsIndirection(loop: Collection): Boolean = + loop.find { it.hasTrait() } != null + + private fun addConstraintViolationRustBoxTrait(memberShape: MemberShape): MemberShape = + memberShape.toBuilder().addTrait(ConstraintViolationRustBoxTrait()).build() +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt new file mode 100644 index 0000000000..8285d20c8b --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt @@ -0,0 +1,185 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.util.stream.Stream + +internal class RecursiveConstraintViolationsTest { + + data class TestCase( + /** The test name is only used in the generated report, to easily identify a failing test. **/ + val testName: String, + /** The model to generate **/ + val model: Model, + /** The shape ID of the member shape that should have the marker trait attached. **/ + val shapeIdWithConstraintViolationRustBoxTrait: String, + ) + + class RecursiveConstraintViolationsTestProvider : ArgumentsProvider { + private val baseModel = + """ + namespace com.amazonaws.recursiveconstraintviolations + + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service RecursiveConstraintViolations { + operations: [ + Operation + ] + } + + @http(uri: "/operation", method: "POST") + operation Operation { + input: Recursive + output: Recursive + errors: [ValidationException] + } + """ + + private fun recursiveListModel(sparse: Boolean, listPrefix: String = ""): Pair = + """ + $baseModel + + structure Recursive { + list: ${listPrefix}List + } + + ${ if (sparse) { "@sparse" } else { "" } } + @length(min: 69) + list ${listPrefix}List { + member: Recursive + } + """.asSmithyModel() to if ("${listPrefix}List" < "Recursive") { + "com.amazonaws.recursiveconstraintviolations#${listPrefix}List\$member" + } else { + "com.amazonaws.recursiveconstraintviolations#Recursive\$list" + } + + private fun recursiveMapModel(sparse: Boolean, mapPrefix: String = ""): Pair = + """ + $baseModel + + structure Recursive { + map: ${mapPrefix}Map + } + + ${ if (sparse) { "@sparse" } else { "" } } + @length(min: 69) + map ${mapPrefix}Map { + key: String, + value: Recursive + } + """.asSmithyModel() to if ("${mapPrefix}Map" < "Recursive") { + "com.amazonaws.recursiveconstraintviolations#${mapPrefix}Map\$value" + } else { + "com.amazonaws.recursiveconstraintviolations#Recursive\$map" + } + + private fun recursiveUnionModel(unionPrefix: String = ""): Pair = + """ + $baseModel + + structure Recursive { + attributeValue: ${unionPrefix}AttributeValue + } + + // Named `${unionPrefix}AttributeValue` in honor of DynamoDB's famous `AttributeValue`. + // https://docs.rs/aws-sdk-dynamodb/latest/aws_sdk_dynamodb/model/enum.AttributeValue.html + union ${unionPrefix}AttributeValue { + set: SetAttribute + } + + @uniqueItems + list SetAttribute { + member: ${unionPrefix}AttributeValue + } + """.asSmithyModel() to + // The first loop the algorithm picks out to fix turns out to be the `list <-> union` loop: + // + // ``` + // [ + // ${unionPrefix}AttributeValue, + // ${unionPrefix}AttributeValue$set, + // SetAttribute, + // SetAttribute$member + // ] + // ``` + // + // Upon which, after fixing it, the other loop (`structure <-> list <-> union`) already contains + // indirection, so we disregard it. + // + // This is hence a good test in that it tests that `RecursiveConstraintViolationBoxer` does not + // superfluously add more indirection than strictly necessary. + // However, it is a bad test in that if the Smithy library ever returns the recursive paths in a + // different order, the (`structure <-> list <-> union`) loop might be fixed first, and this test might + // start to fail! So watch out for that. Nonetheless, `RecursiveShapeBoxer` calls out: + // + // This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. + // + // So I think it's fair to write this test under the above assumption. + if ("${unionPrefix}AttributeValue" < "SetAttribute") { + "com.amazonaws.recursiveconstraintviolations#${unionPrefix}AttributeValue\$set" + } else { + "com.amazonaws.recursiveconstraintviolations#SetAttribute\$member" + } + + override fun provideArguments(context: ExtensionContext?): Stream { + val listModels = listOf(false, true).flatMap { isSparse -> + listOf("", "ZZZ").map { listPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveListModel(isSparse, listPrefix) + var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive list" + if (listPrefix.isNotEmpty()) { + testName += " with shape name prefix $listPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) + } + } + val mapModels = listOf(false, true).flatMap { isSparse -> + listOf("", "ZZZ").map { mapPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveMapModel(isSparse, mapPrefix) + var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive map" + if (mapPrefix.isNotEmpty()) { + testName += " with shape name prefix $mapPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) + } + } + val unionModels = listOf("", "ZZZ").map { unionPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveUnionModel(unionPrefix) + var testName = "recursive union" + if (unionPrefix.isNotEmpty()) { + testName += " with shape name prefix $unionPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) + } + return listOf(listModels, mapModels, unionModels) + .flatten() + .map { Arguments.of(it) }.stream() + } + } + + /** + * Ensures the models generate code that compiles. + * + * Make sure the tests in [software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxerTest] + * are all passing before debugging any of these tests, since the former tests test preconditions for these. + */ + @ParameterizedTest + @ArgumentsSource(RecursiveConstraintViolationsTestProvider::class) + fun `recursive constraint violation code generation test`(testCase: TestCase) { + serverIntegrationTest(testCase.model) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index a65192d609..c7f2b2e5ff 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -123,7 +123,7 @@ class ServerInstantiatorTest { }, ]) string NamedEnum - """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } + """.asSmithyModel().let { RecursiveShapeBoxer().transform(it) } private val codegenContext = serverTestCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt new file mode 100644 index 0000000000..622d380603 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt @@ -0,0 +1,31 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.RecursiveConstraintViolationsTest +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait +import kotlin.streams.toList + +internal class RecursiveConstraintViolationBoxerTest { + @ParameterizedTest + @ArgumentsSource(RecursiveConstraintViolationsTest.RecursiveConstraintViolationsTestProvider::class) + fun `recursive constraint violation boxer test`(testCase: RecursiveConstraintViolationsTest.TestCase) { + val transformed = RecursiveConstraintViolationBoxer.transform(testCase.model) + + val shapesWithConstraintViolationRustBoxTrait = transformed.shapes().filter { + it.hasTrait() + }.toList() + + // Only the provided member shape should have the trait attached. + shapesWithConstraintViolationRustBoxTrait shouldBe + listOf(transformed.lookup(testCase.shapeIdWithConstraintViolationRustBoxTrait)) + } +}