Skip to content

Commit

Permalink
Move RustSymbolProvider and related types out of SymbolVisitor (s…
Browse files Browse the repository at this point in the history
…mithy-lang#2380)

* Move base `RustSymbolProvider` types out of `SymbolVisitor`
* Rename `SymbolVisitorConfig` to `RustSymbolProviderConfig`
  • Loading branch information
jdisanti committed Feb 17, 2023
1 parent 3d00767 commit afb1f16
Show file tree
Hide file tree
Showing 18 changed files with 127 additions and 110 deletions.
Expand Up @@ -31,7 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
Expand Down Expand Up @@ -67,7 +67,7 @@ class ClientCodegenVisitor(
private val protocolGenerator: ClientProtocolGenerator

init {
val symbolVisitorConfig = SymbolVisitorConfig(
val rustSymbolProviderConfig = RustSymbolProviderConfig(
runtimeConfig = settings.runtimeConfig,
renameExceptions = settings.codegenConfig.renameExceptions,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
Expand All @@ -82,7 +82,7 @@ class ClientCodegenVisitor(
model = codegenDecorator.transformModel(untransformedService, baseModel)
// the model transformer _might_ change the service shape
val service = settings.getService(model)
symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig)
symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(model, service, rustSymbolProviderConfig)

codegenContext = ClientCodegenContext(model, symbolProvider, service, protocol, settings, codegenDecorator)

Expand Down
Expand Up @@ -23,10 +23,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolP
import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import java.util.logging.Level
import java.util.logging.Logger

Expand Down Expand Up @@ -74,10 +74,10 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
* The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered
* with other symbol providers, documented inline, to handle the full scope of Smithy types.
*/
fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig) =
SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig)
fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig) =
SymbolVisitor(model, serviceShape = serviceShape, config = rustSymbolProviderConfig)
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) }
.let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) }
// Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
.let { StreamingShapeSymbolProvider(it, model) }
// Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
Expand Down
Expand Up @@ -19,7 +19,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings

Expand Down Expand Up @@ -49,7 +49,7 @@ fun clientTestRustSettings(
customizationConfig,
)

val ClientTestSymbolVisitorConfig = SymbolVisitorConfig(
val ClientTestRustSymbolProviderConfig = RustSymbolProviderConfig(
runtimeConfig = TestRuntimeConfig,
renameExceptions = true,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
Expand All @@ -60,7 +60,7 @@ fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSy
RustClientCodegenPlugin.baseSymbolProvider(
model,
serviceShape ?: ServiceShape.builder().version("test").id("test#Service").build(),
ClientTestSymbolVisitorConfig,
ClientTestRustSymbolProviderConfig,
)

fun testCodegenContext(
Expand Down
Expand Up @@ -10,7 +10,7 @@ import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.testutil.ClientTestSymbolVisitorConfig
import software.amazon.smithy.rust.codegen.client.testutil.ClientTestRustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
Expand Down Expand Up @@ -46,7 +46,7 @@ class EventStreamSymbolProviderTest {
)

val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestSymbolVisitorConfig), model, CodegenTarget.CLIENT)
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), model, CodegenTarget.CLIENT)

// Look up the synthetic input/output rather than the original input/output
val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape
Expand Down Expand Up @@ -82,7 +82,7 @@ class EventStreamSymbolProviderTest {
)

val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestSymbolVisitorConfig), model, CodegenTarget.CLIENT)
val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), model, CodegenTarget.CLIENT)

// Look up the synthetic input/output rather than the original input/output
val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape
Expand Down
@@ -0,0 +1,73 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule

/**
* SymbolProvider interface that carries both the inner configuration and a function to produce an enum variant name.
*/
interface RustSymbolProvider : SymbolProvider, ModuleProvider {
fun config(): RustSymbolProviderConfig
fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?

override fun moduleForShape(shape: Shape): RustModule.LeafModule = config().moduleProvider.moduleForShape(shape)
override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
config().moduleProvider.moduleForOperationError(operation)
override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
config().moduleProvider.moduleForEventStreamError(eventStream)

/** Returns the symbol for an operation error */
fun symbolForOperationError(operation: OperationShape): Symbol

/** Returns the symbol for an event stream error */
fun symbolForEventStreamError(eventStream: UnionShape): Symbol
}

/**
* Provider for RustModules so that the symbol provider knows where to organize things.
*/
interface ModuleProvider {
/** Returns the module for a shape */
fun moduleForShape(shape: Shape): RustModule.LeafModule

/** Returns the module for an operation error */
fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule

/** Returns the module for an event stream error */
fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule
}

/**
* Configuration for symbol providers.
*/
data class RustSymbolProviderConfig(
val runtimeConfig: RuntimeConfig,
val renameExceptions: Boolean,
val nullabilityCheckMode: NullableIndex.CheckMode,
val moduleProvider: ModuleProvider,
)

