Skip to content

Commit

Permalink
AWS SDK 2 upgrade: Fix unit tests in core - spring-atticgh-267
Browse files Browse the repository at this point in the history
  • Loading branch information
tinexw committed Feb 11, 2020
1 parent 642c876 commit fd02381
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 178 deletions.
Expand Up @@ -60,7 +60,7 @@ private AmazonWebserviceClientConfigurationUtils() {

public static BeanDefinitionHolder registerAmazonWebserviceClient(Object source,
BeanDefinitionRegistry registry, String serviceNameClassName,
String customRegionProvider, String customRegion) {
String customRegionProvider) {

String beanName = getBeanName(serviceNameClassName);

Expand All @@ -70,7 +70,7 @@ public static BeanDefinitionHolder registerAmazonWebserviceClient(Object source,
}

BeanDefinition definition = getAmazonWebserviceClientBeanDefinition(source,
serviceNameClassName, customRegionProvider, customRegion, registry);
serviceNameClassName, customRegionProvider, registry);
BeanDefinitionHolder holder = new BeanDefinitionHolder(definition, beanName);
BeanDefinitionReaderUtils.registerBeanDefinition(holder, registry);

Expand All @@ -79,13 +79,7 @@ public static BeanDefinitionHolder registerAmazonWebserviceClient(Object source,

public static AbstractBeanDefinition getAmazonWebserviceClientBeanDefinition(
Object source, String serviceNameClassName, String customRegionProvider,
String customRegion, BeanDefinitionRegistry beanDefinitionRegistry) {

if (StringUtils.hasText(customRegionProvider)
&& StringUtils.hasText(customRegion)) {
throw new IllegalArgumentException(
"Only region or regionProvider can be configured, but not both");
}
BeanDefinitionRegistry beanDefinitionRegistry) {

registerCredentialsProviderIfNeeded(beanDefinitionRegistry);

Expand All @@ -103,9 +97,6 @@ public static AbstractBeanDefinition getAmazonWebserviceClientBeanDefinition(
if (StringUtils.hasText(customRegionProvider)) {
builder.addPropertyReference("regionProvider", customRegionProvider);
}
else if (StringUtils.hasText(customRegion)) {
builder.addPropertyValue("customRegion", customRegion);
}
else {
registerRegionProviderBeanIfNeeded(beanDefinitionRegistry);
builder.addPropertyReference("regionProvider", REGION_PROVIDER_BEAN_NAME);
Expand Down
Expand Up @@ -95,7 +95,7 @@ protected T createInstance() throws Exception {
builder.credentialsProvider(this.credentialsProvider);
}

else if (this.regionProvider != null) {
if (this.regionProvider != null) {
builder.region(this.regionProvider.getRegion());
}
else {
Expand Down
Expand Up @@ -31,8 +31,6 @@
*/
public final class XmlWebserviceConfigurationUtils {

private static final String REGION_ATTRIBUTE_NAME = "region";

private static final String REGION_PROVIDER_ATTRIBUTE_NAME = "region-provider";

private XmlWebserviceConfigurationUtils() {
Expand All @@ -57,7 +55,6 @@ public static AbstractBeanDefinition parseCustomClientElement(Element element,
try {
return getAmazonWebserviceClientBeanDefinition(source, serviceClassName,
element.getAttribute(REGION_PROVIDER_ATTRIBUTE_NAME),
element.getAttribute(REGION_ATTRIBUTE_NAME),
parserContext.getRegistry());
}
catch (Exception e) {
Expand All @@ -72,8 +69,7 @@ private static BeanDefinitionHolder parseAndRegisterDefaultAmazonWebserviceClien
try {
return registerAmazonWebserviceClient(source, parserContext.getRegistry(),
serviceClassName,
element.getAttribute(REGION_PROVIDER_ATTRIBUTE_NAME),
element.getAttribute(REGION_ATTRIBUTE_NAME));
element.getAttribute(REGION_PROVIDER_ATTRIBUTE_NAME));
}
catch (Exception e) {
parserContext.getReaderContext().error(e.getMessage(), source, e);
Expand Down
Expand Up @@ -59,18 +59,16 @@ public Class<?> getObjectType() {

@Override
protected AwsCredentialsProvider createInstance() throws Exception {
AwsCredentialsProviderChain.Builder awsCredentialsProviderChainBuilder;
if (this.delegates.isEmpty()) {
awsCredentialsProviderChainBuilder = AwsCredentialsProviderChain.builder()
.credentialsProviders(DefaultCredentialsProvider.create());
return DefaultCredentialsProvider.builder()
.reuseLastProviderEnabled(false)
.build();
}
else {
awsCredentialsProviderChainBuilder = AwsCredentialsProviderChain.builder()
.credentialsProviders(this.delegates);
return AwsCredentialsProviderChain.builder()
.credentialsProviders(this.delegates)
.reuseLastProviderEnabled(false).build();
}

awsCredentialsProviderChainBuilder.reuseLastProviderEnabled(false);
return awsCredentialsProviderChainBuilder.build();
}

}
Expand Up @@ -20,13 +20,16 @@
import java.util.concurrent.ConcurrentHashMap;

import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;

import org.springframework.aop.framework.Advised;
import org.springframework.aop.support.AopUtils;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

/**
Expand All @@ -39,20 +42,21 @@
*/
public class AmazonS3ClientFactory {

private static final String CREDENTIALS_PROVIDER_FIELD_NAME = "awsCredentialsProvider";
private static final String CLIENT_CONFIGURATION_FIELD_NAME = "clientConfiguration";

private final ConcurrentHashMap<String, S3Client> clientCache = new ConcurrentHashMap<>(
Region.regions().size());

private final Field credentialsProviderField;
private final Field clientConfigurationField;

public AmazonS3ClientFactory() {
this.credentialsProviderField = ReflectionUtils.findField(S3Client.class,
CREDENTIALS_PROVIDER_FIELD_NAME);
Assert.notNull(this.credentialsProviderField,
"Credentials Provider field not found, this class does not work with the current "
Class<?> s3ClientClass = ClassUtils.resolveClassName("software.amazon.awssdk.services.s3.DefaultS3Client", null);
this.clientConfigurationField = ReflectionUtils.findField(s3ClientClass,
CLIENT_CONFIGURATION_FIELD_NAME);
Assert.notNull(this.clientConfigurationField,
"Client Configuration field not found, this class does not work with the current "
+ "AWS SDK release");
ReflectionUtils.makeAccessible(this.credentialsProviderField);
ReflectionUtils.makeAccessible(this.clientConfigurationField);
}

private static S3Client getAmazonS3ClientFromProxy(S3Client amazonS3) {
Expand Down Expand Up @@ -85,14 +89,14 @@ public S3Client createClientForRegion(S3Client prototype, String region) {
return this.clientCache.get(region);
}

// TODO SDK2 migration: find a different solution to use the same credentials provider.
private S3ClientBuilder buildAmazonS3ForRegion(S3Client prototype, Region region) {
S3ClientBuilder clientBuilder = S3Client.builder();

S3Client target = getAmazonS3ClientFromProxy(prototype);
if (target != null) {
AwsCredentialsProvider awsCredentialsProvider = (AwsCredentialsProvider) ReflectionUtils
.getField(this.credentialsProviderField, target);
SdkClientConfiguration sdkClientConfiguration = (SdkClientConfiguration) ReflectionUtils
.getField(this.clientConfigurationField, target);
AwsCredentialsProvider awsCredentialsProvider = sdkClientConfiguration.option(AwsClientOption.CREDENTIALS_PROVIDER);
clientBuilder.credentialsProvider(awsCredentialsProvider);
}

Expand Down
Expand Up @@ -38,6 +38,7 @@
*
* @author Greg Turnquist
* @author Agim Emruli
* @author Kristine Jetzke
* @since 1.1
*/
public final class AmazonS3ProxyFactory {
Expand Down Expand Up @@ -135,6 +136,7 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
private S3Client buildAmazonS3ForRedirectLocation(S3Client prototype,
S3Exception e) {
try {
// TODO SDK2 migration: add integration test
final String region = e.awsErrorDetails().sdkHttpResponse()
.firstMatchingHeader("x-amx-bucket-region")
.orElseThrow(() -> new RuntimeException(
Expand Down
Expand Up @@ -139,7 +139,7 @@ public String getFilename() throws IllegalStateException {
}

@Override
public URL getURL() throws IOException {
public URL getURL() {
return this.amazonS3.utilities().getUrl(GetUrlRequest.builder()
.bucket(this.bucketName).key(this.objectName).build());
}
Expand Down
Expand Up @@ -50,7 +50,7 @@ public void registerAmazonWebserviceClient_withMinimalConfiguration_returnsDefau

BeanDefinitionHolder beanDefinitionHolder = AmazonWebserviceClientConfigurationUtils
.registerAmazonWebserviceClient(new Object(), beanFactory,
AmazonTestWebserviceClient.class.getName(), null, null);
AmazonTestWebserviceClient.class.getName(), null);

// Act
beanFactory.preInstantiateSingletons();
Expand All @@ -77,8 +77,7 @@ public void registerAmazonWebserviceClient_withCustomRegionProviderConfiguration

BeanDefinitionHolder beanDefinitionHolder = AmazonWebserviceClientConfigurationUtils
.registerAmazonWebserviceClient(new Object(), beanFactory,
AmazonTestWebserviceClient.class.getName(), "myRegionProvider",
null);
AmazonTestWebserviceClient.class.getName(), "myRegionProvider");

// Act
beanFactory.preInstantiateSingletons();
Expand All @@ -90,55 +89,6 @@ public void registerAmazonWebserviceClient_withCustomRegionProviderConfiguration
assertThat(beanDefinitionHolder.getBeanName()).isEqualTo("amazonTestWebservice");
}

@Test
public void registerAmazonWebserviceClient_withCustomRegionConfiguration_returnsBeanDefinitionWithRegionConfigured()
throws Exception {
// Arrange
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.registerSingleton(
AmazonWebserviceClientConfigurationUtils.CREDENTIALS_PROVIDER_BEAN_NAME,
new StaticAwsCredentialsProvider());

BeanDefinitionHolder beanDefinitionHolder = AmazonWebserviceClientConfigurationUtils
.registerAmazonWebserviceClient(new Object(), beanFactory,
AmazonTestWebserviceClient.class.getName(), null,
Region.EU_WEST_1.id());

// Act
beanFactory.preInstantiateSingletons();
AmazonTestWebserviceClient client = beanFactory.getBean(
beanDefinitionHolder.getBeanName(), AmazonTestWebserviceClient.class);

// Assert
assertThat(client).isNotNull();
assertThat(beanDefinitionHolder.getBeanName()).isEqualTo("amazonTestWebservice");
}

@Test
public void registerAmazonWebserviceClient_withCustomRegionAndRegionProviderConfigured_reportsError()
throws Exception {
// Arrange
this.expectedException.expect(IllegalArgumentException.class);
this.expectedException.expectMessage(
"Only region or regionProvider can be configured, but not both");

DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.registerSingleton(
AmazonWebserviceClientConfigurationUtils.CREDENTIALS_PROVIDER_BEAN_NAME,
new StaticAwsCredentialsProvider());

BeanDefinitionHolder beanDefinitionHolder = AmazonWebserviceClientConfigurationUtils
.registerAmazonWebserviceClient(new Object(), beanFactory,
AmazonTestWebserviceClient.class.getName(), "someProvider",
Region.EU_WEST_1.id());

// Act
beanFactory.getBean(beanDefinitionHolder.getBeanName(),
AmazonTestWebserviceClient.class);

// Assert
}

@Test
public void generateBeanName_withInterfaceAndCapitalLetterInSequence_producesDeCapitalizedBeanName()
throws Exception {
Expand Down
Expand Up @@ -79,7 +79,9 @@ public void testCreateWithMultiple() throws Exception {
AwsBasicCredentials foo = AwsBasicCredentials.create("foo", "foo");
AwsBasicCredentials bar = AwsBasicCredentials.create("bar", "bar");

when(first.resolveCredentials()).thenReturn(null, foo);
when(first.resolveCredentials())
.thenThrow(new RuntimeException("first call fails"))
.thenReturn(foo);
when(second.resolveCredentials()).thenReturn(bar);

assertThat(provider.resolveCredentials()).isEqualTo(bar);
Expand Down
Expand Up @@ -54,7 +54,7 @@ public void getObject_userTagDataAvailable_objectContainsAllAvailableKeys()
TagDescription.builder().key("keyA").resourceType(ResourceType.INSTANCE)
.value("valueA").build(),
TagDescription.builder().key("keyB").resourceType(ResourceType.INSTANCE)
.value("keyB").build())
.value("valueB").build())
.build();

when(amazonEC2.describeTags(describeTagsRequest)).thenReturn(describeTagsResult);
Expand Down
Expand Up @@ -95,10 +95,8 @@ private static CloudFormationClient makeAmazonCloudFormationClient(
.entrySet()) {
String stackName = entry.getKey();

ListStackResourcesResponse listStackResourcesResult = mock(
ListStackResourcesResponse.class);
when(listStackResourcesResult.stackResourceSummaries())
.thenReturn(entry.getValue());
ListStackResourcesResponse listStackResourcesResult = ListStackResourcesResponse.builder()
.stackResourceSummaries(entry.getValue()).build();

when(amazonCloudFormationClient.listStackResources(
ArgumentMatchers.<ListStackResourcesRequest>argThat(
Expand All @@ -111,12 +109,11 @@ private static CloudFormationClient makeAmazonCloudFormationClient(

private static StackResourceSummary makeStackResourceSummary(String logicalResourceId,
String physicalResourceId) {
StackResourceSummary stackResourceSummary = mock(StackResourceSummary.class);
when(stackResourceSummary.logicalResourceId()).thenReturn(logicalResourceId);
when(stackResourceSummary.physicalResourceId()).thenReturn(physicalResourceId);
when(stackResourceSummary.resourceType())
.thenReturn(logicalResourceId.endsWith("Stack")
? "AWS::CloudFormation::Stack" : "Amazon::SES::Test");
StackResourceSummary stackResourceSummary = StackResourceSummary.builder()
.logicalResourceId(logicalResourceId)
.physicalResourceId(physicalResourceId)
.resourceType(logicalResourceId.endsWith("Stack") ? "AWS::CloudFormation::Stack" : "Amazon::SES::Test")
.build();
return stackResourceSummary;
}

Expand Down
Expand Up @@ -19,15 +19,18 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;

import org.springframework.test.util.ReflectionTestUtils;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Agim Emruli
*/
// TODO SDK2 migration: test for missing header
public class AmazonS3ClientFactoryTest {

@Rule
Expand All @@ -40,7 +43,7 @@ public void createClientForEndpointUrl_withNullEndpoint_throwsIllegalArgumentExc
S3Client amazonS3 = S3Client.builder().region(Region.US_WEST_2).build();

this.expectedException.expect(IllegalArgumentException.class);
this.expectedException.expectMessage("Endpoint Url must not be null");
this.expectedException.expectMessage("Region must not be null");

// Act
amazonS3ClientFactory.createClientForRegion(amazonS3, null);
Expand Down Expand Up @@ -75,8 +78,8 @@ public void createClientForEndpointUrl_withRegion_createClientForRegion() {
"us-west-1");

// Prepare
// TODO SDK2 migration: update and uncomment
// assertThat(newClient.getRegionName()).isEqualTo(Region.US_WEST_1);
SdkClientConfiguration clientConfiguration = (SdkClientConfiguration) ReflectionTestUtils.getField(newClient, "clientConfiguration");
assertThat(clientConfiguration.option(AwsClientOption.AWS_REGION)).isEqualTo(Region.US_WEST_1);
}

@Test
Expand All @@ -90,8 +93,8 @@ public void createClientForEndpointUrl_withProxiedClient_createClientForCustomRe
"eu-central-1");

// Prepare
// TODO SDK2 migration: update and uncomment
// assertThat(newClient.getRegionName()).isEqualTo(Regions.EU_CENTRAL_1.getName());
SdkClientConfiguration clientConfiguration = (SdkClientConfiguration) ReflectionTestUtils.getField(newClient, "clientConfiguration");
assertThat(clientConfiguration.option(AwsClientOption.AWS_REGION)).isEqualTo(Region.EU_CENTRAL_1);
}

@Test
Expand Down
Expand Up @@ -29,6 +29,8 @@
import org.springframework.aop.framework.Advised;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aop.support.AopUtils;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.ClassUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -72,12 +74,13 @@ public void verifyDoubleWrappingHandled() throws Exception {

@Test
public void verifyPolymorphicHandling() {
Class<?> defaultS3ClientClass = ClassUtils.resolveClassName("software.amazon.awssdk.services.s3.DefaultS3Client", null);

S3Client amazonS3 = mock(S3Client.class);
S3Client proxy1 = AmazonS3ProxyFactory.createProxy(amazonS3);

assertThat(S3Client.class.isAssignableFrom(proxy1.getClass())).isTrue();
assertThat(S3Client.class.isAssignableFrom(proxy1.getClass())).isFalse();
assertThat(defaultS3ClientClass.isAssignableFrom(proxy1.getClass())).isFalse();

S3Client amazonS3Client = S3Client.builder().region(Region.US_WEST_2).build();
S3Client proxy2 = AmazonS3ProxyFactory.createProxy(amazonS3Client);
Expand Down

0 comments on commit fd02381

Please sign in to comment.