From c3ae6f7eaf20e77014d6054d18cd0f03f754599f Mon Sep 17 00:00:00 2001 From: Julian Antonielli Date: Tue, 28 Feb 2023 20:26:20 +0000 Subject: [PATCH] Refactor event stream tests with `{client,server}IntegrationTest`s (#2342) * Refactor `ClientEventStreamUnmarshallerGeneratorTest` to use `clientIntegrationTest` (WIP) * Refactor `ClientEventStreamUnmarshallerGeneratorTest` with `clientIntegrationTest` * Refactor `ClientEventStreamUnmarshallerGeneratorTest` to use generic test cases * Start refactoring `ServerEventStreamUnmarshallerGeneratorTest` * Make `ServerEventStreamUnmarshallerGeneratorTest` tests work * Uncomment other test models * Allow unused on `parse_generic_error` * Rename `ServerEventStreamUnmarshallerGeneratorTest` * Make `EventStreamUnmarshallTestCases` codegenTarget-agnostic * Refactor `ClientEventStreamMarshallerGeneratorTest`: Tests run but fail * Refactor `ServerEventStreamMarshallerGeneratorTest` * Move `.into()` calls to `conditionalBuilderInput` * Add "context" to TODO * Fix client unmarshall tests * Fix clippy lint * Fix more clippy lints * Add docs for `event_stream_serde` module * Fix client tests * Remove `#[allow(missing_docs)]` from event stream module * Remove unused `EventStreamTestTools` * Add `smithy-validation-model` test dep to `codegen-client` * Temporarily add docs to make tests compile * Undo change in model * Make event stream unmarshaller tests a unit test * Remove unused code * Make `ServerEventStreamUnmarshallerGeneratorTest` a unit test * Make `ServerEventStreamMarshallerGeneratorTest` a unit test * Make `ServerEventStreamMarshallerGeneratorTest` pass * Make remaining tests non-integration tests * Make event stream serde module private again * Remove unnecessary clippy allowances * Remove clippy allowance * Remove docs for `event_stream_serde` module * Remove docs for `$unmarshallerTypeName::new` * Remove more unnecessary docs * Remove more superfluous docs * Undo unnecessary diffs * Uncomment last test * Make `conditionalBuilderInput` internal --- codegen-client/build.gradle.kts | 4 + .../ClientEventStreamBaseRequirements.kt | 98 ----- ...lientEventStreamMarshallerGeneratorTest.kt | 47 +-- ...entEventStreamUnmarshallerGeneratorTest.kt | 77 ++-- .../rust/codegen/core/rustlang/RustType.kt | 2 + .../rust/codegen/core/rustlang/RustWriter.kt | 2 +- .../parse/EventStreamUnmarshallerGenerator.kt | 20 +- .../EventStreamErrorMarshallerGenerator.kt | 5 +- .../EventStreamMarshallerGenerator.kt | 3 +- .../testutil/EventStreamMarshallTestCases.kt | 364 ++++++++++-------- .../core/testutil/EventStreamTestModels.kt | 21 +- .../core/testutil/EventStreamTestTools.kt | 185 --------- .../EventStreamUnmarshallTestCases.kt | 351 +++++++++-------- .../ServerEventStreamBaseRequirements.kt | 121 ------ ...erverEventStreamMarshallerGeneratorTest.kt | 68 ++-- ...verEventStreamUnmarshallerGeneratorTest.kt | 70 +--- .../inlineable/src/rest_xml_wrapped_errors.rs | 1 + 17 files changed, 541 insertions(+), 898 deletions(-) delete mode 100644 codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt delete mode 100644 codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt delete mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index ba6ac6ac1b..62d543beb4 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -28,6 +28,10 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") + + // `smithy.framework#ValidationException` is defined here, which is used in event stream +// marshalling/unmarshalling tests. + testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } tasks.compileKotlin { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt deleted file mode 100644 index 3ac00e5cd6..0000000000 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream - -import org.junit.jupiter.api.extension.ExtensionContext -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.ArgumentsProvider -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.traits.ErrorTrait -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator -import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator -import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings -import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.implBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements -import software.amazon.smithy.rust.codegen.core.util.expectTrait -import java.util.stream.Stream - -class TestCasesProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream = - EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream() -} - -abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements { - override fun createCodegenContext( - model: Model, - serviceShape: ServiceShape, - protocolShapeId: ShapeId, - codegenTarget: CodegenTarget, - ): ClientCodegenContext = ClientCodegenContext( - model, - testSymbolProvider(model), - serviceShape, - protocolShapeId, - testClientRustSettings(), - CombinedClientCodegenDecorator(emptyList()), - ) - - override fun renderBuilderForShape( - rustCrate: RustCrate, - writer: RustWriter, - codegenContext: ClientCodegenContext, - shape: StructureShape, - ) { - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, emptyList()).apply { - render(writer) - } - writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { - BuilderGenerator.renderConvenienceMethod(writer, codegenContext.symbolProvider, shape) - } - } - - override fun renderOperationError( - writer: RustWriter, - model: Model, - symbolProvider: RustSymbolProvider, - operationOrEventStream: Shape, - ) { - OperationErrorGenerator(model, symbolProvider, operationOrEventStream, emptyList()).render(writer) - } - - override fun renderError( - rustCrate: RustCrate, - writer: RustWriter, - codegenContext: ClientCodegenContext, - shape: StructureShape, - ) { - val errorTrait = shape.expectTrait() - val errorGenerator = ErrorGenerator( - codegenContext.model, - codegenContext.symbolProvider, - shape, - errorTrait, - emptyList(), - ) - rustCrate.useShapeWriter(shape) { - errorGenerator.renderStruct(this) - } - rustCrate.withModule(codegenContext.symbolProvider.moduleForBuilder(shape)) { - errorGenerator.renderBuilder(this) - } - } -} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt index b73d12a8d8..349f6a8cf3 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt @@ -5,43 +5,30 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream +import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import java.util.stream.Stream class ClientEventStreamMarshallerGeneratorTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) fun test(testCase: EventStreamTestModels.TestCase) { - EventStreamTestTools.setupTestCase( - testCase, - object : ClientEventStreamBaseRequirements() { - override fun renderGenerator( - codegenContext: ClientCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType = EventStreamMarshallerGenerator( - project.model, - CodegenTarget.CLIENT, - TestRuntimeConfig, - project.symbolProvider, - project.streamShape, - protocol.structuredDataSerializer(project.operationShape), - testCase.requestContentType, - ).render() - }, - CodegenTarget.CLIENT, - EventStreamTestVariety.Marshall, - ).compileAndTest() + clientIntegrationTest(testCase.model) { _, rustCrate -> + rustCrate.testModule { + writeMarshallTestCases(testCase, optionalBuilderInputs = false) + } + } } } + +class TestCasesProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream = + EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream() +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt index 11a253ceef..db44d62a61 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt @@ -7,39 +7,60 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.core.testutil.unitTest class ClientEventStreamUnmarshallerGeneratorTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) fun test(testCase: EventStreamTestModels.TestCase) { - EventStreamTestTools.setupTestCase( - testCase, - object : ClientEventStreamBaseRequirements() { - override fun renderGenerator( - codegenContext: ClientCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType { - return EventStreamUnmarshallerGenerator( - protocol, - codegenContext, - project.operationShape, - project.streamShape, - ).render() - } - }, - CodegenTarget.CLIENT, - EventStreamTestVariety.Unmarshall, - ).compileAndTest() + clientIntegrationTest( + testCase.model, + IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true), + ) { _, rustCrate -> + val generator = "crate::event_stream_serde::TestStreamUnmarshaller" + + rustCrate.testModule { + rust("##![allow(unused_imports, dead_code)]") + writeUnmarshallTestCases(testCase, optionalBuilderInputs = false) + + unitTest( + "unknown_message", + """ + let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!"); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert!(expect_event(result.unwrap()).is_unknown()); + """, + ) + + unitTest( + "generic_error", + """ + let message = msg( + "exception", + "UnmodeledError", + "${testCase.responseContentType}", + br#"${testCase.validUnmodeledError}"# + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + match expect_error(result.unwrap()) { + TestStreamError::Unhandled(err) => { + let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err)); + let expected = "message: \"unmodeled error\""; + assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'"); + } + kind => panic!("expected generic error, but got {:?}", kind), + } + """, + ) + } + } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 0ada867b42..9503607980 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -470,9 +470,11 @@ class Attribute(val inner: Writable) { val AllowDeprecated = Attribute(allow("deprecated")) val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns")) val AllowUnreachableCode = Attribute(allow("unreachable_code")) + val AllowUnreachablePatterns = Attribute(allow("unreachable_patterns")) val AllowUnusedImports = Attribute(allow("unused_imports")) val AllowUnusedMut = Attribute(allow("unused_mut")) val AllowUnusedVariables = Attribute(allow("unused_variables")) + val AllowMissingDocs = Attribute(allow("missing_docs")) val CfgTest = Attribute(cfg("test")) val DenyMissingDocs = Attribute(deny("missing_docs")) val DocHidden = Attribute(doc("hidden")) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index 19a1262696..1628fe9cca 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -116,7 +116,7 @@ private fun , U> T.withTemplate( * This enables conditionally wrapping a block in a prefix/suffix, e.g. * * ``` - * writer.withBlock("Some(", ")", conditional = symbol.isOptional()) { + * writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) { * write("symbolValue") * } * ``` diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index ae80c72c06..4fa98fa6ee 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -43,6 +43,9 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.toPascalCase +fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = + private("event_stream_serde") + class EventStreamUnmarshallerGenerator( private val protocol: Protocol, codegenContext: CodegenContext, @@ -60,7 +63,7 @@ class EventStreamUnmarshallerGenerator( symbolProvider.symbolForEventStreamError(unionShape) } private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) - private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val codegenScope = arrayOf( "Blob" to RuntimeType.blob(runtimeConfig), "expect_fns" to smithyEventStream.resolve("smithy"), @@ -84,15 +87,16 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) { + val unmarshallerTypeName = unmarshallerType.name rust( """ ##[non_exhaustive] ##[derive(Debug)] - pub struct ${unmarshallerType.name}; + pub struct $unmarshallerTypeName; - impl ${unmarshallerType.name} { + impl $unmarshallerTypeName { pub fn new() -> Self { - ${unmarshallerType.name} + $unmarshallerTypeName } } """, @@ -154,6 +158,7 @@ class EventStreamUnmarshallerGenerator( "Output" to unionSymbol, *codegenScope, ) + false -> rustTemplate( "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", *codegenScope, @@ -179,6 +184,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + payloadOnly -> { withBlock("let parsed = ", ";") { renderParseProtocolPayload(unionMember) @@ -189,6 +195,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + else -> { rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(unionStruct)) val payloadMember = unionStruct.members().firstOrNull { it.hasTrait() } @@ -265,6 +272,7 @@ class EventStreamUnmarshallerGenerator( is BlobShape -> { rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope) } + is StringShape -> { rustTemplate( """ @@ -275,6 +283,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + is UnionShape, is StructureShape -> { renderParseProtocolPayload(member) } @@ -312,6 +321,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + CodegenTarget.SERVER -> {} } @@ -350,6 +360,7 @@ class EventStreamUnmarshallerGenerator( ) } } + CodegenTarget.SERVER -> { val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) @@ -391,6 +402,7 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.CLIENT -> { rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope) } + CodegenTarget.SERVER -> { rustTemplate( """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt index ebd42d609d..3a0c5c1b30 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticEventStreamUnionTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors @@ -49,7 +50,7 @@ class EventStreamErrorMarshallerGenerator( } else { symbolProvider.symbolForEventStreamError(unionShape) } - private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val errorsShape = unionShape.expectTrait() private val codegenScope = arrayOf( "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), @@ -126,7 +127,7 @@ class EventStreamErrorMarshallerGenerator( } } - fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { + private fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { val headerMembers = eventStruct.members().filter { it.hasTrait() } val payloadMember = eventStruct.members().firstOrNull { it.hasTrait() } for (member in headerMembers) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index cb6833aaf7..201cd82ed5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -53,7 +54,7 @@ open class EventStreamMarshallerGenerator( private val payloadContentType: String, ) { private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) - private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val codegenScope = arrayOf( "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), "Message" to smithyEventStream.resolve("frame::Message"), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt index 6e82fc1b2c..95ea5677a4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt @@ -5,26 +5,35 @@ package software.amazon.smithy.rust.codegen.core.testutil +import org.intellij.lang.annotations.Language import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.dq -internal object EventStreamMarshallTestCases { - internal fun RustWriter.writeMarshallTestCases( +object EventStreamMarshallTestCases { + fun RustWriter.writeMarshallTestCases( testCase: EventStreamTestModels.TestCase, - generator: RuntimeType, + optionalBuilderInputs: Boolean, ) { + val generator = "crate::event_stream_serde::TestStreamMarshaller" + val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig) .copy(scope = DependencyScope.Compile) + + fun builderInput( + @Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") + input: String, + vararg ctx: Pair, + ): Writable = conditionalBuilderInput(input, conditional = optionalBuilderInputs, ctx = ctx) + rustTemplate( """ use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage}; use std::collections::HashMap; use aws_smithy_types::{Blob, DateTime}; - use crate::error::*; use crate::model::*; use #{validate_body}; @@ -46,163 +55,192 @@ internal object EventStreamMarshallTestCases { "MediaType" to protocolTestHelpers.toType().resolve("MediaType"), ) - unitTest( - "message_with_blob", - """ - let event = TestStream::MessageWithBlob( - MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap()); - assert_eq!(&b"hello, world!"[..], message.payload()); - """, - ) - - unitTest( - "message_with_string", - """ - let event = TestStream::MessageWithString( - MessageWithString::builder().data("hello, world!").build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap()); - assert_eq!(&b"hello, world!"[..], message.payload()); - """, - ) - - unitTest( - "message_with_struct", - """ - let event = TestStream::MessageWithStruct( - MessageWithStruct::builder().some_struct( - TestStruct::builder() - .some_string("hello") - .some_int(5) - .build() - ).build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validTestStruct.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - - unitTest( - "message_with_union", - """ - let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( - TestUnion::Foo("hello".into()) - ).build()); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validTestUnion.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - - unitTest( - "message_with_headers", - """ - let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder() - .blob(Blob::new(&b"test"[..])) - .boolean(true) - .byte(55i8) - .int(100_000i32) - .long(9_000_000_000i64) - .short(16_000i16) - .string("test") - .timestamp(DateTime::from_secs(5)) - .build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let actual_message = result.unwrap(); - let expected_message = Message::new(&b""[..]) - .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) - .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into()))) - .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) - .add_header(Header::new("boolean", HeaderValue::Bool(true))) - .add_header(Header::new("byte", HeaderValue::Byte(55i8))) - .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) - .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) - .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) - .add_header(Header::new("string", HeaderValue::String("test".into()))) - .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); - assert_eq!(expected_message, actual_message); - """, - ) - - unitTest( - "message_with_header_and_payload", - """ - let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() - .header("header") - .payload(Blob::new(&b"payload"[..])) - .build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let actual_message = result.unwrap(); - let expected_message = Message::new(&b"payload"[..]) - .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) - .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into()))) - .add_header(Header::new("header", HeaderValue::String("header".into()))) - .add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into()))); - assert_eq!(expected_message, actual_message); - """, - ) - - unitTest( - "message_with_no_header_payload_traits", - """ - let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() - .some_int(5) - .some_string("hello") - .build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validMessageWithNoHeaderPayloadTraits.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) + unitTest("message_with_blob") { + rustTemplate( + """ + let event = TestStream::MessageWithBlob( + MessageWithBlob::builder().data(#{BlobInput:W}).build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap()); + assert_eq!(&b"hello, world!"[..], message.payload()); + """, + "BlobInput" to builderInput("Blob::new(&b\"hello, world!\"[..])"), + ) + } + + unitTest("message_with_string") { + rustTemplate( + """ + let event = TestStream::MessageWithString( + MessageWithString::builder().data(#{StringInput}).build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap()); + assert_eq!(&b"hello, world!"[..], message.payload()); + """, + "StringInput" to builderInput("\"hello, world!\""), + ) + } + + unitTest("message_with_struct") { + rustTemplate( + """ + let event = TestStream::MessageWithStruct( + MessageWithStruct::builder().some_struct(#{StructInput}).build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validTestStruct.dq()}, + MediaType::from(${testCase.mediaType.dq()}) + ).unwrap(); + """, + "StructInput" to + builderInput( + """ + TestStruct::builder() + .some_string(#{StringInput}) + .some_int(#{IntInput}) + .build() + """, + "IntInput" to builderInput("5"), + "StringInput" to builderInput("\"hello\""), + ), + ) + } + + unitTest("message_with_union") { + rustTemplate( + """ + let event = TestStream::MessageWithUnion(MessageWithUnion::builder() + .some_union(#{UnionInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validTestUnion.dq()}, + MediaType::from(${testCase.mediaType.dq()}) + ).unwrap(); + """, + "UnionInput" to builderInput("TestUnion::Foo(\"hello\".into())"), + ) + } + + unitTest("message_with_headers") { + rustTemplate( + """ + let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder() + .blob(#{BlobInput}) + .boolean(#{BooleanInput}) + .byte(#{ByteInput}) + .int(#{IntInput}) + .long(#{LongInput}) + .short(#{ShortInput}) + .string(#{StringInput}) + .timestamp(#{TimestampInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let actual_message = result.unwrap(); + let expected_message = Message::new(&b""[..]) + .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) + .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into()))) + .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) + .add_header(Header::new("boolean", HeaderValue::Bool(true))) + .add_header(Header::new("byte", HeaderValue::Byte(55i8))) + .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) + .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) + .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) + .add_header(Header::new("string", HeaderValue::String("test".into()))) + .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); + assert_eq!(expected_message, actual_message); + """, + "BlobInput" to builderInput("Blob::new(&b\"test\"[..])"), + "BooleanInput" to builderInput("true"), + "ByteInput" to builderInput("55i8"), + "IntInput" to builderInput("100_000i32"), + "LongInput" to builderInput("9_000_000_000i64"), + "ShortInput" to builderInput("16_000i16"), + "StringInput" to builderInput("\"test\""), + "TimestampInput" to builderInput("DateTime::from_secs(5)"), + ) + } + + unitTest("message_with_header_and_payload") { + rustTemplate( + """ + let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() + .header(#{HeaderInput}) + .payload(#{PayloadInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let actual_message = result.unwrap(); + let expected_message = Message::new(&b"payload"[..]) + .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) + .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into()))) + .add_header(Header::new("header", HeaderValue::String("header".into()))) + .add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into()))); + assert_eq!(expected_message, actual_message); + """, + "HeaderInput" to builderInput("\"header\""), + "PayloadInput" to builderInput("Blob::new(&b\"payload\"[..])"), + ) + } + + unitTest("message_with_no_header_payload_traits") { + rustTemplate( + """ + let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() + .some_int(#{IntInput}) + .some_string(#{StringInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validMessageWithNoHeaderPayloadTraits.dq()}, + MediaType::from(${testCase.mediaType.dq()}) + ).unwrap(); + """, + "IntInput" to builderInput("5"), + "StringInput" to builderInput("\"hello\""), + ) + } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index 58ab85eec6..e944a552a0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -19,6 +19,7 @@ private fun fillInBaseModel( ): String = """ namespace test + use smithy.framework#ValidationException use aws.protocols#$protocolName union TestUnion { @@ -69,12 +70,20 @@ private fun fillInBaseModel( MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits, SomeError: SomeError, } - structure TestStreamInputOutput { @httpPayload @required value: TestStream } + + structure TestStreamInputOutput { + @required + @httpPayload + value: TestStream + } + + @http(method: "POST", uri: "/test") operation TestStreamOp { input: TestStreamInputOutput, output: TestStreamInputOutput, - errors: [SomeError], + errors: [SomeError, ValidationException], } + $extraServiceAnnotations @$protocolName service TestService { version: "123", operations: [TestStreamOp] } @@ -92,6 +101,7 @@ object EventStreamTestModels { data class TestCase( val protocolShapeId: String, val model: Model, + val mediaType: String, val requestContentType: String, val responseContentType: String, val validTestStruct: String, @@ -111,7 +121,8 @@ object EventStreamTestModels { TestCase( protocolShapeId = "aws.protocols#restJson1", model = restJson1(), - requestContentType = "application/json", + mediaType = "application/json", + requestContentType = "application/vnd.amazon.eventstream", responseContentType = "application/json", validTestStruct = """{"someString":"hello","someInt":5}""", validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", @@ -126,6 +137,7 @@ object EventStreamTestModels { TestCase( protocolShapeId = "aws.protocols#awsJson1_1", model = awsJson11(), + mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", validTestStruct = """{"someString":"hello","someInt":5}""", @@ -141,7 +153,8 @@ object EventStreamTestModels { TestCase( protocolShapeId = "aws.protocols#restXml", model = restXml(), - requestContentType = "application/xml", + mediaType = "application/xml", + requestContentType = "application/vnd.amazon.eventstream", responseContentType = "application/xml", validTestStruct = """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt deleted file mode 100644 index cfa92fb70c..0000000000 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.testutil - -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.ErrorTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer -import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases -import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.core.util.lookup -import software.amazon.smithy.rust.codegen.core.util.outputShape - -data class TestEventStreamProject( - val model: Model, - val serviceShape: ServiceShape, - val operationShape: OperationShape, - val streamShape: UnionShape, - val symbolProvider: RustSymbolProvider, - val project: TestWriterDelegator, -) - -enum class EventStreamTestVariety { - Marshall, - Unmarshall, -} - -interface EventStreamTestRequirements { - /** Create a codegen context for the tests */ - fun createCodegenContext( - model: Model, - serviceShape: ServiceShape, - protocolShapeId: ShapeId, - codegenTarget: CodegenTarget, - ): C - - /** Render the event stream marshall/unmarshall code generator */ - fun renderGenerator( - codegenContext: C, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType - - /** Render a builder for the given shape */ - fun renderBuilderForShape( - rustCrate: RustCrate, - writer: RustWriter, - codegenContext: C, - shape: StructureShape, - ) - - /** Render an operation error for the given operation and error shapes */ - fun renderOperationError( - writer: RustWriter, - model: Model, - symbolProvider: RustSymbolProvider, - operationOrEventStream: Shape, - ) - - /** Render an error struct and builder */ - fun renderError(rustCrate: RustCrate, writer: RustWriter, codegenContext: C, shape: StructureShape) -} - -object EventStreamTestTools { - fun setupTestCase( - testCase: EventStreamTestModels.TestCase, - requirements: EventStreamTestRequirements, - codegenTarget: CodegenTarget, - variety: EventStreamTestVariety, - transformers: List<(Model) -> Model> = listOf(), - ): TestWriterDelegator { - val model = (listOf(OperationNormalizer::transform, EventStreamNormalizer::transform) + transformers).fold(testCase.model) { model, transformer -> - transformer(model) - } - - val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val codegenContext = requirements.createCodegenContext( - model, - serviceShape, - ShapeId.from(testCase.protocolShapeId), - codegenTarget, - ) - val test = generateTestProject(requirements, codegenContext, codegenTarget) - val protocol = testCase.protocolBuilder(codegenContext) - val generator = requirements.renderGenerator(codegenContext, test, protocol) - - test.project.lib { - when (variety) { - EventStreamTestVariety.Marshall -> writeMarshallTestCases(testCase, generator) - EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator) - } - } - - return test.project - } - - private fun generateTestProject( - requirements: EventStreamTestRequirements, - codegenContext: C, - codegenTarget: CodegenTarget, - ): TestEventStreamProject { - val model = codegenContext.model - val symbolProvider = codegenContext.symbolProvider - val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape - val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape - val walker = DirectedWalker(model) - - val project = TestWorkspace.testProject(symbolProvider) - val errors = model.serviceShapes - .flatMap { walker.walkShapes(it) } - .filterIsInstance() - .filter { shape -> shape.hasTrait() } - check(errors.isNotEmpty()) { "must have at least one error modeled" } - project.withModule(codegenContext.symbolProvider.moduleForShape(errors[0])) { - requirements.renderOperationError(this, model, symbolProvider, operationShape) - requirements.renderOperationError(this, model, symbolProvider, unionShape) - for (shape in errors) { - requirements.renderError(project, this, codegenContext, shape) - } - } - val inputOutput = model.lookup("test#TestStreamInputOutput") - project.withModule(codegenContext.symbolProvider.moduleForShape(inputOutput)) { - recursivelyGenerateModels(project, model, symbolProvider, inputOutput, this, codegenTarget) - } - operationShape.outputShape(model).also { outputShape -> - outputShape.renderWithModelBuilder(model, symbolProvider, project) - } - return TestEventStreamProject( - model, - codegenContext.serviceShape, - operationShape, - unionShape, - symbolProvider, - project, - ) - } - - private fun recursivelyGenerateModels( - rustCrate: RustCrate, - model: Model, - symbolProvider: RustSymbolProvider, - shape: Shape, - writer: RustWriter, - mode: CodegenTarget, - ) { - for (member in shape.members()) { - if (member.target.namespace == "smithy.api") { - continue - } - val target = model.expectShape(member.target) - when (target) { - is StructureShape -> target.renderWithModelBuilder(model, symbolProvider, rustCrate) - is UnionShape -> UnionGenerator( - model, - symbolProvider, - writer, - target, - renderUnknownVariant = mode.renderUnknownVariant(), - ).render() - else -> TODO("EventStreamTestTools doesn't support rendering $target") - } - recursivelyGenerateModels(rustCrate, model, symbolProvider, target, writer, mode) - } - } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt index 36b3efafd0..daff01af3f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt @@ -5,17 +5,22 @@ package software.amazon.smithy.rust.codegen.core.testutil +import org.intellij.lang.annotations.Language +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable -internal object EventStreamUnmarshallTestCases { - internal fun RustWriter.writeUnmarshallTestCases( +object EventStreamUnmarshallTestCases { + fun RustWriter.writeUnmarshallTestCases( testCase: EventStreamTestModels.TestCase, - codegenTarget: CodegenTarget, - generator: RuntimeType, + optionalBuilderInputs: Boolean = false, ) { + val generator = "crate::event_stream_serde::TestStreamUnmarshaller" + rust( """ use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage}; @@ -53,202 +58,199 @@ internal object EventStreamUnmarshallTestCases { """, ) - unitTest( - name = "message_with_blob", - test = """ + unitTest("message_with_blob") { + rustTemplate( + """ let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!"); - let result = ${format(generator)}().unmarshall(&message); + let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); assert_eq!( TestStream::MessageWithBlob( - MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() + MessageWithBlob::builder().data(#{DataInput:W}).build() ), expect_event(result.unwrap()) ); - """, - ) + """, + "DataInput" to conditionalBuilderInput( + """ + Blob::new(&b"hello, world!"[..]) + """, + conditional = optionalBuilderInputs, + ), - if (codegenTarget == CodegenTarget.CLIENT) { - unitTest( - "unknown_message", + ) + } + + unitTest("message_with_string") { + rustTemplate( """ - let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!"); - let result = ${format(generator)}().unmarshall(&message); + let message = msg("event", "MessageWithString", "text/plain", b"hello, world!"); + let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); assert_eq!( - TestStream::Unknown, + TestStream::MessageWithString(MessageWithString::builder().data(#{DataInput}).build()), expect_event(result.unwrap()) ); """, + "DataInput" to conditionalBuilderInput("\"hello, world!\"", conditional = optionalBuilderInputs), ) } - unitTest( - "message_with_string", - """ - let message = msg("event", "MessageWithString", "text/plain", b"hello, world!"); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_struct", - """ - let message = msg( - "event", - "MessageWithStruct", - "${testCase.responseContentType}", - br#"${testCase.validTestStruct}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct( + unitTest("message_with_struct") { + rustTemplate( + """ + let message = msg( + "event", + "MessageWithStruct", + "${testCase.responseContentType}", + br##"${testCase.validTestStruct}"## + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(#{StructInput}).build()), + expect_event(result.unwrap()) + ); + """, + "StructInput" to conditionalBuilderInput( + """ TestStruct::builder() - .some_string("hello") - .some_int(5) + .some_string(#{StringInput}) + .some_int(#{IntInput}) .build() - ).build()), - expect_event(result.unwrap()) - ); - """, - ) + """, + conditional = optionalBuilderInputs, + "StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs), + "IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs), + ), - unitTest( - "message_with_union", - """ - let message = msg( - "event", - "MessageWithUnion", - "${testCase.responseContentType}", - br#"${testCase.validTestUnion}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( - TestUnion::Foo("hello".into()) - ).build()), - expect_event(result.unwrap()) - ); - """, - ) + ) + } - unitTest( - "message_with_headers", - """ - let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"") - .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) - .add_header(Header::new("boolean", HeaderValue::Bool(true))) - .add_header(Header::new("byte", HeaderValue::Byte(55i8))) - .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) - .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) - .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) - .add_header(Header::new("string", HeaderValue::String("test".into()))) - .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithHeaders(MessageWithHeaders::builder() - .blob(Blob::new(&b"test"[..])) - .boolean(true) - .byte(55i8) - .int(100_000i32) - .long(9_000_000_000i64) - .short(16_000i16) - .string("test") - .timestamp(DateTime::from_secs(5)) - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) + unitTest("message_with_union") { + rustTemplate( + """ + let message = msg( + "event", + "MessageWithUnion", + "${testCase.responseContentType}", + br##"${testCase.validTestUnion}"## + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(#{UnionInput}).build()), + expect_event(result.unwrap()) + ); + """, + "UnionInput" to conditionalBuilderInput("TestUnion::Foo(\"hello\".into())", conditional = optionalBuilderInputs), + ) + } - unitTest( - "message_with_header_and_payload", - """ - let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload") - .add_header(Header::new("header", HeaderValue::String("header".into()))); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() - .header("header") - .payload(Blob::new(&b"payload"[..])) - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) + unitTest("message_with_headers") { + rustTemplate( + """ + let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"") + .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) + .add_header(Header::new("boolean", HeaderValue::Bool(true))) + .add_header(Header::new("byte", HeaderValue::Byte(55i8))) + .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) + .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) + .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) + .add_header(Header::new("string", HeaderValue::String("test".into()))) + .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithHeaders(MessageWithHeaders::builder() + .blob(#{BlobInput}) + .boolean(#{BoolInput}) + .byte(#{ByteInput}) + .int(#{IntInput}) + .long(#{LongInput}) + .short(#{ShortInput}) + .string(#{StringInput}) + .timestamp(#{TimestampInput}) + .build() + ), + expect_event(result.unwrap()) + ); + """, + "BlobInput" to conditionalBuilderInput("Blob::new(&b\"test\"[..])", conditional = optionalBuilderInputs), + "BoolInput" to conditionalBuilderInput("true", conditional = optionalBuilderInputs), + "ByteInput" to conditionalBuilderInput("55i8", conditional = optionalBuilderInputs), + "IntInput" to conditionalBuilderInput("100_000i32", conditional = optionalBuilderInputs), + "LongInput" to conditionalBuilderInput("9_000_000_000i64", conditional = optionalBuilderInputs), + "ShortInput" to conditionalBuilderInput("16_000i16", conditional = optionalBuilderInputs), + "StringInput" to conditionalBuilderInput("\"test\"", conditional = optionalBuilderInputs), + "TimestampInput" to conditionalBuilderInput("DateTime::from_secs(5)", conditional = optionalBuilderInputs), + ) + } - unitTest( - "message_with_no_header_payload_traits", - """ - let message = msg( - "event", - "MessageWithNoHeaderPayloadTraits", - "${testCase.responseContentType}", - br#"${testCase.validMessageWithNoHeaderPayloadTraits}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() - .some_int(5) - .some_string("hello") - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) + unitTest("message_with_header_and_payload") { + rustTemplate( + """ + let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload") + .add_header(Header::new("header", HeaderValue::String("header".into()))); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() + .header(#{HeaderInput}) + .payload(#{PayloadInput}) + .build() + ), + expect_event(result.unwrap()) + ); + """, + "HeaderInput" to conditionalBuilderInput("\"header\"", conditional = optionalBuilderInputs), + "PayloadInput" to conditionalBuilderInput("Blob::new(&b\"payload\"[..])", conditional = optionalBuilderInputs), + ) + } - unitTest( - "some_error", - """ - let message = msg( - "exception", - "SomeError", - "${testCase.responseContentType}", - br#"${testCase.validSomeError}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - match expect_error(result.unwrap()) { - TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()), - kind => panic!("expected SomeError, but got {:?}", kind), - } - """, - ) + unitTest("message_with_no_header_payload_traits") { + rustTemplate( + """ + let message = msg( + "event", + "MessageWithNoHeaderPayloadTraits", + "${testCase.responseContentType}", + br##"${testCase.validMessageWithNoHeaderPayloadTraits}"## + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() + .some_int(#{IntInput}) + .some_string(#{StringInput}) + .build() + ), + expect_event(result.unwrap()) + ); + """, + "IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs), + "StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs), + ) + } - if (codegenTarget == CodegenTarget.CLIENT) { - unitTest( - "error_metadata", + unitTest("some_error") { + rustTemplate( """ let message = msg( "exception", - "UnmodeledError", + "SomeError", "${testCase.responseContentType}", - br#"${testCase.validUnmodeledError}"# + br##"${testCase.validSomeError}"## ); - let result = ${format(generator)}().unmarshall(&message); + let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); match expect_error(result.unwrap()) { - TestStreamError::Unhandled(err) => { - let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err)); - let expected = "message: \"unmodeled error\""; - assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'"); - } - kind => panic!("expected error metadata, but got {:?}", kind), + TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()), + #{AllowUnreachablePatterns:W} + kind => panic!("expected SomeError, but got {:?}", kind), } """, + "AllowUnreachablePatterns" to writable { Attribute.AllowUnreachablePatterns.render(this) }, ) } @@ -261,10 +263,21 @@ internal object EventStreamUnmarshallTestCases { "wrong-content-type", br#"${testCase.validTestStruct}"# ); - let result = ${format(generator)}().unmarshall(&message); + let result = $generator::new().unmarshall(&message); assert!(result.is_err(), "expected error, got: {:?}", result); assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be")); """, ) } } + +internal fun conditionalBuilderInput( + @Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") contents: String, + conditional: Boolean, + vararg ctx: Pair, +): Writable = + writable { + conditionalBlock("Some(", ".into())", conditional = conditional) { + rustTemplate(contents, *ctx) + } + } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt deleted file mode 100644 index 8891b8d849..0000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream - -import org.junit.jupiter.api.extension.ExtensionContext -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.ArgumentsProvider -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.implBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements -import software.amazon.smithy.rust.codegen.core.util.getTrait -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings -import java.util.stream.Stream - -data class TestCase( - val eventStreamTestCase: EventStreamTestModels.TestCase, - val publicConstrainedTypes: Boolean, -) { - override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes" -} - -class TestCasesProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream = - EventStreamTestModels.TEST_CASES - .flatMap { testCase -> - listOf( - TestCase(testCase, publicConstrainedTypes = false), - TestCase(testCase, publicConstrainedTypes = true), - ) - }.map { Arguments.of(it) }.stream() -} - -abstract class ServerEventStreamBaseRequirements : EventStreamTestRequirements { - abstract val publicConstrainedTypes: Boolean - - override fun createCodegenContext( - model: Model, - serviceShape: ServiceShape, - protocolShapeId: ShapeId, - codegenTarget: CodegenTarget, - ): ServerCodegenContext = serverTestCodegenContext( - model, - serviceShape, - serverTestRustSettings( - codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes), - ), - protocolShapeId, - ) - - override fun renderBuilderForShape( - rustCrate: RustCrate, - writer: RustWriter, - codegenContext: ServerCodegenContext, - shape: StructureShape, - ) { - val validationExceptionConversionGenerator = SmithyValidationExceptionConversionGenerator(codegenContext) - if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { - ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator).apply { - render(rustCrate, writer) - writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { - renderConvenienceMethod(writer) - } - } - } else { - ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator).apply { - render(rustCrate, writer) - writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { - renderConvenienceMethod(writer) - } - } - } - } - - override fun renderOperationError( - writer: RustWriter, - model: Model, - symbolProvider: RustSymbolProvider, - operationOrEventStream: Shape, - ) { - ServerOperationErrorGenerator(model, symbolProvider, operationOrEventStream).render(writer) - } - - override fun renderError( - rustCrate: RustCrate, - writer: RustWriter, - codegenContext: ServerCodegenContext, - shape: StructureShape, - ) { - StructureGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, listOf()).render() - ErrorImplGenerator( - codegenContext.model, - codegenContext.symbolProvider, - writer, - shape, - shape.getTrait()!!, - listOf(), - ).render(CodegenTarget.SERVER) - renderBuilderForShape(rustCrate, writer, codegenContext, shape) - } -} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt index cf0c4f94b1..c4c15742e3 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt @@ -5,49 +5,43 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream +import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.util.stream.Stream class ServerEventStreamMarshallerGeneratorTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) fun test(testCase: TestCase) { - val testProject = EventStreamTestTools.setupTestCase( - testCase.eventStreamTestCase, - object : ServerEventStreamBaseRequirements() { - override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes - - override fun renderGenerator( - codegenContext: ServerCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType { - return EventStreamMarshallerGenerator( - project.model, - CodegenTarget.SERVER, - TestRuntimeConfig, - project.symbolProvider, - project.streamShape, - protocol.structuredDataSerializer(project.operationShape), - testCase.eventStreamTestCase.requestContentType, - ).render() - } - }, - CodegenTarget.SERVER, - EventStreamTestVariety.Marshall, - ) - testProject.renderInlineMemoryModules() - testProject.compileAndTest() + serverIntegrationTest(testCase.eventStreamTestCase.model) { _, rustCrate -> + rustCrate.testModule { + writeMarshallTestCases(testCase.eventStreamTestCase, optionalBuilderInputs = true) + } + } } } + +data class TestCase( + val eventStreamTestCase: EventStreamTestModels.TestCase, + val publicConstrainedTypes: Boolean, +) { + override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes" +} + +class TestCasesProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream = + EventStreamTestModels.TEST_CASES + .flatMap { testCase -> + listOf( + TestCase(testCase, publicConstrainedTypes = false), + TestCase(testCase, publicConstrainedTypes = true), + ) + }.map { Arguments.of(it) }.stream() +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt index 11e9d65b4d..7d88d00ee6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt @@ -7,21 +7,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.implBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest class ServerEventStreamUnmarshallerGeneratorTest { @ParameterizedTest @@ -33,45 +22,16 @@ class ServerEventStreamUnmarshallerGeneratorTest { return } - val testProject = EventStreamTestTools.setupTestCase( - testCase.eventStreamTestCase, - object : ServerEventStreamBaseRequirements() { - override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes - - override fun renderGenerator( - codegenContext: ServerCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType { - return EventStreamUnmarshallerGenerator( - protocol, - codegenContext, - project.operationShape, - project.streamShape, - ).render() - } - - // TODO(https://github.com/awslabs/smithy-rs/issues/1442): Delete this function override to use the correct builder from the parent class - override fun renderBuilderForShape( - rustCrate: RustCrate, - writer: RustWriter, - codegenContext: ServerCodegenContext, - shape: StructureShape, - ) { - rustCrate.withModule(codegenContext.symbolProvider.moduleForBuilder(shape)) { - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, emptyList()).render(this) - } - rustCrate.moduleFor(shape) { - writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { - BuilderGenerator.renderConvenienceMethod(this, codegenContext.symbolProvider, shape) - } - } - } - }, - CodegenTarget.SERVER, - EventStreamTestVariety.Unmarshall, - transformers = listOf(ConstrainedMemberTransform::transform), - ) - testProject.compileAndTest() + serverIntegrationTest( + testCase.eventStreamTestCase.model, + IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true), + ) { _, rustCrate -> + rustCrate.testModule { + writeUnmarshallTestCases( + testCase.eventStreamTestCase, + optionalBuilderInputs = true, + ) + } + } } } diff --git a/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs b/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs index 2511929a06..b735b77249 100644 --- a/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs +++ b/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs @@ -14,6 +14,7 @@ pub fn body_is_error(body: &[u8]) -> Result { Ok(scoped.start_el().matches("ErrorResponse")) } +#[allow(dead_code)] pub fn parse_error_metadata(body: &[u8]) -> Result { let mut doc = Document::try_from(body)?; let mut root = doc.root_element()?;