Skip to content

Commit

Permalink
Prevent test dependencies from leaking into production (smithy-lang#2264
Browse files Browse the repository at this point in the history
)

* Prevent test dependencies from leaking into production

* refactor & fix tests

* fix tests take two

* fix more tests

* Fix missed called to mergeDependencyFeatures

* Add test

* fix glacier compilation

* fix more tests

* fix one more test
  • Loading branch information
rcoh committed Feb 6, 2023
1 parent 7bf9251 commit c9275fb
Show file tree
Hide file tree
Showing 20 changed files with 323 additions and 123 deletions.
Expand Up @@ -16,7 +16,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg
import software.amazon.smithy.rust.codegen.core.rustlang.RustGenerics
Expand Down Expand Up @@ -228,7 +227,7 @@ private class AwsFluentClientDocs(private val codegenContext: CodegenContext) :
private val serviceShape = codegenContext.serviceShape
private val crateName = codegenContext.moduleUseName()
private val codegenScope =
arrayOf("aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).copy(scope = DependencyScope.Dev).toType())
arrayOf("aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType())

// If no `aws-config` version is provided, assume that docs referencing `aws-config` cannot be given.
// Also, STS and SSO must NOT reference `aws-config` since that would create a circular dependency.
Expand Down
Expand Up @@ -7,7 +7,6 @@ package software.amazon.smithy.rustsdk

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation
Expand Down Expand Up @@ -63,7 +62,7 @@ object AwsRuntimeType {
fun awsCredentialTypes(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsCredentialTypes(runtimeConfig).toType()

fun awsCredentialTypesTestUtil(runtimeConfig: RuntimeConfig) =
AwsCargoDependency.awsCredentialTypes(runtimeConfig).copy(scope = DependencyScope.Dev).withFeature("test-util").toType()
AwsCargoDependency.awsCredentialTypes(runtimeConfig).toDevDependency().withFeature("test-util").toType()

fun awsEndpoint(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsEndpoint(runtimeConfig).toType()
fun awsHttp(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsHttp(runtimeConfig).toType()
Expand Down
Expand Up @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import java.nio.file.Files
import java.nio.file.Paths
import kotlin.io.path.absolute
Expand Down Expand Up @@ -72,7 +73,7 @@ class IntegrationTestDependencies(
private val hasBenches: Boolean,
) : LibRsCustomization() {
override fun section(section: LibRsSection) = when (section) {
is LibRsSection.Body -> writable {
is LibRsSection.Body -> testDependenciesOnly {
if (hasTests) {
val smithyClient = CargoDependency.smithyClient(runtimeConfig)
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
Expand All @@ -81,7 +82,7 @@ class IntegrationTestDependencies(
addDependency(SerdeJson)
addDependency(Tokio)
addDependency(FuturesUtil)
addDependency(Tracing)
addDependency(Tracing.toDevDependency())
addDependency(TracingSubscriber)
}
if (hasBenches) {
Expand All @@ -91,6 +92,7 @@ class IntegrationTestDependencies(
serviceSpecific.section(section)(this)
}
}

else -> emptySection
}

Expand All @@ -114,8 +116,8 @@ class S3TestDependencies : LibRsCustomization() {
override fun section(section: LibRsSection): Writable =
writable {
addDependency(AsyncStd)
addDependency(BytesUtils)
addDependency(FastRand)
addDependency(BytesUtils.toDevDependency())
addDependency(FastRand.toDevDependency())
addDependency(HdrHistogram)
addDependency(Smol)
addDependency(TempFile)
Expand Down
Expand Up @@ -33,13 +33,16 @@ private val UploadMultipartPart: ShapeId = ShapeId.from("com.amazonaws.glacier#U
private val Applies = setOf(UploadArchive, UploadMultipartPart)

class TreeHashHeader(private val runtimeConfig: RuntimeConfig) : OperationCustomization() {
private val glacierChecksums = RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("glacier_checksums"))
private val glacierChecksums = RuntimeType.forInlineDependency(
InlineAwsDependency.forRustFile(
"glacier_checksums",
additionalDependency = TreeHashDependencies.toTypedArray(),
),
)

override fun section(section: OperationSection): Writable {
return when (section) {
is OperationSection.MutateRequest -> writable {
TreeHashDependencies.forEach { dep ->
addDependency(dep)
}
rustTemplate(
"""
#{glacier_checksums}::add_checksum_treehash(
Expand All @@ -49,6 +52,7 @@ class TreeHashHeader(private val runtimeConfig: RuntimeConfig) : OperationCustom
"glacier_checksums" to glacierChecksums, "BuildError" to runtimeConfig.operationBuildError(),
)
}

else -> emptySection
}
}
Expand Down
Expand Up @@ -17,14 +17,14 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesG
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.PublicImportSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
Expand Down Expand Up @@ -146,8 +146,7 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test:
let _result = dbg!(#{invoke_operation});
#{assertion}
""",
"capture_request" to CargoDependency.smithyClient(runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
"conf" to config(testOperationInput),
"invoke_operation" to operationInvocation(testOperationInput),
"assertion" to writable {
Expand Down
Expand Up @@ -6,8 +6,8 @@
package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest
Expand Down Expand Up @@ -96,8 +96,7 @@ class EndpointsCredentialsTest {
let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap();
assert!(auth_header.contains("/us-west-2/foobaz/aws4_request"), "{}", auth_header);
""",
"capture_request" to CargoDependency.smithyClient(context.runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(context.runtimeConfig),
"Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig)
.withFeature("test-util").toType().resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"),
Expand All @@ -120,8 +119,7 @@ class EndpointsCredentialsTest {
let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap();
assert!(auth_header.contains("/region-custom-auth/name-custom-auth/aws4_request"), "{}", auth_header);
""",
"capture_request" to CargoDependency.smithyClient(context.runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(context.runtimeConfig),
"Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig)
.withFeature("test-util").toType().resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"),
Expand Down
Expand Up @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.symbol
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand Down Expand Up @@ -59,7 +58,7 @@ val EndpointTests = RustModule.new(
documentation = "Generated endpoint tests",
parent = EndpointsModule,
inline = true,
).copy(rustMetadata = RustMetadata.TestModule)
).cfgTest()

// stdlib is isolated because it contains code generated names of stdlib functions–we want to ensure we avoid clashing
val EndpointsStdLib = RustModule.private("endpoint_lib", "Endpoints standard library functions")
Expand Down
Expand Up @@ -14,7 +14,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.escape
Expand Down Expand Up @@ -48,8 +47,7 @@ internal class EndpointTestGenerator(
"Error" to types.resolveEndpointError,
"Document" to RuntimeType.document(runtimeConfig),
"HashMap" to RuntimeType.HashMap,
"capture_request" to CargoDependency.smithyClient(runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
)

private val instantiator = clientInstantiator(codegenContext)
Expand Down
Expand Up @@ -22,10 +22,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand Down Expand Up @@ -91,14 +89,10 @@ class ProtocolTestGenerator(
if (allTests.isNotEmpty()) {
val operationName = operationSymbol.name
val testModuleName = "${operationName.toSnakeCase()}_request_test"
val moduleMeta = RustMetadata(
visibility = Visibility.PRIVATE,
additionalAttributes = listOf(
Attribute.CfgTest,
Attribute(allow("unreachable_code", "unused_variables")),
),
val additionalAttributes = listOf(
Attribute(allow("unreachable_code", "unused_variables")),
)
writer.withInlineModule(RustModule.LeafModule(testModuleName, moduleMeta, inline = true)) {
writer.withInlineModule(RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)) {
renderAllTestCases(allTests)
}
}
Expand Down
Expand Up @@ -133,6 +133,8 @@ data class CargoDependency(
return copy(features = features.toMutableSet().apply { add(feature) })
}

fun toDevDependency() = copy(scope = DependencyScope.Dev)

override fun version(): String = when (location) {
is CratesIo -> location.version
is Local -> "local"
Expand Down Expand Up @@ -220,7 +222,12 @@ data class CargoDependency(
val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev)
val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev)
val Tokio: CargoDependency =
CargoDependency("tokio", CratesIo("1.8.4"), DependencyScope.Dev, features = setOf("macros", "test-util", "rt-multi-thread"))
CargoDependency(
"tokio",
CratesIo("1.8.4"),
DependencyScope.Dev,
features = setOf("macros", "test-util", "rt-multi-thread"),
)
val TracingAppender: CargoDependency = CargoDependency(
"tracing-appender",
CratesIo("0.2.2"),
Expand All @@ -236,12 +243,16 @@ data class CargoDependency(
fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async")
fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums")
fun smithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-client")
fun smithyClientTestUtil(runtimeConfig: RuntimeConfig) =
smithyClient(runtimeConfig).toDevDependency().withFeature("test-util")

fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream")
fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http")
fun smithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-tower")
fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json")
fun smithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) =
runtimeConfig.smithyRuntimeCrate("smithy-protocol-test", scope = DependencyScope.Dev)

fun smithyQuery(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-query")
fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types")
fun smithyXml(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-xml")
Expand Down
Expand Up @@ -32,7 +32,10 @@ sealed class RustModule {
val documentation: String? = null,
val parent: RustModule = LibRs,
val inline: Boolean = false,
/* module is a cfg(test) module */
val tests: Boolean = false,
) : RustModule() {

init {
check(!name.contains("::")) {
"Module names CANNOT contain `::`—modules must be nested with parent (name was: `$name`)"
Expand All @@ -45,6 +48,12 @@ sealed class RustModule {
"Module `$name` cannot be a module name—it is a reserved word."
}
}

/** Convert a module into a module gated with `#[cfg(test)]` */
fun cfgTest(): LeafModule = this.copy(
rustMetadata = rustMetadata.copy(additionalAttributes = rustMetadata.additionalAttributes + Attribute.CfgTest),
tests = true,
)
}

companion object {
Expand Down Expand Up @@ -78,12 +87,36 @@ sealed class RustModule {
fun pubCrate(name: String, documentation: String? = null, parent: RustModule): LeafModule =
new(name, visibility = Visibility.PUBCRATE, documentation = documentation, inline = false, parent = parent)

fun inlineTests(
name: String = "test",
parent: RustModule = LibRs,
additionalAttributes: List<Attribute> = listOf(),
) = new(
name,
Visibility.PRIVATE,
inline = true,
additionalAttributes = additionalAttributes,
parent = parent,
).cfgTest()

/* Common modules used across client, server and tests */
val Config = public("config", documentation = "Configuration for the service.")
val Error = public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.")
val Model = public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.")
val Input = public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.")
val Output = public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.")
val Error = public(
"error",
documentation = "All error types that operations can return. Documentation on these types is copied from the model.",
)
val Model = public(
"model",
documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.",
)
val Input = public(
"input",
documentation = "Input structures for operations. Documentation on these types is copied from the model.",
)
val Output = public(
"output",
documentation = "Output structures for operations. Documentation on these types is copied from the model.",
)
val Types = public("types", documentation = "Data primitives referenced by other data types.")

/**
Expand Down

0 comments on commit c9275fb

Please sign in to comment.