/**
* Default delegator to enable easily decorating another symbol provider.
*/
open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider {
override fun config(): RustSymbolProviderConfig = base.config()
override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? = base.toEnumVariantName(definition)
override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape)
override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape)
override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation)
override fun symbolForEventStreamError(eventStream: UnionShape): Symbol =
base.symbolForEventStreamError(eventStream)
}
Expand Up @@ -14,12 +14,10 @@ import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.model.traits.StreamingTrait
Expand All @@ -28,19 +26,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.util.hasTrait

/**
* Default delegator to enable easily decorating another symbol provider.
*/
open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider {
override fun config(): SymbolVisitorConfig = base.config()
override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? = base.toEnumVariantName(definition)
override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape)
override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape)
override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation)
override fun symbolForEventStreamError(eventStream: UnionShape): Symbol =
base.symbolForEventStreamError(eventStream)
}

/**
* Attach `meta` to symbols. `meta` is used by the generators (e.g. StructureGenerator) to configure the generated models.
*
Expand Down
Expand Up @@ -64,27 +64,6 @@ val SimpleShapes: Map<KClass<out Shape>, RustType> = mapOf(
StringShape::class to RustType.String,
)

/**
* Provider for RustModules so that the symbol provider knows where to organize things.
*/
interface ModuleProvider {
/** Returns the module for a shape */
fun moduleForShape(shape: Shape): RustModule.LeafModule

/** Returns the module for an operation error */
fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule

/** Returns the module for an event stream error */
fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule
}

data class SymbolVisitorConfig(
val runtimeConfig: RuntimeConfig,
val renameExceptions: Boolean,
val nullabilityCheckMode: CheckMode,
val moduleProvider: ModuleProvider,
)

/**
* Track both the past and current name of a symbol
*
Expand All @@ -96,26 +75,6 @@ data class SymbolVisitorConfig(
*/
data class MaybeRenamed(val name: String, val renamedFrom: String?)

/**
* SymbolProvider interface that carries both the inner configuration and a function to produce an enum variant name.
*/
interface RustSymbolProvider : SymbolProvider, ModuleProvider {
fun config(): SymbolVisitorConfig
fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?

override fun moduleForShape(shape: Shape): RustModule.LeafModule = config().moduleProvider.moduleForShape(shape)
override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
config().moduleProvider.moduleForOperationError(operation)
override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
config().moduleProvider.moduleForEventStreamError(eventStream)

/** Returns the symbol for an operation error */
fun symbolForOperationError(operation: OperationShape): Symbol

/** Returns the symbol for an event stream error */
fun symbolForEventStreamError(eventStream: UnionShape): Symbol
}

/**
* Make the return [value] optional if the [member] symbol is as well optional.
*/
Expand Down Expand Up @@ -148,11 +107,11 @@ fun Shape.contextName(serviceShape: ServiceShape?): String {
open class SymbolVisitor(
private val model: Model,
private val serviceShape: ServiceShape?,
private val config: SymbolVisitorConfig,
private val config: RustSymbolProviderConfig,
) : RustSymbolProvider,
ShapeVisitor<Symbol> {
private val nullableIndex = NullableIndex.of(model)
override fun config(): SymbolVisitorConfig = config
override fun config(): RustSymbolProviderConfig = config

override fun toSymbol(shape: Shape): Symbol {
return shape.accept(this)
Expand Down
Expand Up @@ -28,8 +28,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
Expand Down Expand Up @@ -74,7 +74,7 @@ private object CodegenCoreTestModules {
}
}

val TestSymbolVisitorConfig = SymbolVisitorConfig(
val TestRustSymbolProviderConfig = RustSymbolProviderConfig(
runtimeConfig = TestRuntimeConfig,
renameExceptions = true,
nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
Expand Down Expand Up @@ -116,7 +116,7 @@ fun String.asSmithyModel(sourceLocation: String? = null, smithyVersion: String =
internal fun testSymbolProvider(model: Model): RustSymbolProvider = SymbolVisitor(
model,
ServiceShape.builder().version("test").id("test#Service").build(),
TestSymbolVisitorConfig,
TestRustSymbolProviderConfig,
).let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(Attribute.NonExhaustive)) }
.let { RustReservedWordSymbolProvider(it, model) }

Expand Down
Expand Up @@ -15,15 +15,15 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

internal class RustReservedWordSymbolProviderTest {
class Stub : RustSymbolProvider {
override fun config(): SymbolVisitorConfig = PANIC()
override fun config(): RustSymbolProviderConfig = PANIC()
override fun symbolForOperationError(operation: OperationShape): Symbol = PANIC()
override fun symbolForEventStreamError(eventStream: UnionShape): Symbol = PANIC()

Expand Down
Expand Up @@ -18,7 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.setDefault
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
Expand Down Expand Up @@ -66,7 +66,7 @@ internal class BuilderGeneratorTest {
val baseProvider = testSymbolProvider(StructureGeneratorTest.model)
val provider =
object : RustSymbolProvider {
override fun config(): SymbolVisitorConfig {
override fun config(): RustSymbolProviderConfig {
return baseProvider.config()
}

Expand Down

0 comments on commit afb1f16

Please sign in to comment.