Skip to content

Commit

Permalink
Fix TestRestTemplate request factory management
Browse files Browse the repository at this point in the history
This commit fixes two issues in `TestRestTemplate`:

* it improves the detection of the underlying request factory, using
reflection to look inside the intercepting request factory if
interceptors were configured

* it avoids reusing the same request factory when creating a new
`TestRestTemplate` with `withBasicAuth`. Sharing the same instance would
result in sharing authentication state (HTTP cookies). Since the
original request factory can't be detected consistently, a new one is
selected automatically

See gh-8697
  • Loading branch information
bclozel committed Feb 16, 2018
1 parent 51de220 commit db7268b
Showing 1 changed file with 22 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,17 @@ public TestRestTemplate(RestTemplateBuilder restTemplateBuilder, String username
httpClientOptions);
}

private static RestTemplate buildRestTemplate(
RestTemplateBuilder restTemplateBuilder) {
Assert.notNull(restTemplateBuilder, "RestTemplateBuilder must not be null");
return restTemplateBuilder.build();
}

private TestRestTemplate(RestTemplate restTemplate, String username, String password,
HttpClientOption... httpClientOptions) {
Assert.notNull(restTemplate, "RestTemplate must not be null");
this.httpClientOptions = httpClientOptions;
if (restTemplate.getRequestFactory().getClass().getName()
if (getRequestFactoryClass(restTemplate).getName()
.equals("org.springframework.http.client.HttpComponentsClientHttpRequestFactory")) {
restTemplate.setRequestFactory(
new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions));
Expand All @@ -148,10 +154,16 @@ private TestRestTemplate(RestTemplate restTemplate, String username, String pass
this.restTemplate = restTemplate;
}

private static RestTemplate buildRestTemplate(
RestTemplateBuilder restTemplateBuilder) {
Assert.notNull(restTemplateBuilder, "RestTemplateBuilder must not be null");
return restTemplateBuilder.build();
private Class<? extends ClientHttpRequestFactory> getRequestFactoryClass(RestTemplate restTemplate) {
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
if (InterceptingClientHttpRequestFactory.class.isAssignableFrom(requestFactory.getClass())) {
Field requestFactoryField = ReflectionUtils
.findField(RestTemplate.class, "requestFactory");
ReflectionUtils.makeAccessible(requestFactoryField);
requestFactory = (ClientHttpRequestFactory)
ReflectionUtils.getField(requestFactoryField, restTemplate);
}
return requestFactory.getClass();
}

private void addAuthentication(RestTemplate restTemplate, String username,
Expand Down Expand Up @@ -1022,30 +1034,18 @@ public RestTemplate getRestTemplate() {
* @since 1.4.1
*/
public TestRestTemplate withBasicAuth(String username, String password) {
RestTemplate restTemplate = new RestTemplate();
restTemplate.setMessageConverters(getRestTemplate().getMessageConverters());
restTemplate.setInterceptors(getRestTemplate().getInterceptors());
restTemplate.setRequestFactory(getRequestFactory(getRestTemplate()));
restTemplate.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler());
RestTemplate restTemplate = new RestTemplateBuilder()
.messageConverters(getRestTemplate().getMessageConverters())
.interceptors(getRestTemplate().getInterceptors())
.uriTemplateHandler(getRestTemplate().getUriTemplateHandler())
.build();
TestRestTemplate testRestTemplate = new TestRestTemplate(restTemplate, username,
password, this.httpClientOptions);
testRestTemplate.getRestTemplate()
.setErrorHandler(getRestTemplate().getErrorHandler());
return testRestTemplate;
}

private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) {
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
if (InterceptingClientHttpRequestFactory.class.isAssignableFrom(requestFactory.getClass())) {
Field requestFactoryField = ReflectionUtils
.findField(RestTemplate.class, "requestFactory");
ReflectionUtils.makeAccessible(requestFactoryField);
requestFactory = (ClientHttpRequestFactory)
ReflectionUtils.getField(requestFactoryField, getRestTemplate());
}
return requestFactory;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private RequestEntity<?> createRequestEntityWithRootAppliedUri(
RequestEntity<?> requestEntity) {
Expand Down

0 comments on commit db7268b

Please sign in to comment.