Skip to content

Commit

Permalink
INT-3539: Fix ChannelSecurityInterceptorBPP
Browse files Browse the repository at this point in the history
JIRA: https://jira.spring.io/browse/INT-3539

Previously the `ChannelSecurityInterceptorBeanPostProcessor` eagerly loaded all beans from its `afterPropertiesSet`
to retrieve the `ChannelSecurityInterceptor`s. According to the `BeanPostProcessor` nature the `afterPropertiesSet` hook isn't legitimate.
It may cause some bad side-effects for other beans, which might not been initialized yet.

Rework `ChannelSecurityInterceptorBeanPostProcessor` to accept `Collection<ChannelSecurityInterceptor>` in the ctor.
Rework `SecurityIntegrationConfigurationInitializer` to iterate over `BeanDefinition`s to determine those of them, which
are `ChannelSecurityInterceptor` or `ChannelSecurityInterceptorFactoryBean`.

**Cherry-pick to 4.0.x**
  • Loading branch information
artembilan committed Oct 29, 2014
1 parent 10b8468 commit 542fc4f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 59 deletions.
Expand Up @@ -25,15 +25,10 @@
import org.springframework.aop.support.AopUtils;
import org.springframework.aop.support.DefaultPointcutAdvisor;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.integration.security.channel.ChannelSecurityInterceptor;
import org.springframework.integration.security.channel.ChannelSecurityMetadataSource;
import org.springframework.messaging.MessageChannel;
import org.springframework.util.Assert;

