Skip to content

Commit

Permalink
Support Gradle-style Kotlin bean API
Browse files Browse the repository at this point in the history
val context = GenericApplicationContext {
    registerBean<Foo>()
    registerBean { Bar(it.getBean<Foo>()) }
}

Issue: SPR-15126
  • Loading branch information
sdeleuze committed Jan 10, 2017
1 parent 1af905c commit f8461d8
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,33 @@ object BeanFactoryExtension {
*/
fun <T : Any> BeanFactory.getBean(requiredType: KClass<T>) = getBean(requiredType.java)

/**
* @see BeanFactory.getBean(Class<T>)
*/
inline fun <reified T : Any> BeanFactory.getBean() = getBean(T::class.java)

/**
* @see BeanFactory.getBean(String, Class<T>)
*/
fun <T : Any> BeanFactory.getBean(name: String, requiredType: KClass<T>) =
getBean(name, requiredType.java)

/**
* @see BeanFactory.getBean(String, Class<T>)
*/
inline fun <reified T : Any> BeanFactory.getBean(name: String) =
getBean(name, T::class.java)

/**
* @see BeanFactory.getBean(Class<T>, Object...)
*/
fun <T : Any> BeanFactory.getBean(requiredType: KClass<T>, vararg args:Any) =
getBean(requiredType.java, *args)

/**
* @see BeanFactory.getBean(Class<T>, Object...)
*/
inline fun <reified T : Any> BeanFactory.getBean(vararg args:Any) =
getBean(T::class.java, *args)

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,84 @@ object ListableBeanFactoryExtension {
fun <T : Any> ListableBeanFactory.getBeanNamesForType(type: KClass<T>) =
getBeanNamesForType(type.java)

/**
* @see ListableBeanFactory.getBeanNamesForType(Class<?>)
*/
inline fun <reified T : Any> ListableBeanFactory.getBeanNamesForType() =
getBeanNamesForType(T::class.java)

/**
* @see ListableBeanFactory.getBeanNamesForType(Class<?>, boolean, boolean)
*/
fun <T : Any> ListableBeanFactory.getBeanNamesForType(type: KClass<T>,
includeNonSingletons: Boolean, allowEagerInit: Boolean) =
getBeanNamesForType(type.java, includeNonSingletons, allowEagerInit)

/**
* @see ListableBeanFactory.getBeanNamesForType(Class<?>, boolean, boolean)
*/
inline fun <reified T : Any> ListableBeanFactory.getBeanNamesForType(includeNonSingletons: Boolean, allowEagerInit: Boolean) =
getBeanNamesForType(T::class.java, includeNonSingletons, allowEagerInit)

/**
* @see ListableBeanFactory.getBeansOfType(Class<T>)
*/
fun <T : Any> ListableBeanFactory.getBeansOfType(type: KClass<T>) =
getBeansOfType(type.java)

/**
* @see ListableBeanFactory.getBeansOfType(Class<T>)
*/
inline fun <reified T : Any> ListableBeanFactory.getBeansOfType() =
getBeansOfType(T::class.java)

/**
* @see ListableBeanFactory.getBeansOfType(Class<T>, boolean, boolean)
*/
fun <T : Any> ListableBeanFactory.getBeansOfType(type: KClass<T>,
includeNonSingletons: Boolean, allowEagerInit: Boolean) =
getBeansOfType(type.java, includeNonSingletons, allowEagerInit)

/**
* @see ListableBeanFactory.getBeansOfType(Class<T>, boolean, boolean)
*/
inline fun <reified T : Any> ListableBeanFactory.getBeansOfType(includeNonSingletons: Boolean, allowEagerInit: Boolean) =
getBeansOfType(T::class.java, includeNonSingletons, allowEagerInit)

/**
* @see ListableBeanFactory.getBeanNamesForAnnotation
*/
fun <T : Annotation> ListableBeanFactory.getBeanNamesForAnnotation(type: KClass<T>) =
getBeanNamesForAnnotation(type.java)

/**
* @see ListableBeanFactory.getBeanNamesForAnnotation
*/
inline fun <reified T : Annotation> ListableBeanFactory.getBeanNamesForAnnotation() =
getBeanNamesForAnnotation(T::class.java)

/**
* @see ListableBeanFactory.getBeansWithAnnotation
*/
fun <T : Annotation> ListableBeanFactory.getBeansWithAnnotation(type: KClass<T>) =
getBeansWithAnnotation(type.java)

/**
* @see ListableBeanFactory.getBeansWithAnnotation
*/
inline fun <reified T : Annotation> ListableBeanFactory.getBeansWithAnnotation() =
getBeansWithAnnotation(T::class.java)

/**
* @see ListableBeanFactoryExtension.findAnnotationOnBean
*/
fun <T : Annotation> ListableBeanFactory.findAnnotationOnBean(beanName:String, type: KClass<T>) =
findAnnotationOnBean(beanName, type.java)

/**
* @see ListableBeanFactoryExtension.findAnnotationOnBean
*/
inline fun <reified T : Annotation> ListableBeanFactory.findAnnotationOnBean(beanName:String) =
findAnnotationOnBean(beanName, T::class.java)

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ object GenericApplicationContextExtension {
registerBean(beanClass.java, *customizers)
}

/**
* @see GenericApplicationContext.registerBean(Class<T>, BeanDefinitionCustomizer...)
*/
inline fun <reified T : Any> GenericApplicationContext.registerBean(vararg customizers: BeanDefinitionCustomizer) {
registerBean(T::class.java, *customizers)
}

/**
* @see GenericApplicationContext.registerBean(String, Class<T>, BeanDefinitionCustomizer...)
*/
Expand All @@ -31,6 +38,13 @@ object GenericApplicationContextExtension {
registerBean(beanName, beanClass.java, *customizers)
}

/**
* @see GenericApplicationContext.registerBean(String, Class<T>, BeanDefinitionCustomizer...)
*/
inline fun <reified T : Any> GenericApplicationContext.registerBean(beanName: String, vararg customizers: BeanDefinitionCustomizer) {
registerBean(beanName, T::class.java, *customizers)
}

/**
* @see GenericApplicationContext.registerBean(Class<T>, Supplier<T>, BeanDefinitionCustomizer...)
*/
Expand All @@ -46,4 +60,6 @@ object GenericApplicationContextExtension {
vararg customizers: BeanDefinitionCustomizer, crossinline function: (ApplicationContext) -> T) {
registerBean(name, T::class.java, Supplier { function.invoke(this) }, *customizers)
}

fun GenericApplicationContext(configure: GenericApplicationContext.()->Unit) = GenericApplicationContext().apply(configure)
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.junit.Assert.assertNotNull
import org.junit.Test
import org.springframework.context.support.GenericApplicationContextExtension.registerBean
import org.springframework.beans.factory.BeanFactoryExtension.getBean
import org.springframework.context.support.GenericApplicationContextExtension.GenericApplicationContext

class GenericApplicationContextExtensionTests {

Expand Down Expand Up @@ -59,6 +60,17 @@ class GenericApplicationContextExtensionTests {
assertNotNull(context.getBean("b"))
}

@Test
fun registerBeanWithGradleStyleApi() {
val context = GenericApplicationContext {
registerBean<BeanA>()
registerBean { BeanB(it.getBean<BeanA>()) }
}
context.refresh()
assertNotNull(context.getBean<BeanA>())
assertNotNull(context.getBean<BeanB>())
}

internal class BeanA

internal class BeanB(val a: BeanA)
Expand Down

0 comments on commit f8461d8

Please sign in to comment.