Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constraint member types are refactored as standalone shapes. #2256

Merged
merged 19 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,9 @@ produced Rust code that did not compile"""
references = ["smithy-rs#2352", "smithy-rs#2343"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "server"}
author = "82marbag"

[[smithy-rs]]
message = "Support for constraint traits on member shapes (constraint trait precedence) has been added."
references = ["smithy-rs#1969"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "server" }
author = "drganjoo"
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
Expand Down Expand Up @@ -51,6 +52,7 @@ abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<C
)

override fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
Expand All @@ -73,6 +75,7 @@ abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements<C
}

override fun renderError(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: ClientCodegenContext,
shape: StructureShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest

class ClientEventStreamMarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.runTestCase(
EventStreamTestTools.setupTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
Expand All @@ -41,6 +42,6 @@ class ClientEventStreamMarshallerGeneratorTest {
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Marshall,
)
).compileAndTest()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety
import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest

class ClientEventStreamUnmarshallerGeneratorTest {
@ParameterizedTest
@ArgumentsSource(TestCasesProvider::class)
fun test(testCase: EventStreamTestModels.TestCase) {
EventStreamTestTools.runTestCase(
EventStreamTestTools.setupTestCase(
testCase,
object : ClientEventStreamBaseRequirements() {
override fun renderGenerator(
Expand All @@ -44,6 +45,6 @@ class ClientEventStreamUnmarshallerGeneratorTest {
},
CodegenTarget.CLIENT,
EventStreamTestVariety.Unmarshall,
)
).compileAndTest()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
Expand Down Expand Up @@ -62,6 +63,7 @@ interface EventStreamTestRequirements<C : CodegenContext> {

/** Render a builder for the given shape */
fun renderBuilderForShape(
rustCrate: RustCrate,
writer: RustWriter,
codegenContext: C,
shape: StructureShape,
Expand All @@ -76,17 +78,21 @@ interface EventStreamTestRequirements<C : CodegenContext> {
)

/** Render an error struct and builder */
fun renderError(writer: RustWriter, codegenContext: C, shape: StructureShape)
fun renderError(rustCrate: RustCrate, writer: RustWriter, codegenContext: C, shape: StructureShape)
}

object EventStreamTestTools {
fun <C : CodegenContext> runTestCase(
fun <C : CodegenContext> setupTestCase(
testCase: EventStreamTestModels.TestCase,
requirements: EventStreamTestRequirements<C>,
codegenTarget: CodegenTarget,
variety: EventStreamTestVariety,
) {
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model))
transformers: List<(Model) -> Model> = listOf(),
): TestWriterDelegator {
val model = (listOf(OperationNormalizer::transform, EventStreamNormalizer::transform) + transformers).fold(testCase.model) { model, transformer ->
transformer(model)
}

val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val codegenContext = requirements.createCodegenContext(
model,
Expand All @@ -104,7 +110,8 @@ object EventStreamTestTools {
EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator)
}
}
test.project.compileAndTest()

return test.project
}

private fun <C : CodegenContext> generateTestProject(
Expand All @@ -128,7 +135,7 @@ object EventStreamTestTools {
requirements.renderOperationError(this, model, symbolProvider, operationShape)
requirements.renderOperationError(this, model, symbolProvider, unionShape)
for (shape in errors) {
requirements.renderError(this, codegenContext, shape)
requirements.renderError(project, this, codegenContext, shape)
}
}
val inputOutput = model.lookup<StructureShape>("test#TestStreamInputOutput")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class PythonServerCodegenVisitor(
serviceShape: ServiceShape,
symbolVisitorConfig: SymbolVisitorConfig,
publicConstrainedTypes: Boolean,
includeConstraintShapeProvider: Boolean,
) = RustServerCodegenPythonPlugin.baseSymbolProvider(model, serviceShape, symbolVisitorConfig, publicConstrainedTypes)

val serverSymbolProviders = ServerSymbolProviders.from(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin {
// Generate public constrained types for directly constrained shapes.
// In the Python server project, this is only done to generate constrained types for simple shapes (e.g.
// a `string` shape with the `length` trait), but these always remain `pub(crate)`.
.let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it }
.let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape, constrainedTypes) else it }
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) }
// Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShortShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.contextName
Expand All @@ -29,9 +33,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing
import software.amazon.smithy.rust.codegen.core.smithy.locatedIn
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderModule
import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait

/**
* The [ConstrainedShapeSymbolProvider] returns, for a given _directly_
Expand All @@ -56,14 +64,16 @@ class ConstrainedShapeSymbolProvider(
private val base: RustSymbolProvider,
private val model: Model,
private val serviceShape: ServiceShape,
private val publicConstrainedTypes: Boolean,
) : WrappingSymbolProvider(base) {
private val nullableIndex = NullableIndex.of(model)

private fun publicConstrainedSymbolForMapOrCollectionShape(shape: Shape): Symbol {
check(shape is MapShape || shape is CollectionShape)

val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase())
return symbolBuilder(shape, rustType).locatedIn(ServerRustModule.Model).build()
val (name, module) = getMemberNameAndModule(shape, serviceShape, ServerRustModule.Model, !publicConstrainedTypes)
val rustType = RustType.Opaque(name)
return symbolBuilder(shape, rustType).locatedIn(module).build()
}

override fun toSymbol(shape: Shape): Symbol {
Expand All @@ -74,8 +84,14 @@ class ConstrainedShapeSymbolProvider(
val target = model.expectShape(shape.target)
val targetSymbol = this.toSymbol(target)
// Handle boxing first, so we end up with `Option<Box<_>>`, not `Box<Option<_>>`.
handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode)
handleOptionality(
handleRustBoxing(targetSymbol, shape),
shape,
nullableIndex,
base.config().nullabilityCheckMode,
)
}

is MapShape -> {
if (shape.isDirectlyConstrained(base)) {
check(shape.hasTrait<LengthTrait>()) {
Expand All @@ -91,6 +107,7 @@ class ConstrainedShapeSymbolProvider(
.build()
}
}

is CollectionShape -> {
if (shape.isDirectlyConstrained(base)) {
check(constrainedCollectionCheck(shape)) {
Expand All @@ -105,8 +122,11 @@ class ConstrainedShapeSymbolProvider(

is StringShape, is IntegerShape, is ShortShape, is LongShape, is ByteShape, is BlobShape -> {
if (shape.isDirectlyConstrained(base)) {
val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase())
symbolBuilder(shape, rustType).locatedIn(ServerRustModule.Model).build()
// A standalone constrained shape goes into `ModelsModule`, but one
// arising from a constrained member shape goes into a module for the container.
val (name, module) = getMemberNameAndModule(shape, serviceShape, ServerRustModule.Model, !publicConstrainedTypes)
val rustType = RustType.Opaque(name)
symbolBuilder(shape, rustType).locatedIn(module).build()
} else {
base.toSymbol(shape)
}
Expand All @@ -122,9 +142,51 @@ class ConstrainedShapeSymbolProvider(
* - That it has no unsupported constraints applied.
*/
private fun constrainedCollectionCheck(shape: CollectionShape): Boolean {
val supportedConstraintTraits = supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet()
val supportedConstraintTraits =
supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet()
val allConstraintTraits = allConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet()

return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits).isEmpty()
return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits)
.isEmpty()
}

/**
* Returns the pair (Rust Symbol Name, Inline Module) for the shape. At the time of model transformation all
* constrained member shapes are extracted and are given a model-wide unique name. However, the generated code
* for the new shapes is in a module that is named after the containing shape (structure, list, map or union).
* The new shape's Rust Symbol is renamed from `{structureName}{memberName}` to `{structure_name}::{member_name}`
*/
private fun getMemberNameAndModule(
shape: Shape,
serviceShape: ServiceShape,
defaultModule: RustModule.LeafModule,
pubCrateServerBuilder: Boolean,
): Pair<String, RustModule.LeafModule> {
val syntheticMemberTrait = shape.getTrait<SyntheticStructureFromConstrainedMemberTrait>()
?: return Pair(shape.contextName(serviceShape), defaultModule)

return if (syntheticMemberTrait.container is StructureShape) {
val builderModule = syntheticMemberTrait.container.serverBuilderModule(base, pubCrateServerBuilder)
val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name
Pair(renameTo.toPascalCase(), builderModule)
} else {
// For non-structure shapes, the new shape defined for a constrained member shape
// needs to be placed in an inline module named `pub {container_name_in_snake_case}`.
val moduleName = RustReservedWords.escapeIfNeeded(syntheticMemberTrait.container.id.name.toSnakeCase())
val innerModuleName = moduleName + if (pubCrateServerBuilder) {
"_internal"
} else {
""
}

val innerModule = RustModule.new(
innerModuleName,
visibility = Visibility.publicIf(!pubCrateServerBuilder, Visibility.PUBCRATE),
parent = defaultModule,
inline = true,
)
val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name
Pair(renameTo.toPascalCase(), innerModule)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.locatedIn
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol
import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait

/**
* The [ConstraintViolationSymbolProvider] returns, for a given constrained
Expand Down Expand Up @@ -79,15 +81,29 @@ class ConstraintViolationSymbolProvider(

private fun Shape.shapeModule(): RustModule.LeafModule {
val documentation = if (publicConstrainedTypes && this.isDirectlyConstrained(base)) {
"See [`${this.contextName(serviceShape)}`]."
val symbol = base.toSymbol(this)
"See [`${this.contextName(serviceShape)}`]($symbol)."
} else {
null
}
return RustModule.new(

val syntheticTrait = getTrait<SyntheticStructureFromConstrainedMemberTrait>()

val (module, name) = if (syntheticTrait != null) {
// For constrained member shapes, the ConstraintViolation code needs to go in an inline rust module
// that is a descendant of the module that contains the extracted shape itself.
val overriddenMemberModule = this.getParentAndInlineModuleForConstrainedMember(base, publicConstrainedTypes)!!
val name = syntheticTrait.member.memberName
Pair(overriddenMemberModule.second, RustReservedWords.escapeIfNeeded(name).toSnakeCase())
} else {
// Need to use the context name so we get the correct name for maps.
name = RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase(),
Pair(ServerRustModule.Model, RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase())
}

return RustModule.new(
name = name,
visibility = visibility,
parent = ServerRustModule.Model,
parent = module,
inline = true,
documentation = documentation,
)
Expand Down