/**
* A {@link BeanPostProcessor} that proxies {@link MessageChannel}s to apply a {@link ChannelSecurityInterceptor}.
Expand All @@ -42,21 +37,12 @@
* @author Oleg Zhurakousky
* @author Artem Bilan
*/
public class ChannelSecurityInterceptorBeanPostProcessor implements BeanPostProcessor, BeanFactoryAware, InitializingBean {
public class ChannelSecurityInterceptorBeanPostProcessor implements BeanPostProcessor {

private volatile Collection<ChannelSecurityInterceptor> securityInterceptors;
private final Collection<ChannelSecurityInterceptor> securityInterceptors;

private ListableBeanFactory beanFactory;

@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
Assert.isInstanceOf(ListableBeanFactory.class, beanFactory);
this.beanFactory = (ListableBeanFactory) beanFactory;
}

@Override
public void afterPropertiesSet() throws Exception {
this.securityInterceptors = this.beanFactory.getBeansOfType(ChannelSecurityInterceptor.class).values();
public ChannelSecurityInterceptorBeanPostProcessor(Collection<ChannelSecurityInterceptor> securityInterceptors) {
this.securityInterceptors = securityInterceptors;
}

public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
Expand All @@ -65,7 +51,7 @@ public Object postProcessBeforeInitialization(Object bean, String beanName) thro

public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof MessageChannel) {
for (ChannelSecurityInterceptor securityInterceptor : securityInterceptors) {
for (ChannelSecurityInterceptor securityInterceptor : this.securityInterceptors) {
ChannelSecurityMetadataSource channelSecurityMetadataSource =
(ChannelSecurityMetadataSource) securityInterceptor.obtainSecurityMetadataSource();
if (this.shouldProxy(beanName, channelSecurityMetadataSource)) {
Expand All @@ -92,4 +78,5 @@ private boolean shouldProxy(String beanName, ChannelSecurityMetadataSource chann
}
return false;
}

}
Expand Up @@ -16,11 +16,23 @@

package org.springframework.integration.security.config;

import java.lang.reflect.Method;
import java.util.List;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.CannotLoadBeanClassException;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.core.type.MethodMetadata;
import org.springframework.core.type.StandardMethodMetadata;
import org.springframework.integration.config.IntegrationConfigurationInitializer;
import org.springframework.integration.security.channel.ChannelSecurityInterceptor;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;

/**
* The Integration Security infrastructure {@code beanFactory} initializer.
Expand All @@ -30,13 +42,49 @@
*/
public class SecurityIntegrationConfigurationInitializer implements IntegrationConfigurationInitializer {

private static final String CHANNEL_SECURITY_INTERCEPTOR_BPP_BEAN_NAME = ChannelSecurityInterceptorBeanPostProcessor.class.getName();
private static final String CHANNEL_SECURITY_INTERCEPTOR_BPP_BEAN_NAME =
ChannelSecurityInterceptorBeanPostProcessor.class.getName();

@Override
public void initialize(ConfigurableListableBeanFactory beanFactory) throws BeansException {
BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
if (!registry.containsBeanDefinition(CHANNEL_SECURITY_INTERCEPTOR_BPP_BEAN_NAME)) {
registry.registerBeanDefinition(CHANNEL_SECURITY_INTERCEPTOR_BPP_BEAN_NAME, new RootBeanDefinition(ChannelSecurityInterceptorBeanPostProcessor.class));

List<BeanDefinition> securityInterceptors = new ManagedList<BeanDefinition>();

for (String beanName : registry.getBeanDefinitionNames()) {
BeanDefinition beanDefinition = registry.getBeanDefinition(beanName);
String beanClassName = beanDefinition.getBeanClassName();
Class<?> clazz = null;
if (StringUtils.hasText(beanClassName)) {
try {
clazz = ClassUtils.forName(beanClassName, beanFactory.getBeanClassLoader());
}
catch (ClassNotFoundException e) {
throw new CannotLoadBeanClassException(this.toString(), beanName, beanClassName, e);
}
}
else if (beanDefinition instanceof AnnotatedBeanDefinition
&& beanDefinition.getSource() instanceof MethodMetadata) {
MethodMetadata beanMethod = (MethodMetadata) beanDefinition.getSource();
if (beanMethod instanceof StandardMethodMetadata) {
Method method = ((StandardMethodMetadata) beanMethod).getIntrospectedMethod();
clazz = method.getReturnType();
}
}

if (clazz != null &&
(ChannelSecurityInterceptor.class.isAssignableFrom(clazz)
|| ChannelSecurityInterceptorFactoryBean.class.isAssignableFrom(clazz))) {
securityInterceptors.add(beanDefinition);
}
}

if (!securityInterceptors.isEmpty()) {
BeanDefinition securityPostProcessorBd =
BeanDefinitionBuilder.rootBeanDefinition(ChannelSecurityInterceptorBeanPostProcessor.class)
.addConstructorArgValue(securityInterceptors)
.getBeanDefinition();
registry.registerBeanDefinition(CHANNEL_SECURITY_INTERCEPTOR_BPP_BEAN_NAME, securityPostProcessorBd);
}
}

Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2010 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,74 +19,55 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import java.util.Collections;
import java.util.Map;
import java.util.Arrays;
import java.util.regex.Pattern;

import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import org.springframework.aop.support.AopUtils;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.security.config.ChannelSecurityInterceptorBeanPostProcessor;
import org.springframework.messaging.MessageChannel;

/**
* @author Mark Fisher
* @author Artem Bilan
*/
public class ChannelSecurityInterceptorBeanPostProcessorTests {

@Test
public void securedChannelIsProxied() throws Exception {
ChannelSecurityMetadataSource securityMetadataSource = new ChannelSecurityMetadataSource();
securityMetadataSource.addPatternMapping(Pattern.compile("secured.*"), new DefaultChannelAccessPolicy("ROLE_ADMIN", null));
securityMetadataSource.addPatternMapping(Pattern.compile("secured.*"),
new DefaultChannelAccessPolicy("ROLE_ADMIN", null));

final ChannelSecurityInterceptor interceptor = new ChannelSecurityInterceptor(securityMetadataSource);
ChannelSecurityInterceptor interceptor = new ChannelSecurityInterceptor(securityMetadataSource);

ListableBeanFactory beanFactory = Mockito.mock(ListableBeanFactory.class);
Mockito.doAnswer(new Answer<Map<String, ChannelSecurityInterceptor>>() {

@Override
public Map<String, ChannelSecurityInterceptor> answer(InvocationOnMock invocation) throws Throwable {
return Collections.singletonMap("interceptor", interceptor);
}
}).when(beanFactory).getBeansOfType(ChannelSecurityInterceptor.class);

ChannelSecurityInterceptorBeanPostProcessor postProcessor = new ChannelSecurityInterceptorBeanPostProcessor();
postProcessor.setBeanFactory(beanFactory);
postProcessor.afterPropertiesSet();
ChannelSecurityInterceptorBeanPostProcessor postProcessor =
new ChannelSecurityInterceptorBeanPostProcessor(Arrays.asList(interceptor));

QueueChannel securedChannel = new QueueChannel();
securedChannel.setBeanName("securedChannel");
MessageChannel postProcessedChannel = (MessageChannel) postProcessor.postProcessAfterInitialization(securedChannel, "securedChannel");
MessageChannel postProcessedChannel =
(MessageChannel) postProcessor.postProcessAfterInitialization(securedChannel, "securedChannel");
assertTrue(AopUtils.isAopProxy(postProcessedChannel));
}

@Test
public void nonsecuredChannelIsNotProxied() throws Exception {
ChannelSecurityMetadataSource securityMetadataSource = new ChannelSecurityMetadataSource();
securityMetadataSource.addPatternMapping(Pattern.compile("secured.*"), new DefaultChannelAccessPolicy("ROLE_ADMIN", null));
final ChannelSecurityInterceptor interceptor = new ChannelSecurityInterceptor(securityMetadataSource);

ListableBeanFactory beanFactory = Mockito.mock(ListableBeanFactory.class);
Mockito.doAnswer(new Answer<Map<String, ChannelSecurityInterceptor>>() {
securityMetadataSource.addPatternMapping(Pattern.compile("secured.*"),
new DefaultChannelAccessPolicy("ROLE_ADMIN", null));

@Override
public Map<String, ChannelSecurityInterceptor> answer(InvocationOnMock invocation) throws Throwable {
return Collections.singletonMap("interceptor", interceptor);
}
}).when(beanFactory).getBeansOfType(ChannelSecurityInterceptor.class);
ChannelSecurityInterceptor interceptor = new ChannelSecurityInterceptor(securityMetadataSource);

ChannelSecurityInterceptorBeanPostProcessor postProcessor = new ChannelSecurityInterceptorBeanPostProcessor();
postProcessor.setBeanFactory(beanFactory);
postProcessor.afterPropertiesSet();
ChannelSecurityInterceptorBeanPostProcessor postProcessor =
new ChannelSecurityInterceptorBeanPostProcessor(Arrays.asList(interceptor));

QueueChannel channel = new QueueChannel();
channel.setBeanName("testChannel");
MessageChannel postProcessedChannel = (MessageChannel) postProcessor.postProcessAfterInitialization(channel, "testChannel");
MessageChannel postProcessedChannel =
(MessageChannel) postProcessor.postProcessAfterInitialization(channel, "testChannel");
assertFalse(AopUtils.isAopProxy(postProcessedChannel));
}

Expand Down

0 comments on commit 542fc4f

Please sign in to comment.