Skip to content

Commit

Permalink
Constraint member types are refactored as standalone shapes. (smithy-…
Browse files Browse the repository at this point in the history
…lang#2256)

* Constraint member types are refactored as standalone shapes.

* ModelModule to ServerRustModule.model

* Constraints are written to the correct module

* Code generates for non-public constrained types.

* Removed a comment

* Using ConcurrentHashmap just to be on the safe side

* Clippy warnings removed on constraints, k.into() if gated

* Wordings for some of the checks changed

* Test need to call rustCrate.renderInlineMemoryModules

* ktlintFormat related changes

* RustCrate need to be passed for server builder

* Param renamed in getParentAndInlineModuleForConstrainedMember

* pubCrate to publicConstrainedType rename

* PythonServer symbol builder needed to pass publicConstrainedTypes

* @required still remains on the member shape after transformation

* ConcurrentLinkedQueue used for root RustWriters

* runTestCase does not run the tests but just sets them up, hence has been renamed

* CHANGELOG added

---------

Co-authored-by: Fahad Zubair <fahadzub@amazon.com>
  • Loading branch information
drganjoo and Fahad Zubair committed Feb 17, 2023
1 parent 5a5a7c4 commit 3d00767
Show file tree
Hide file tree
Showing 57 changed files with 1,861 additions and 308 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Expand Up @@ -229,3 +229,9 @@ produced Rust code that did not compile"""
references = ["smithy-rs#2352", "smithy-rs#2343"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "server"}
author = "82marbag"

[[smithy-rs]]
message = "Support for constraint traits on member shapes (constraint trait precedence) has been added."
references = ["smithy-rs#1969"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "server" }
author = "drganjoo"
Expand Up @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
Expand Down Expand Up @@ -51,6 +52,7 @@ abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<C
)

override fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
Expand All @@ -73,6 +75,7 @@ abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<C
}

override fun renderError(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
Expand Down
Expand Up @@ -17,12 +17,13 @@ import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest

class ClientEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.runTestCase(
EventStreamTestTools.setupTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
Expand All @@ -41,6 +42,6 @@ class ClientEventStreamMarshallerGeneratorTest {
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Marshall,
)
).compileAndTest()
}
}
Expand Up @@ -19,12 +19,13 @@ import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest

class ClientEventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.runTestCase(
EventStreamTestTools.setupTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
Expand All @@ -44,6 +45,6 @@ class ClientEventStreamUnmarshallerGeneratorTest {
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Unmarshall,
)
).compileAndTest()
}
}
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
Expand Down Expand Up @@ -62,6 +63,7 @@ interface EventStreamTestRequirements<C : CodegenContext> {

/** Render a builder for the given shape */
fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: C,
shape: StructureShape,
Expand All @@ -76,17 +78,21 @@ interface EventStreamTestRequirements<C : CodegenContext> {
)

/** Render an error struct and builder */
fun renderError(writer: RustWriter, codegenContext: C, shape: StructureShape)
fun renderError(rustCrate: RustCrate, writer: RustWriter, codegenContext: C, shape: StructureShape)
}

object EventStreamTestTools {
fun <C : CodegenContext> runTestCase(
fun <C : CodegenContext> setupTestCase(
testCase: EventStreamTestModels.TestCase,
requirements: EventStreamTestRequirements<C>,
codegenTarget: CodegenTarget,
variety: EventStreamTestVariety,
) {
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model))
transformers: List<(Model) -> Model> = listOf(),
): TestWriterDelegator {
val model = (listOf(OperationNormalizer::transform, EventStreamNormalizer::transform) + transformers).fold(testCase.model) { model, transformer ->
transformer(model)
}

val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val codegenContext = requirements.createCodegenContext(
model,
Expand All @@ -104,7 +110,8 @@ object EventStreamTestTools {
EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator)
}
}
test.project.compileAndTest()

return test.project
}

private fun <C : CodegenContext> generateTestProject(
Expand All @@ -128,7 +135,7 @@ object EventStreamTestTools {
requirements.renderOperationError(this, model, symbolProvider, operationShape)
requirements.renderOperationError(this, model, symbolProvider, unionShape)
for (shape in errors) {
requirements.renderError(this, codegenContext, shape)
requirements.renderError(project, this, codegenContext, shape)
}
}
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
Expand Down
Expand Up @@ -78,6 +78,7 @@ class PythonServerCodegenVisitor(
serviceShape: ServiceShape,
symbolVisitorConfig: SymbolVisitorConfig,
publicConstrainedTypes: Boolean,
includeConstraintShapeProvider: Boolean,
) = RustServerCodegenPythonPlugin.baseSymbolProvider(model, serviceShape, symbolVisitorConfig, publicConstrainedTypes)

val serverSymbolProviders = ServerSymbolProviders.from(
Expand Down
Expand Up @@ -79,7 +79,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin {
// Generate public constrained types for directly constrained shapes.
// In the Python server project, this is only done to generate constrained types for simple shapes (e.g.
// a `string` shape with the `length` trait), but these always remain `pub(crate)`.
.let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it }
.let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape, constrainedTypes) else it }
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) }
// Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
Expand Down
Expand Up @@ -19,8 +19,12 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShortShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.contextName
Expand All @@ -29,9 +33,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing
import software.amazon.smithy.rust.codegen.core.smithy.locatedIn
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderModule
import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait

/**
* The [ConstrainedShapeSymbolProvider] returns, for a given _directly_
Expand All @@ -56,14 +64,16 @@ class ConstrainedShapeSymbolProvider(
private val base: RustSymbolProvider,
private val model: Model,
private val serviceShape: ServiceShape,
private val publicConstrainedTypes: Boolean,
) : WrappingSymbolProvider(base) {
private val nullableIndex = NullableIndex.of(model)

private fun publicConstrainedSymbolForMapOrCollectionShape(shape: Shape): Symbol {
check(shape is MapShape || shape is CollectionShape)

val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase())
return symbolBuilder(shape, rustType).locatedIn(ServerRustModule.Model).build()
val (name, module) = getMemberNameAndModule(shape, serviceShape, ServerRustModule.Model, !publicConstrainedTypes)
val rustType = RustType.Opaque(name)
return symbolBuilder(shape, rustType).locatedIn(module).build()
}

override fun toSymbol(shape: Shape): Symbol {
Expand All @@ -74,8 +84,14 @@ class ConstrainedShapeSymbolProvider(
val target = model.expectShape(shape.target)
val targetSymbol = this.toSymbol(target)
// Handle boxing first, so we end up with `Option<Box<_>>`, not `Box<Option<_>>`.
handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode)
handleOptionality(
handleRustBoxing(targetSymbol, shape),
shape,
nullableIndex,
base.config().nullabilityCheckMode,
)
}

is MapShape -> {
if (shape.isDirectlyConstrained(base)) {
check(shape.hasTrait<LengthTrait>()) {
Expand All @@ -91,6 +107,7 @@ class ConstrainedShapeSymbolProvider(
.build()
}
}

is CollectionShape -> {
if (shape.isDirectlyConstrained(base)) {
check(constrainedCollectionCheck(shape)) {
Expand All @@ -105,8 +122,11 @@ class ConstrainedShapeSymbolProvider(

is StringShape, is IntegerShape, is ShortShape, is LongShape, is ByteShape, is BlobShape -> {
if (shape.isDirectlyConstrained(base)) {
val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase())
symbolBuilder(shape, rustType).locatedIn(ServerRustModule.Model).build()
// A standalone constrained shape goes into `ModelsModule`, but one
// arising from a constrained member shape goes into a module for the container.
val (name, module) = getMemberNameAndModule(shape, serviceShape, ServerRustModule.Model, !publicConstrainedTypes)
val rustType = RustType.Opaque(name)
symbolBuilder(shape, rustType).locatedIn(module).build()
} else {
base.toSymbol(shape)
}
Expand All @@ -122,9 +142,51 @@ class ConstrainedShapeSymbolProvider(
* - That it has no unsupported constraints applied.
*/
private fun constrainedCollectionCheck(shape: CollectionShape): Boolean {
val supportedConstraintTraits = supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet()
val supportedConstraintTraits =
supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet()
val allConstraintTraits = allConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet()

return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits).isEmpty()
return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits)
.isEmpty()
}

/**
* Returns the pair (Rust Symbol Name, Inline Module) for the shape. At the time of model transformation all
* constrained member shapes are extracted and are given a model-wide unique name. However, the generated code
* for the new shapes is in a module that is named after the containing shape (structure, list, map or union).
* The new shape's Rust Symbol is renamed from `{structureName}{memberName}` to `{structure_name}::{member_name}`
*/
private fun getMemberNameAndModule(
shape: Shape,
serviceShape: ServiceShape,
defaultModule: RustModule.LeafModule,
pubCrateServerBuilder: Boolean,
): Pair<String, RustModule.LeafModule> {
val syntheticMemberTrait = shape.getTrait<SyntheticStructureFromConstrainedMemberTrait>()
?: return Pair(shape.contextName(serviceShape), defaultModule)

return if (syntheticMemberTrait.container is StructureShape) {
val builderModule = syntheticMemberTrait.container.serverBuilderModule(base, pubCrateServerBuilder)
val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name
Pair(renameTo.toPascalCase(), builderModule)
} else {
// For non-structure shapes, the new shape defined for a constrained member shape
// needs to be placed in an inline module named `pub {container_name_in_snake_case}`.
val moduleName = RustReservedWords.escapeIfNeeded(syntheticMemberTrait.container.id.name.toSnakeCase())
val innerModuleName = moduleName + if (pubCrateServerBuilder) {
"_internal"
} else {
""
}

val innerModule = RustModule.new(
innerModuleName,
visibility = Visibility.publicIf(!pubCrateServerBuilder, Visibility.PUBCRATE),
parent = defaultModule,
inline = true,
)
val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name
Pair(renameTo.toPascalCase(), innerModule)
}
}
}
Expand Up @@ -29,8 +29,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.locatedIn
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol
import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait

/**
* The [ConstraintViolationSymbolProvider] returns, for a given constrained
Expand Down Expand Up @@ -79,15 +81,29 @@ class ConstraintViolationSymbolProvider(

private fun Shape.shapeModule(): RustModule.LeafModule {
val documentation = if (publicConstrainedTypes && this.isDirectlyConstrained(base)) {
"See [`${this.contextName(serviceShape)}`]."
val symbol = base.toSymbol(this)
"See [`${this.contextName(serviceShape)}`]($symbol)."
} else {
null
}
return RustModule.new(

val syntheticTrait = getTrait<SyntheticStructureFromConstrainedMemberTrait>()

val (module, name) = if (syntheticTrait != null) {
// For constrained member shapes, the ConstraintViolation code needs to go in an inline rust module
// that is a descendant of the module that contains the extracted shape itself.
val overriddenMemberModule = this.getParentAndInlineModuleForConstrainedMember(base, publicConstrainedTypes)!!
val name = syntheticTrait.member.memberName
Pair(overriddenMemberModule.second, RustReservedWords.escapeIfNeeded(name).toSnakeCase())
} else {
// Need to use the context name so we get the correct name for maps.
name = RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase(),
Pair(ServerRustModule.Model, RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase())
}

return RustModule.new(
name = name,
visibility = visibility,
parent = ServerRustModule.Model,
parent = module,
inline = true,
documentation = documentation,
)
Expand Down

0 comments on commit 3d00767

Please sign in to comment.