Skip to content

Commit

Permalink
Make modules in codegen-core configurable (smithy-lang#2336)
Browse files Browse the repository at this point in the history
* Refactor modules to be configurable in `codegen-core`
* Remove panicking default test symbol provider
* Remove as many references to Error/Types as possible
* Rename module constants
  • Loading branch information
jdisanti committed Feb 10, 2023
1 parent de18667 commit cdc710d
Show file tree
Hide file tree
Showing 87 changed files with 858 additions and 719 deletions.
Expand Up @@ -75,7 +75,7 @@ private class AwsClientGenerics(private val types: Types) : FluentClientGenerics
override fun sendBounds(
operation: Symbol,
operationOutput: Symbol,
operationError: RuntimeType,
operationError: Symbol,
retryClassifier: RuntimeType,
): Writable =
writable { }
Expand Down
Expand Up @@ -35,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
import software.amazon.smithy.rust.codegen.core.util.expectTrait
Expand Down Expand Up @@ -155,7 +154,7 @@ class AwsInputPresignedMethod(
}

private fun RustWriter.writeInputPresignedMethod(section: OperationSection.InputImpl) {
val operationError = operationShape.errorSymbol(symbolProvider)
val operationError = symbolProvider.symbolForOperationError(operationShape)
val presignableOp = PRESIGNABLE_OPERATIONS.getValue(operationShape.id)

val makeOperationOp = if (presignableOp.hasModelTransforms()) {
Expand Down
Expand Up @@ -6,10 +6,10 @@
package software.amazon.smithy.rustsdk

import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
Expand Down Expand Up @@ -105,7 +105,7 @@ class SdkConfigDecorator : ClientCodegenDecorator {
val codegenScope = arrayOf(
"SdkConfig" to AwsRuntimeType.awsTypes(codegenContext.runtimeConfig).resolve("sdk_config::SdkConfig"),
)
rustCrate.withModule(RustModule.Config) {
rustCrate.withModule(ClientRustModule.Config) {
rustTemplate(
"""
impl From<&#{SdkConfig}> for Builder {
Expand Down
Expand Up @@ -24,7 +24,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader
import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage
import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
Expand All @@ -33,15 +32,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerat
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.OperationErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors
import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
Expand All @@ -56,7 +52,6 @@ class ClientCodegenVisitor(
context: PluginContext,
private val codegenDecorator: ClientCodegenDecorator,
) : ShapeVisitor.Default<Unit>() {

private val logger = Logger.getLogger(javaClass.name)
private val settings = ClientRustSettings.from(context.model, context.settings)

Expand All @@ -69,12 +64,12 @@ class ClientCodegenVisitor(
private val protocolGenerator: ClientProtocolGenerator

init {
val symbolVisitorConfig =
SymbolVisitorConfig(
runtimeConfig = settings.runtimeConfig,
renameExceptions = settings.codegenConfig.renameExceptions,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
)
val symbolVisitorConfig = SymbolVisitorConfig(
runtimeConfig = settings.runtimeConfig,
renameExceptions = settings.codegenConfig.renameExceptions,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
moduleProvider = ClientModuleProvider,
)
val baseModel = baselineTransform(context.model)
val untransformedService = settings.getService(baseModel)
val (protocol, generator) = ClientProtocolLoader(
Expand Down Expand Up @@ -224,13 +219,8 @@ class ClientCodegenVisitor(
UnionGenerator(model, symbolProvider, this, shape, renderUnknownVariant = true).render()
}
if (shape.isEventStream()) {
rustCrate.withModule(RustModule.Error) {
val symbol = symbolProvider.toSymbol(shape)
val errors = shape.eventStreamErrors()
.map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) }
val errorSymbol = shape.eventStreamErrorSymbol(symbolProvider)
OperationErrorGenerator(model, symbolProvider, symbol, errors)
.renderErrors(this, errorSymbol, symbol)
rustCrate.withModule(ClientRustModule.Error) {
OperationErrorGenerator(model, symbolProvider, shape).render(this)
}
}
}
Expand All @@ -239,14 +229,8 @@ class ClientCodegenVisitor(
* Generate errors for operation shapes
*/
override fun operationShape(shape: OperationShape) {
rustCrate.withModule(RustModule.Error) {
val operationSymbol = symbolProvider.toSymbol(shape)
OperationErrorGenerator(
model,
symbolProvider,
operationSymbol,
shape.operationErrors(model).map { it.asStructureShape().get() },
).render(this)
rustCrate.withModule(ClientRustModule.Error) {
OperationErrorGenerator(model, symbolProvider, shape).render(this)
}
}
}
@@ -0,0 +1,59 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

/**
* Modules for code generated client crates.
*/
object ClientRustModule {
/** crate::client */
val client = Client.self
object Client {
/** crate::client */
val self = RustModule.public("client", "Client and fluent builders for calling the service.")

/** crate::client::customize */
val customize = RustModule.public("customize", "Operation customization and supporting types", parent = self)
}

val Config = RustModule.public("config", documentation = "Configuration for the service.")
val Error = RustModule.public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.")
val Operation = RustModule.public("operation", documentation = "All operations that this crate can perform.")
val Model = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.")
val Input = RustModule.public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.")
val Output = RustModule.public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.")
val Types = RustModule.public("types", documentation = "Data primitives referenced by other data types.")
}

object ClientModuleProvider : ModuleProvider {
override fun moduleForShape(shape: Shape): RustModule.LeafModule = when (shape) {
is OperationShape -> ClientRustModule.Operation
is StructureShape -> when {
shape.hasTrait<ErrorTrait>() -> ClientRustModule.Error
shape.hasTrait<SyntheticInputTrait>() -> ClientRustModule.Input
shape.hasTrait<SyntheticOutputTrait>() -> ClientRustModule.Output
else -> ClientRustModule.Model
}
else -> ClientRustModule.Model
}

override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
ClientRustModule.Error

override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
ClientRustModule.Error
}
Expand Up @@ -5,9 +5,9 @@

package software.amazon.smithy.rust.codegen.client.smithy.customizations

import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
Expand Down Expand Up @@ -235,7 +235,7 @@ class ResiliencyConfigCustomization(codegenContext: CodegenContext) : ConfigCust

class ResiliencyReExportCustomization(private val runtimeConfig: RuntimeConfig) {
fun extras(rustCrate: RustCrate) {
rustCrate.withModule(RustModule.Config) {
rustCrate.withModule(ClientRustModule.Config) {
rustTemplate(
"""
pub use #{sleep}::{AsyncSleep, Sleep};
Expand Down
Expand Up @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.customize

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customizations.EndpointPrefixGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpChecksumRequiredGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpVersionListCustomization
Expand Down Expand Up @@ -65,6 +66,8 @@ class RequiredCustomizations : ClientCodegenDecorator {
// Re-export resiliency types
ResiliencyReExportCustomization(codegenContext.runtimeConfig).extras(rustCrate)

pubUseSmithyTypes(codegenContext.runtimeConfig, codegenContext.model, rustCrate)
rustCrate.withModule(ClientRustModule.Types) {
pubUseSmithyTypes(codegenContext.runtimeConfig, codegenContext.model)(this)
}
}
}
Expand Up @@ -24,7 +24,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.findMemberWithTrait
Expand Down Expand Up @@ -79,7 +78,7 @@ class PaginatorGenerator private constructor(
private val inputType = symbolProvider.toSymbol(operation.inputShape(model))
private val outputShape = operation.outputShape(model)
private val outputType = symbolProvider.toSymbol(outputShape)
private val errorType = operation.errorSymbol(symbolProvider)
private val errorType = symbolProvider.symbolForOperationError(operation)

private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun(
paginatorName,
Expand Down
Expand Up @@ -7,12 +7,12 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfigGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServiceErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
Expand Down Expand Up @@ -58,7 +58,7 @@ class ServiceGenerator(

ServiceErrorGenerator(clientCodegenContext, operations).render(rustCrate)

rustCrate.withModule(RustModule.Config) {
rustCrate.withModule(ClientRustModule.Config) {
ServiceConfigGenerator.withBaseBehavior(
clientCodegenContext,
extraCustomizations = decorator.configCustomizations(clientCodegenContext, listOf()),
Expand Down
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators.client

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
Expand Down Expand Up @@ -65,7 +66,7 @@ sealed class FluentClientSection(name: String) : Section(name) {
/** Write custom code into an operation fluent builder's impl block */
data class FluentBuilderImpl(
val operationShape: OperationShape,
val operationErrorType: RuntimeType,
val operationErrorType: Symbol,
) : FluentClientSection("FluentBuilderImpl")

/** Write custom code into the docs */
Expand Down
Expand Up @@ -44,7 +44,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -168,7 +167,7 @@ class FluentClientGenerator(

val output = operation.outputShape(model)
val operationOk = symbolProvider.toSymbol(output)
val operationErr = operation.errorSymbol(symbolProvider).toSymbol()
val operationErr = symbolProvider.symbolForOperationError(operation)

val inputFieldsBody =
generateOperationShapeDocs(writer, symbolProvider, operation, model).joinToString("\n") {
Expand Down Expand Up @@ -263,7 +262,7 @@ class FluentClientGenerator(
"bounds" to generics.bounds,
) {
val outputType = symbolProvider.toSymbol(operation.outputShape(model))
val errorType = operation.errorSymbol(symbolProvider)
val errorType = symbolProvider.symbolForOperationError(operation)

// Have to use fully-qualified result here or else it could conflict with an op named Result
rustTemplate(
Expand Down Expand Up @@ -333,7 +332,7 @@ class FluentClientGenerator(
customizations,
FluentClientSection.FluentBuilderImpl(
operation,
operation.errorSymbol(symbolProvider),
symbolProvider.symbolForOperationError(operation),
),
)
input.members().forEach { member ->
Expand Down
Expand Up @@ -28,7 +28,7 @@ interface FluentClientGenerics {
val bounds: Writable

/** Bounds for generated `send()` functions */
fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryClassifier: RuntimeType): Writable
fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: Symbol, retryClassifier: RuntimeType): Writable

/** Convert this `FluentClientGenerics` into the more general `RustGenerics` */
fun toRustGenerics(): RustGenerics
Expand Down Expand Up @@ -70,7 +70,7 @@ data class FlexibleClientGenerics(
}

/** Bounds for generated `send()` functions */
override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryClassifier: RuntimeType): Writable = writable {
override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: Symbol, retryClassifier: RuntimeType): Writable = writable {
rustTemplate(
"""
where
Expand Down
Expand Up @@ -34,11 +34,7 @@ class ResponseBindingGenerator(

fun generateDeserializePayloadFn(
binding: HttpBindingDescriptor,
errorT: RuntimeType,
errorSymbol: Symbol,
payloadParser: RustWriter.(String) -> Unit,
): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn(
binding,
errorT,
payloadParser,
)
): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn(binding, errorSymbol, payloadParser)
}

0 comments on commit cdc710d

Please sign in to comment.