Skip to content
This repository has been archived by the owner on Feb 23, 2023. It is now read-only.

Commit

Permalink
Retain collection implementation types
Browse files Browse the repository at this point in the history
This commit makes sure that custom collection implementation types are
retained if possible.

Closes gh-1483
  • Loading branch information
snicoll committed Feb 8, 2022
1 parent e184df9 commit 9dd2342
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2021 the original author or authors.
* Copyright 2019-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,17 +17,21 @@
package org.springframework.aot.context.bootstrap.generator.bean.support;

import java.lang.reflect.Executable;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.CodeBlock.Builder;
Expand Down Expand Up @@ -110,10 +114,7 @@ else if (value instanceof List) {
code.add("$T.emptyList()", Collections.class);
}
else {
code.add("$T.of(", List.class);
ResolvableType collectionType = parameterType.as(List.class).getGenerics()[0];
code.add(writeAll(list, (item) -> collectionType));
code.add(")");
writeCollection(code, list, parameterType, ArrayList.class);
}
}
else if (value instanceof Set) {
Expand All @@ -122,10 +123,7 @@ else if (value instanceof Set) {
code.add("$T.emptySet()", Collections.class);
}
else {
code.add("$T.of(", Set.class);
ResolvableType collectionType = parameterType.as(Set.class).getGenerics()[0];
code.add(writeAll(set, (item) -> collectionType));
code.add(")");
writeCollection(code, set, parameterType, LinkedHashSet.class);
}
}
else if (value instanceof Map) {
Expand Down Expand Up @@ -172,6 +170,27 @@ else if (value instanceof BeanReference) {
}
}

private <T> void writeCollection(Builder code, Object value, ResolvableType parameterType, Class<?> fallbackCollectionType) {
code.add("$T.of(", Stream.class);
ResolvableType collectionType = parameterType.as(Collection.class).getGenerics()[0];
code.add(writeAll((Collection<?>) value, (item) -> collectionType));
code.add(").collect($T.toCollection($T::new))", Collectors.class,
inferCollectionType(value.getClass(), fallbackCollectionType));
}

private Class<?> inferCollectionType(Class<?> valueType, Class<?> fallbackCollectionType) {
if (Modifier.isPublic(valueType.getModifiers())) {
try {
valueType.getConstructor();
return valueType;
}
catch (NoSuchMethodException ex) {
// Ignore
}
}
return fallbackCollectionType;
}

private <T> CodeBlock writeAll(Iterable<T> items, Function<T, ResolvableType> elementType) {
MultiCodeBlock multi = new MultiCodeBlock();
items.forEach((item) -> multi.add((code) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ void writePropertyAsListOfBeanDefinitions() {
CodeSnippet generatedCode = beanRegistration(beanDefinition, (code) -> code.add("() -> InjectionConfiguration::new"));
assertThat(generatedCode).lines().containsOnly(
"BeanDefinitionRegistrar.of(\"test\", InjectionConfiguration.class)",
" .instanceSupplier(() -> InjectionConfiguration::new).customize((bd) -> bd.getPropertyValues().addPropertyValue(\"names\", List.of(BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, \"stringBean\")",
" .instanceSupplier(() -> InjectionConfiguration::new).customize((bd) -> bd.getPropertyValues().addPropertyValue(\"names\", Stream.of(BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, \"stringBean\")",
" .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).toBeanDefinition(), BeanDefinitionRegistrar.inner(SimpleConfiguration.class).withFactoryMethod(SimpleConfiguration.class, \"stringBean\")",
" .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).toBeanDefinition()))).register(beanFactory);");
" .instanceSupplier(() -> beanFactory.getBean(SimpleConfiguration.class).stringBean()).toBeanDefinition()).collect(Collectors.toCollection(ArrayList::new)))).register(beanFactory);");
assertThat(generatedCode).hasImport(SimpleConfiguration.class);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2021 the original author or authors.
* Copyright 2019-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -28,7 +29,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
Expand All @@ -40,6 +43,7 @@
import org.springframework.aot.context.bootstrap.generator.test.CodeSnippet;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.ResourceLoader;
Expand Down Expand Up @@ -75,10 +79,21 @@ void writeStringArray() {
}

@Test
void writeStringList() {
void writeStringNonPublicList() {
List<String> value = List.of("a", "test");
assertThat(write(value, ResolvableType.forClassWithGenerics(List.class, String.class)))
.isEqualTo("List.of(\"a\", \"test\")").hasImport(List.class);
.isEqualTo("Stream.of(\"a\", \"test\").collect(Collectors.toCollection(ArrayList::new))")
.hasImport(Stream.class).hasImport(Collectors.class).hasImport(ArrayList.class);
}

@Test
void writeStringCustomList() {
ManagedList<String> value = new ManagedList<>();
value.add("a");
value.add("test");
assertThat(write(value, ResolvableType.forClassWithGenerics(List.class, String.class)))
.isEqualTo("Stream.of(\"a\", \"test\").collect(Collectors.toCollection(ManagedList::new))")
.hasImport(Stream.class).hasImport(Collectors.class).hasImport(ManagedList.class);
}

@Test
Expand All @@ -89,10 +104,19 @@ void writeEmptyList() {
}

@Test
void writeStringSet() {
Set<String> value = new LinkedHashSet<>(Arrays.asList("a", "test"));
void writeStringNonPublicSet() {
Set<String> value = Set.of("a", "test");
assertThat(write(value, ResolvableType.forClassWithGenerics(Set.class, String.class)))
.endsWith(").collect(Collectors.toCollection(LinkedHashSet::new))")
.hasImport(Stream.class).hasImport(Collectors.class).hasImport(LinkedHashSet.class);
}

@Test
void writeStringCustomSet() {
Set<String> value = new TreeSet<>(Arrays.asList("a", "test"));
assertThat(write(value, ResolvableType.forClassWithGenerics(Set.class, String.class)))
.isEqualTo("Set.of(\"a\", \"test\")").hasImport(Set.class);
.isEqualTo("Stream.of(\"a\", \"test\").collect(Collectors.toCollection(TreeSet::new))")
.hasImport(Stream.class).hasImport(Collectors.class).hasImport(TreeSet.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2021 the original author or authors.
* Copyright 2019-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -151,9 +151,9 @@ void writeInfrastructureWithLifecycleMethodsRegisterCreateMethod() {
" private InitDestroyBeanPostProcessor createInitDestroyBeanPostProcessor(",
" ConfigurableBeanFactory beanFactory) {",
" Map<String, List<String>> initMethods = new LinkedHashMap<>();",
" initMethods.put(\"testBean\", List.of(\"start\"));",
" initMethods.put(\"testBean\", Stream.of(\"start\").collect(Collectors.toCollection(ArrayList::new)));",
" Map<String, List<String>> destroyMethods = new LinkedHashMap<>();",
" destroyMethods.put(\"testBean\", List.of(\"stop\"));",
" destroyMethods.put(\"testBean\", Stream.of(\"stop\").collect(Collectors.toCollection(ArrayList::new)));",
" return new InitDestroyBeanPostProcessor(beanFactory, initMethods, destroyMethods);",
" }").contains("import " + InitDestroyBeanPostProcessor.class.getName() + ";");
}
Expand Down

0 comments on commit 9dd2342

Please sign in to comment.