diff --git a/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java b/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java index 006df5ad8..2c93b2222 100644 --- a/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java +++ b/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java @@ -35,8 +35,6 @@ public class RestJson1ProtocolTests { @HttpClientRequestTests @ProtocolTestFilter( skipTests = { - // TODO: support checksums in requests - "RestJsonHttpChecksumRequired", // TODO: These tests require a payload even when the httpPayload member is null. Should it? "RestJsonHttpWithHeadersButNoPayload", "RestJsonHttpWithEmptyStructurePayload", diff --git a/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java b/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java index 95113e535..e231b41bf 100644 --- a/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java +++ b/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java @@ -32,6 +32,7 @@ import software.amazon.smithy.java.client.http.mock.MockPlugin; import software.amazon.smithy.java.client.http.mock.MockQueue; import software.amazon.smithy.java.client.http.plugins.ApplyHttpRetryInfoPlugin; +import software.amazon.smithy.java.client.http.plugins.HttpChecksumPlugin; import software.amazon.smithy.java.client.http.plugins.UserAgentPlugin; import software.amazon.smithy.java.core.serde.document.Document; import software.amazon.smithy.java.dynamicclient.DynamicClient; @@ -95,6 +96,7 @@ public void tracksPlugins() throws URISyntaxException { // And HttpMessageExchange applies the UserAgentPlugin and ApplyHttpRetryInfoPlugin. UserAgentPlugin.class, ApplyHttpRetryInfoPlugin.class, + HttpChecksumPlugin.class, // User plugins are applied last. FooPlugin.class)); } diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpMessageExchange.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpMessageExchange.java index f3509e961..edbb68ec2 100644 --- a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpMessageExchange.java +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpMessageExchange.java @@ -8,6 +8,7 @@ import software.amazon.smithy.java.client.core.ClientConfig; import software.amazon.smithy.java.client.core.MessageExchange; import software.amazon.smithy.java.client.http.plugins.ApplyHttpRetryInfoPlugin; +import software.amazon.smithy.java.client.http.plugins.HttpChecksumPlugin; import software.amazon.smithy.java.client.http.plugins.UserAgentPlugin; import software.amazon.smithy.java.http.api.HttpRequest; import software.amazon.smithy.java.http.api.HttpResponse; @@ -19,6 +20,7 @@ * */ public final class HttpMessageExchange implements MessageExchange { @@ -33,5 +35,6 @@ private HttpMessageExchange() {} public void configureClient(ClientConfig.Builder config) { config.applyPlugin(new UserAgentPlugin()); config.applyPlugin(new ApplyHttpRetryInfoPlugin()); + config.applyPlugin(new HttpChecksumPlugin()); } } diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/plugins/HttpChecksumPlugin.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/plugins/HttpChecksumPlugin.java new file mode 100644 index 000000000..5b894721c --- /dev/null +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/plugins/HttpChecksumPlugin.java @@ -0,0 +1,68 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.plugins; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import software.amazon.smithy.java.client.core.ClientConfig; +import software.amazon.smithy.java.client.core.ClientPlugin; +import software.amazon.smithy.java.client.core.interceptors.ClientInterceptor; +import software.amazon.smithy.java.client.core.interceptors.RequestHook; +import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.java.http.api.HttpRequest; +import software.amazon.smithy.java.io.ByteBufferUtils; +import software.amazon.smithy.model.traits.HttpChecksumRequiredTrait; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Plugin that adds Content-MD5 header for operations with @httpChecksumRequired trait. + */ +@SmithyInternalApi +public final class HttpChecksumPlugin implements ClientPlugin { + + @Override + public void configureClient(ClientConfig.Builder config) { + config.addInterceptor(HttpChecksumInterceptor.INSTANCE); + } + + static final class HttpChecksumInterceptor implements ClientInterceptor { + private static final ClientInterceptor INSTANCE = new HttpChecksumInterceptor(); + private static final TraitKey CHECKSUM_REQUIRED_TRAIT_KEY = + TraitKey.get(HttpChecksumRequiredTrait.class); + + @Override + public RequestT modifyBeforeTransmit(RequestHook hook) { + return hook.mapRequest(HttpRequest.class, HttpChecksumInterceptor::processRequest); + } + + private static HttpRequest processRequest(RequestHook hook) { + if (hook.operation().schema().hasTrait(CHECKSUM_REQUIRED_TRAIT_KEY)) { + return addContentMd5Header(hook.request()); + } + return hook.request(); + } + + static HttpRequest addContentMd5Header(HttpRequest request) { + var body = request.body(); + if (body != null) { + var buffer = body.waitForByteBuffer(); + var bytes = ByteBufferUtils.getBytes(buffer); + try { + byte[] hash = MessageDigest.getInstance("MD5").digest(bytes); + String base64Hash = Base64.getEncoder().encodeToString(hash); + return request.toBuilder() + .withReplacedHeader("Content-MD5", ListUtils.of(base64Hash)) + .build(); + } catch (NoSuchAlgorithmException e) { + throw new IllegalStateException("Unable to fetch message digest instance for MD5", e); + } + } + return request; + } + } +} diff --git a/client/client-http/src/test/java/software/amazon/smithy/java/client/http/plugins/HttpChecksumPluginTest.java b/client/client-http/src/test/java/software/amazon/smithy/java/client/http/plugins/HttpChecksumPluginTest.java new file mode 100644 index 000000000..48c208f71 --- /dev/null +++ b/client/client-http/src/test/java/software/amazon/smithy/java/client/http/plugins/HttpChecksumPluginTest.java @@ -0,0 +1,53 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.plugins; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.java.http.api.HttpRequest; +import software.amazon.smithy.java.io.datastream.DataStream; + +public class HttpChecksumPluginTest { + + @Test + public void interceptorAddsContentMd5HeaderForKnownBody() throws Exception { + var interceptor = new HttpChecksumPlugin.HttpChecksumInterceptor(); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofBytes("test body".getBytes(StandardCharsets.UTF_8))) + .build(); + + var result = interceptor.addContentMd5Header(req); + + var headers = result.headers().allValues("Content-MD5"); + assertThat(headers, hasSize(1)); + assertThat(headers.get(0), equalTo("u/mv50Mcr1+Jpgi8MejYIg==")); + } + + @Test + public void interceptorReplacesExistingContentMd5Header() throws Exception { + var interceptor = new HttpChecksumPlugin.HttpChecksumInterceptor(); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofBytes("test body".getBytes(StandardCharsets.UTF_8))) + .withAddedHeader("Content-MD5", "wrong-hash") + .build(); + + var result = interceptor.addContentMd5Header(req); + + var headers = result.headers().allValues("Content-MD5"); + assertThat(headers, hasSize(1)); + assertThat(headers.get(0), equalTo("u/mv50Mcr1+Jpgi8MejYIg==")); + } + +}