Skip to content

Commit

Permalink
Remove toEnumVariantName from RustSymbolProvider
Browse files Browse the repository at this point in the history
The `toEnumVariantName` function existed on symbol provider to work
around enum definitions not being shapes. In the future when we refactor
to use `EnumShape` instead of `EnumTrait`, there will be `MemberShape`s
for each enum member. This change incrementally moves us to that future
by creating fake `MemberShape`s in the enum generator from the enum
definition.
  • Loading branch information
jdisanti committed Feb 15, 2023
1 parent 1c1a3ef commit 81478db
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.lookup
Expand Down Expand Up @@ -38,8 +40,18 @@ internal class StreamingShapeSymbolProviderTest {
// "doing the right thing"
val modelWithOperationTraits = OperationNormalizer.transform(model)
val symbolProvider = testSymbolProvider(modelWithOperationTraits)
symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechOutput\$data")).name shouldBe ("ByteStream")
symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechInput\$data")).name shouldBe ("ByteStream")
modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechOutput\$data").also { shape ->
symbolProvider.toSymbol(shape).also { symbol ->
symbol.name shouldBe "data"
symbol.rustType() shouldBe RustType.Opaque("ByteStream", "aws_smithy_http::byte_stream")
}
}
modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechInput\$data").also { shape ->
symbolProvider.toSymbol(shape).also { symbol ->
symbol.name shouldBe "data"
symbol.rustType() shouldBe RustType.Opaque("ByteStream", "aws_smithy_http::byte_stream")
}
}
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,30 @@ import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider
import software.amazon.smithy.codegen.core.ReservedWords
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, private val model: Model) :
WrappingSymbolProvider(base) {
private val internal =
ReservedWordSymbolProvider.builder().symbolProvider(base).memberReservedWords(RustReservedWords).build()

override fun toMemberName(shape: MemberShape): String {
val baseName = internal.toMemberName(shape)
return when (val container = model.expectShape(shape.container)) {
is StructureShape -> when (baseName) {
val baseName = super.toMemberName(shape)
val reservedWordReplacedName = internal.toMemberName(shape)
val container = model.expectShape(shape.container)
return when {
container is StructureShape -> when (baseName) {
"build" -> "build_value"
"builder" -> "builder_value"
"default" -> "default_value"
Expand All @@ -42,10 +43,10 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva
"customize" -> "customize_value"
// To avoid conflicts with the error metadata `meta` field
"meta" -> "meta_value"
else -> baseName
else -> reservedWordReplacedName
}

is UnionShape -> when (baseName) {
container is UnionShape -> when (baseName) {
// Unions contain an `Unknown` variant. This exists to support parsing data returned from the server
// that represent union variants that have been added since this SDK was generated.
UnionGenerator.UnknownVariantName -> "${UnionGenerator.UnknownVariantName}Value"
Expand All @@ -55,7 +56,20 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva
"Self" -> "SelfValue"
// Real models won't end in `_` so it's safe to stop here
"SelfValue" -> "SelfValue_"
else -> baseName
else -> reservedWordReplacedName
}

container is EnumShape || container.hasTrait<EnumTrait>() -> when (baseName) {
// Self cannot be used as a raw identifier, so we can't use the normal escaping strategy
// https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4
"Self" -> "SelfValue"
// Real models won't end in `_` so it's safe to stop here
"SelfValue" -> "SelfValue_"
// Unknown is used as the name of the variant containing unexpected values
"Unknown" -> "UnknownValue"
// Real models won't end in `_` so it's safe to stop here
"UnknownValue" -> "UnknownValue_"
else -> reservedWordReplacedName
}

else -> error("unexpected container: $container")
Expand All @@ -69,46 +83,31 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva
* code generators to generate special docs.
*/
override fun toSymbol(shape: Shape): Symbol {
// Sanity check that the symbol provider stack is set up correctly
check(super.toSymbol(shape).renamedFrom() == null) {
"RustReservedWordSymbolProvider should only run once"
}

var renamedSymbol = internal.toSymbol(shape)
return when (shape) {
is MemberShape -> {
val container = model.expectShape(shape.container)
if (!(container is StructureShape || container is UnionShape)) {
val containerIsEnum = container is EnumShape || container.hasTrait<EnumTrait>()
if (container !is StructureShape && container !is UnionShape && !containerIsEnum) {
return base.toSymbol(shape)
}
val previousName = base.toMemberName(shape)
val escapedName = this.toMemberName(shape)
val baseSymbol = base.toSymbol(shape)
// if the names don't match and it isn't a simple escaping with `r#`, record a rename
baseSymbol.letIf(escapedName != previousName && !escapedName.contains("r#")) {
it.toBuilder().renamedFrom(previousName).build()
}
renamedSymbol.toBuilder().name(escapedName)
.letIf(escapedName != previousName && !escapedName.contains("r#")) {
it.renamedFrom(previousName)
}.build()
}

else -> base.toSymbol(shape)
}
}

override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
val baseName = base.toEnumVariantName(definition) ?: return null
check(definition.name.orNull()?.toPascalCase() == baseName.name) {
"Enum variants must already be in pascal case ${baseName.name} differed from ${baseName.name.toPascalCase()}. Definition: ${definition.name}"
}
check(baseName.renamedFrom == null) {
"definitions should only pass through the renamer once"
}
return when (baseName.name) {
// Self cannot be used as a raw identifier, so we can't use the normal escaping strategy
// https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4
"Self" -> MaybeRenamed("SelfValue", "Self")
// Real models won't end in `_` so it's safe to stop here
"SelfValue" -> MaybeRenamed("SelfValue_", "SelfValue")
// Unknown is used as the name of the variant containing unexpected values
"Unknown" -> MaybeRenamed("UnknownValue", "Unknown")
// Real models won't end in `_` so it's safe to stop here
"UnknownValue" -> MaybeRenamed("UnknownValue_", "UnknownValue")
else -> baseName
}
}
}

object RustReservedWords : ReservedWords {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.model.traits.StreamingTrait
Expand All @@ -33,7 +32,6 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait
*/
open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider {
override fun config(): SymbolVisitorConfig = base.config()
override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? = base.toEnumVariantName(definition)
override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape)
override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape)
override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation)
Expand Down Expand Up @@ -130,6 +128,11 @@ class BaseSymbolMetadataProvider(
}

is UnionShape, is CollectionShape, is MapShape -> RustMetadata(visibility = Visibility.PUBLIC)

// This covers strings with the enum trait for now, and can be removed once we're fully on EnumShape
// TODO(EnumShape): Remove this `is StringShape` match arm
is StringShape -> RustMetadata(visibility = Visibility.PUBLIC)

else -> TODO("Unrecognized container type: $container")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.ByteShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
Expand All @@ -35,7 +36,6 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
Expand Down Expand Up @@ -179,7 +179,6 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?)
*/
interface RustSymbolProvider : SymbolProvider, ModuleProvider {
fun config(): SymbolVisitorConfig
fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?

override fun moduleForShape(shape: Shape): RustModule.LeafModule = config().moduleProvider.moduleForShape(shape)
override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
Expand Down Expand Up @@ -248,21 +247,13 @@ open class SymbolVisitor(
module.toType().resolve("${symbol.name}Error").toSymbol().toBuilder().locatedIn(module).build()
}

/**
* Return the name of a given `enum` variant. Note that this refers to `enum` in the Smithy context
* where enum is a trait that can be applied to [StringShape] and not in the Rust context of an algebraic data type.
*
* Because enum variants are not member shape, a separate handler is required.
*/
override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
val baseName = definition.name.orNull()?.toPascalCase() ?: return null
return MaybeRenamed(baseName, null)
}

override fun toMemberName(shape: MemberShape): String = when (val container = model.expectShape(shape.container)) {
is StructureShape -> shape.memberName.toSnakeCase()
is UnionShape -> shape.memberName.toPascalCase()
else -> error("unexpected container shape: $container")
override fun toMemberName(shape: MemberShape): String {
val container = model.expectShape(shape.container)
return when {
container is StructureShape -> shape.memberName.toSnakeCase()
container is UnionShape || container is EnumShape || container.hasTrait<EnumTrait>() -> shape.memberName.toPascalCase()
else -> error("unexpected container shape: $container")
}
}

override fun blobShape(shape: BlobShape?): Symbol {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.model.traits.EnumDefinition
Expand All @@ -27,12 +29,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom
import software.amazon.smithy.rust.codegen.core.util.REDACTION
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.shouldRedact
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

data class EnumGeneratorContext(
val enumName: String,
Expand Down Expand Up @@ -71,14 +75,41 @@ abstract class EnumType {
}

/** Model that wraps [EnumDefinition] to calculate and cache values required to generate the Rust enum source. */
class EnumMemberModel(private val definition: EnumDefinition, private val symbolProvider: RustSymbolProvider) {
class EnumMemberModel(
private val parentShape: Shape,
private val definition: EnumDefinition,
private val symbolProvider: RustSymbolProvider,
) {
companion object {
/**
* Return the name of a given `enum` variant. Note that this refers to `enum` in the Smithy context
* where enum is a trait that can be applied to [StringShape] and not in the Rust context of an algebraic data type.
*
* Ordinarily, the symbol provider would determine this name, but the enum trait doesn't allow for this.
*
* TODO(EnumShape): Remove this function when refactoring to EnumShape.
*/
@Deprecated("This function will go away when we handle EnumShape instead of EnumTrait")
fun toEnumVariantName(
symbolProvider: RustSymbolProvider,
parentShape: Shape,
definition: EnumDefinition,
): MaybeRenamed? {
val name = definition.name.orNull()?.toPascalCase() ?: return null
// Create a fake member shape for symbol look up until we refactor to use EnumShape
val fakeMemberShape =
MemberShape.builder().id(parentShape.id.withMember(name)).target("smithy.api#String").build()
val symbol = symbolProvider.toSymbol(fakeMemberShape)
return MaybeRenamed(symbol.name, symbol.renamedFrom())
}
}
// Because enum variants always start with an upper case letter, they will never
// conflict with reserved words (which are always lower case), therefore, we never need
// to fall back to raw identifiers

val value: String get() = definition.value

fun name(): MaybeRenamed? = symbolProvider.toEnumVariantName(definition)
fun name(): MaybeRenamed? = toEnumVariantName(symbolProvider, parentShape, definition)

private fun renderDocumentation(writer: RustWriter) {
val name =
Expand All @@ -97,7 +128,7 @@ class EnumMemberModel(private val definition: EnumDefinition, private val symbol
}
}

fun derivedName() = checkNotNull(symbolProvider.toEnumVariantName(definition)).name
fun derivedName() = checkNotNull(toEnumVariantName(symbolProvider, parentShape, definition)).name

fun render(writer: RustWriter) {
renderDocumentation(writer)
Expand Down Expand Up @@ -138,7 +169,7 @@ open class EnumGenerator(
enumName = symbol.name,
enumMeta = symbol.expectRustMetadata(),
enumTrait = enumTrait,
sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(it, symbolProvider) },
sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(shape, it, symbolProvider) },
)

fun render(writer: RustWriter) {
Expand Down

0 comments on commit 81478db

Please sign in to comment.